@@ -39,6 +39,11 @@ namespace Tensorflow | |||||
return buffer._handle; | return buffer._handle; | ||||
} | } | ||||
public static implicit operator byte[](Buffer buffer) | |||||
{ | |||||
return buffer.Data; | |||||
} | |||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
c_api.TF_DeleteBuffer(_handle); | c_api.TF_DeleteBuffer(_handle); | ||||
@@ -38,6 +38,16 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | ||||
/// <summary> | |||||
/// Write out a serialized representation of `graph` (as a GraphDef protocol | |||||
/// message) to `output_graph_def` (allocated by TF_NewBuffer()). | |||||
/// </summary> | |||||
/// <param name="graph"></param> | |||||
/// <param name="output_graph_def"></param> | |||||
/// <param name="status"></param> | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern void TF_GraphToGraphDef(IntPtr graph, IntPtr output_graph_def, IntPtr status); | |||||
/// <summary> | /// <summary> | ||||
/// Returns the number of dimensions of the Tensor referenced by `output` | /// Returns the number of dimensions of the Tensor referenced by `output` | ||||
/// in `graph`. | /// in `graph`. | ||||
@@ -26,15 +26,15 @@ namespace Tensorflow | |||||
public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status); | public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status); | ||||
public int NumInputs => c_api.TF_OperationNumInputs(_handle); | public int NumInputs => c_api.TF_OperationNumInputs(_handle); | ||||
public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); | public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); | ||||
public TF_Input[] OutputConsumers(int index, int max_consumers) | |||||
public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | |||||
{ | { | ||||
IntPtr handle = IntPtr.Zero; | |||||
int size = Marshal.SizeOf<TF_Input>(); | int size = Marshal.SizeOf<TF_Input>(); | ||||
int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), ref handle, max_consumers); | |||||
var handle = (TF_Input*)Marshal.AllocHGlobal(size); | |||||
int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); | |||||
var consumers = new TF_Input[num]; | var consumers = new TF_Input[num]; | ||||
for(int i = 0; i < num; i++) | for(int i = 0; i < num; i++) | ||||
{ | { | ||||
consumers[0] = Marshal.PtrToStructure<TF_Input>(handle + i * size); | |||||
consumers[i] = new TF_Input((*handle).oper + i * size, (*handle).index); | |||||
} | } | ||||
return consumers; | return consumers; | ||||
@@ -112,7 +112,7 @@ namespace Tensorflow | |||||
/// <param name="max_consumers"></param> | /// <param name="max_consumers"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern int TF_OperationOutputConsumers(TF_Output oper_out, ref IntPtr consumers, int max_consumers); | |||||
public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input * consumers, int max_consumers); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); | public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); | ||||
@@ -0,0 +1,604 @@ | |||||
// <auto-generated> | |||||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||||
// source: function.proto | |||||
// </auto-generated> | |||||
#pragma warning disable 1591, 0612, 3021 | |||||
#region Designer generated code | |||||
using pb = global::Google.Protobuf; | |||||
using pbc = global::Google.Protobuf.Collections; | |||||
using pbr = global::Google.Protobuf.Reflection; | |||||
using scg = global::System.Collections.Generic; | |||||
namespace Tensorflow { | |||||
/// <summary>Holder for reflection information generated from function.proto</summary> | |||||
public static partial class FunctionReflection { | |||||
#region Descriptor | |||||
/// <summary>File descriptor for function.proto</summary> | |||||
public static pbr::FileDescriptor Descriptor { | |||||
get { return descriptor; } | |||||
} | |||||
private static pbr::FileDescriptor descriptor; | |||||
static FunctionReflection() { | |||||
byte[] descriptorData = global::System.Convert.FromBase64String( | |||||
string.Concat( | |||||
"Cg5mdW5jdGlvbi5wcm90bxIKdGVuc29yZmxvdxoQYXR0cl92YWx1ZS5wcm90", | |||||
"bxoObm9kZV9kZWYucHJvdG8aDG9wX2RlZi5wcm90byJqChJGdW5jdGlvbkRl", | |||||
"ZkxpYnJhcnkSKQoIZnVuY3Rpb24YASADKAsyFy50ZW5zb3JmbG93LkZ1bmN0", | |||||
"aW9uRGVmEikKCGdyYWRpZW50GAIgAygLMhcudGVuc29yZmxvdy5HcmFkaWVu", | |||||
"dERlZiKwAgoLRnVuY3Rpb25EZWYSJAoJc2lnbmF0dXJlGAEgASgLMhEudGVu", | |||||
"c29yZmxvdy5PcERlZhIvCgRhdHRyGAUgAygLMiEudGVuc29yZmxvdy5GdW5j", | |||||
"dGlvbkRlZi5BdHRyRW50cnkSJQoIbm9kZV9kZWYYAyADKAsyEy50ZW5zb3Jm", | |||||
"bG93Lk5vZGVEZWYSLQoDcmV0GAQgAygLMiAudGVuc29yZmxvdy5GdW5jdGlv", | |||||
"bkRlZi5SZXRFbnRyeRpCCglBdHRyRW50cnkSCwoDa2V5GAEgASgJEiQKBXZh", | |||||
"bHVlGAIgASgLMhUudGVuc29yZmxvdy5BdHRyVmFsdWU6AjgBGioKCFJldEVu", | |||||
"dHJ5EgsKA2tleRgBIAEoCRINCgV2YWx1ZRgCIAEoCToCOAFKBAgCEAMiOwoL", | |||||
"R3JhZGllbnREZWYSFQoNZnVuY3Rpb25fbmFtZRgBIAEoCRIVCg1ncmFkaWVu", | |||||
"dF9mdW5jGAIgASgJQm4KGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IORnVu", | |||||
"Y3Rpb25Qcm90b3NQAVo9Z2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNvcmZs", | |||||
"b3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29ya/gBAWIGcHJvdG8z")); | |||||
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||||
new pbr::FileDescriptor[] { global::Tensorflow.AttrValueReflection.Descriptor, global::Tensorflow.NodeDefReflection.Descriptor, global::Tensorflow.OpDefReflection.Descriptor, }, | |||||
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { | |||||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionDefLibrary), global::Tensorflow.FunctionDefLibrary.Parser, new[]{ "Function", "Gradient" }, null, null, null), | |||||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionDef), global::Tensorflow.FunctionDef.Parser, new[]{ "Signature", "Attr", "NodeDef", "Ret" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, null, }), | |||||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GradientDef), global::Tensorflow.GradientDef.Parser, new[]{ "FunctionName", "GradientFunc" }, null, null, null) | |||||
})); | |||||
} | |||||
#endregion | |||||
} | |||||
#region Messages | |||||
/// <summary> | |||||
/// A library is a set of named functions. | |||||
/// </summary> | |||||
public sealed partial class FunctionDefLibrary : pb::IMessage<FunctionDefLibrary> { | |||||
private static readonly pb::MessageParser<FunctionDefLibrary> _parser = new pb::MessageParser<FunctionDefLibrary>(() => new FunctionDefLibrary()); | |||||
private pb::UnknownFieldSet _unknownFields; | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public static pb::MessageParser<FunctionDefLibrary> Parser { get { return _parser; } } | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public static pbr::MessageDescriptor Descriptor { | |||||
get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[0]; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||||
get { return Descriptor; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public FunctionDefLibrary() { | |||||
OnConstruction(); | |||||
} | |||||
partial void OnConstruction(); | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public FunctionDefLibrary(FunctionDefLibrary other) : this() { | |||||
function_ = other.function_.Clone(); | |||||
gradient_ = other.gradient_.Clone(); | |||||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public FunctionDefLibrary Clone() { | |||||
return new FunctionDefLibrary(this); | |||||
} | |||||
/// <summary>Field number for the "function" field.</summary> | |||||
public const int FunctionFieldNumber = 1; | |||||
private static readonly pb::FieldCodec<global::Tensorflow.FunctionDef> _repeated_function_codec | |||||
= pb::FieldCodec.ForMessage(10, global::Tensorflow.FunctionDef.Parser); | |||||
private readonly pbc::RepeatedField<global::Tensorflow.FunctionDef> function_ = new pbc::RepeatedField<global::Tensorflow.FunctionDef>(); | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public pbc::RepeatedField<global::Tensorflow.FunctionDef> Function { | |||||
get { return function_; } | |||||
} | |||||
/// <summary>Field number for the "gradient" field.</summary> | |||||
public const int GradientFieldNumber = 2; | |||||
private static readonly pb::FieldCodec<global::Tensorflow.GradientDef> _repeated_gradient_codec | |||||
= pb::FieldCodec.ForMessage(18, global::Tensorflow.GradientDef.Parser); | |||||
private readonly pbc::RepeatedField<global::Tensorflow.GradientDef> gradient_ = new pbc::RepeatedField<global::Tensorflow.GradientDef>(); | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public pbc::RepeatedField<global::Tensorflow.GradientDef> Gradient { | |||||
get { return gradient_; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override bool Equals(object other) { | |||||
return Equals(other as FunctionDefLibrary); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public bool Equals(FunctionDefLibrary other) { | |||||
if (ReferenceEquals(other, null)) { | |||||
return false; | |||||
} | |||||
if (ReferenceEquals(other, this)) { | |||||
return true; | |||||
} | |||||
if(!function_.Equals(other.function_)) return false; | |||||
if(!gradient_.Equals(other.gradient_)) return false; | |||||
return Equals(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override int GetHashCode() { | |||||
int hash = 1; | |||||
hash ^= function_.GetHashCode(); | |||||
hash ^= gradient_.GetHashCode(); | |||||
if (_unknownFields != null) { | |||||
hash ^= _unknownFields.GetHashCode(); | |||||
} | |||||
return hash; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override string ToString() { | |||||
return pb::JsonFormatter.ToDiagnosticString(this); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void WriteTo(pb::CodedOutputStream output) { | |||||
function_.WriteTo(output, _repeated_function_codec); | |||||
gradient_.WriteTo(output, _repeated_gradient_codec); | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(output); | |||||
} | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public int CalculateSize() { | |||||
int size = 0; | |||||
size += function_.CalculateSize(_repeated_function_codec); | |||||
size += gradient_.CalculateSize(_repeated_gradient_codec); | |||||
if (_unknownFields != null) { | |||||
size += _unknownFields.CalculateSize(); | |||||
} | |||||
return size; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void MergeFrom(FunctionDefLibrary other) { | |||||
if (other == null) { | |||||
return; | |||||
} | |||||
function_.Add(other.function_); | |||||
gradient_.Add(other.gradient_); | |||||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void MergeFrom(pb::CodedInputStream input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||||
break; | |||||
case 10: { | |||||
function_.AddEntriesFrom(input, _repeated_function_codec); | |||||
break; | |||||
} | |||||
case 18: { | |||||
gradient_.AddEntriesFrom(input, _repeated_gradient_codec); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// A function can be instantiated when the runtime can bind every attr | |||||
/// with a value. When a GraphDef has a call to a function, it must | |||||
/// have binding for every attr defined in the signature. | |||||
/// | |||||
/// TODO(zhifengc): | |||||
/// * device spec, etc. | |||||
/// </summary> | |||||
public sealed partial class FunctionDef : pb::IMessage<FunctionDef> { | |||||
private static readonly pb::MessageParser<FunctionDef> _parser = new pb::MessageParser<FunctionDef>(() => new FunctionDef()); | |||||
private pb::UnknownFieldSet _unknownFields; | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public static pb::MessageParser<FunctionDef> Parser { get { return _parser; } } | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public static pbr::MessageDescriptor Descriptor { | |||||
get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[1]; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||||
get { return Descriptor; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public FunctionDef() { | |||||
OnConstruction(); | |||||
} | |||||
partial void OnConstruction(); | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public FunctionDef(FunctionDef other) : this() { | |||||
signature_ = other.signature_ != null ? other.signature_.Clone() : null; | |||||
attr_ = other.attr_.Clone(); | |||||
nodeDef_ = other.nodeDef_.Clone(); | |||||
ret_ = other.ret_.Clone(); | |||||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public FunctionDef Clone() { | |||||
return new FunctionDef(this); | |||||
} | |||||
/// <summary>Field number for the "signature" field.</summary> | |||||
public const int SignatureFieldNumber = 1; | |||||
private global::Tensorflow.OpDef signature_; | |||||
/// <summary> | |||||
/// The definition of the function's name, arguments, return values, | |||||
/// attrs etc. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public global::Tensorflow.OpDef Signature { | |||||
get { return signature_; } | |||||
set { | |||||
signature_ = value; | |||||
} | |||||
} | |||||
/// <summary>Field number for the "attr" field.</summary> | |||||
public const int AttrFieldNumber = 5; | |||||
private static readonly pbc::MapField<string, global::Tensorflow.AttrValue>.Codec _map_attr_codec | |||||
= new pbc::MapField<string, global::Tensorflow.AttrValue>.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 42); | |||||
private readonly pbc::MapField<string, global::Tensorflow.AttrValue> attr_ = new pbc::MapField<string, global::Tensorflow.AttrValue>(); | |||||
/// <summary> | |||||
/// Attributes specific to this function definition. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public pbc::MapField<string, global::Tensorflow.AttrValue> Attr { | |||||
get { return attr_; } | |||||
} | |||||
/// <summary>Field number for the "node_def" field.</summary> | |||||
public const int NodeDefFieldNumber = 3; | |||||
private static readonly pb::FieldCodec<global::Tensorflow.NodeDef> _repeated_nodeDef_codec | |||||
= pb::FieldCodec.ForMessage(26, global::Tensorflow.NodeDef.Parser); | |||||
private readonly pbc::RepeatedField<global::Tensorflow.NodeDef> nodeDef_ = new pbc::RepeatedField<global::Tensorflow.NodeDef>(); | |||||
/// <summary> | |||||
/// By convention, "op" in node_def is resolved by consulting with a | |||||
/// user-defined library first. If not resolved, "func" is assumed to | |||||
/// be a builtin op. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public pbc::RepeatedField<global::Tensorflow.NodeDef> NodeDef { | |||||
get { return nodeDef_; } | |||||
} | |||||
/// <summary>Field number for the "ret" field.</summary> | |||||
public const int RetFieldNumber = 4; | |||||
private static readonly pbc::MapField<string, string>.Codec _map_ret_codec | |||||
= new pbc::MapField<string, string>.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForString(18), 34); | |||||
private readonly pbc::MapField<string, string> ret_ = new pbc::MapField<string, string>(); | |||||
/// <summary> | |||||
/// A mapping from the output arg names from `signature` to the | |||||
/// outputs from `node_def` that should be returned by the function. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public pbc::MapField<string, string> Ret { | |||||
get { return ret_; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override bool Equals(object other) { | |||||
return Equals(other as FunctionDef); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public bool Equals(FunctionDef other) { | |||||
if (ReferenceEquals(other, null)) { | |||||
return false; | |||||
} | |||||
if (ReferenceEquals(other, this)) { | |||||
return true; | |||||
} | |||||
if (!object.Equals(Signature, other.Signature)) return false; | |||||
if (!Attr.Equals(other.Attr)) return false; | |||||
if(!nodeDef_.Equals(other.nodeDef_)) return false; | |||||
if (!Ret.Equals(other.Ret)) return false; | |||||
return Equals(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override int GetHashCode() { | |||||
int hash = 1; | |||||
if (signature_ != null) hash ^= Signature.GetHashCode(); | |||||
hash ^= Attr.GetHashCode(); | |||||
hash ^= nodeDef_.GetHashCode(); | |||||
hash ^= Ret.GetHashCode(); | |||||
if (_unknownFields != null) { | |||||
hash ^= _unknownFields.GetHashCode(); | |||||
} | |||||
return hash; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override string ToString() { | |||||
return pb::JsonFormatter.ToDiagnosticString(this); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void WriteTo(pb::CodedOutputStream output) { | |||||
if (signature_ != null) { | |||||
output.WriteRawTag(10); | |||||
output.WriteMessage(Signature); | |||||
} | |||||
nodeDef_.WriteTo(output, _repeated_nodeDef_codec); | |||||
ret_.WriteTo(output, _map_ret_codec); | |||||
attr_.WriteTo(output, _map_attr_codec); | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(output); | |||||
} | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public int CalculateSize() { | |||||
int size = 0; | |||||
if (signature_ != null) { | |||||
size += 1 + pb::CodedOutputStream.ComputeMessageSize(Signature); | |||||
} | |||||
size += attr_.CalculateSize(_map_attr_codec); | |||||
size += nodeDef_.CalculateSize(_repeated_nodeDef_codec); | |||||
size += ret_.CalculateSize(_map_ret_codec); | |||||
if (_unknownFields != null) { | |||||
size += _unknownFields.CalculateSize(); | |||||
} | |||||
return size; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void MergeFrom(FunctionDef other) { | |||||
if (other == null) { | |||||
return; | |||||
} | |||||
if (other.signature_ != null) { | |||||
if (signature_ == null) { | |||||
signature_ = new global::Tensorflow.OpDef(); | |||||
} | |||||
Signature.MergeFrom(other.Signature); | |||||
} | |||||
attr_.Add(other.attr_); | |||||
nodeDef_.Add(other.nodeDef_); | |||||
ret_.Add(other.ret_); | |||||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void MergeFrom(pb::CodedInputStream input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||||
break; | |||||
case 10: { | |||||
if (signature_ == null) { | |||||
signature_ = new global::Tensorflow.OpDef(); | |||||
} | |||||
input.ReadMessage(signature_); | |||||
break; | |||||
} | |||||
case 26: { | |||||
nodeDef_.AddEntriesFrom(input, _repeated_nodeDef_codec); | |||||
break; | |||||
} | |||||
case 34: { | |||||
ret_.AddEntriesFrom(input, _map_ret_codec); | |||||
break; | |||||
} | |||||
case 42: { | |||||
attr_.AddEntriesFrom(input, _map_attr_codec); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// GradientDef defines the gradient function of a function defined in | |||||
/// a function library. | |||||
/// | |||||
/// A gradient function g (specified by gradient_func) for a function f | |||||
/// (specified by function_name) must follow the following: | |||||
/// | |||||
/// The function 'f' must be a numerical function which takes N inputs | |||||
/// and produces M outputs. Its gradient function 'g', which is a | |||||
/// function taking N + M inputs and produces N outputs. | |||||
/// | |||||
/// I.e. if we have | |||||
/// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), | |||||
/// then, g is | |||||
/// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, | |||||
/// dL/dy1, dL/dy2, ..., dL/dy_M), | |||||
/// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the | |||||
/// loss function). dL/dx_i is the partial derivative of L with respect | |||||
/// to x_i. | |||||
/// </summary> | |||||
public sealed partial class GradientDef : pb::IMessage<GradientDef> { | |||||
private static readonly pb::MessageParser<GradientDef> _parser = new pb::MessageParser<GradientDef>(() => new GradientDef()); | |||||
private pb::UnknownFieldSet _unknownFields; | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public static pb::MessageParser<GradientDef> Parser { get { return _parser; } } | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public static pbr::MessageDescriptor Descriptor { | |||||
get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[2]; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||||
get { return Descriptor; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public GradientDef() { | |||||
OnConstruction(); | |||||
} | |||||
partial void OnConstruction(); | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public GradientDef(GradientDef other) : this() { | |||||
functionName_ = other.functionName_; | |||||
gradientFunc_ = other.gradientFunc_; | |||||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public GradientDef Clone() { | |||||
return new GradientDef(this); | |||||
} | |||||
/// <summary>Field number for the "function_name" field.</summary> | |||||
public const int FunctionNameFieldNumber = 1; | |||||
private string functionName_ = ""; | |||||
/// <summary> | |||||
/// The function name. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public string FunctionName { | |||||
get { return functionName_; } | |||||
set { | |||||
functionName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||||
} | |||||
} | |||||
/// <summary>Field number for the "gradient_func" field.</summary> | |||||
public const int GradientFuncFieldNumber = 2; | |||||
private string gradientFunc_ = ""; | |||||
/// <summary> | |||||
/// The gradient function's name. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public string GradientFunc { | |||||
get { return gradientFunc_; } | |||||
set { | |||||
gradientFunc_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||||
} | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override bool Equals(object other) { | |||||
return Equals(other as GradientDef); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public bool Equals(GradientDef other) { | |||||
if (ReferenceEquals(other, null)) { | |||||
return false; | |||||
} | |||||
if (ReferenceEquals(other, this)) { | |||||
return true; | |||||
} | |||||
if (FunctionName != other.FunctionName) return false; | |||||
if (GradientFunc != other.GradientFunc) return false; | |||||
return Equals(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override int GetHashCode() { | |||||
int hash = 1; | |||||
if (FunctionName.Length != 0) hash ^= FunctionName.GetHashCode(); | |||||
if (GradientFunc.Length != 0) hash ^= GradientFunc.GetHashCode(); | |||||
if (_unknownFields != null) { | |||||
hash ^= _unknownFields.GetHashCode(); | |||||
} | |||||
return hash; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override string ToString() { | |||||
return pb::JsonFormatter.ToDiagnosticString(this); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void WriteTo(pb::CodedOutputStream output) { | |||||
if (FunctionName.Length != 0) { | |||||
output.WriteRawTag(10); | |||||
output.WriteString(FunctionName); | |||||
} | |||||
if (GradientFunc.Length != 0) { | |||||
output.WriteRawTag(18); | |||||
output.WriteString(GradientFunc); | |||||
} | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(output); | |||||
} | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public int CalculateSize() { | |||||
int size = 0; | |||||
if (FunctionName.Length != 0) { | |||||
size += 1 + pb::CodedOutputStream.ComputeStringSize(FunctionName); | |||||
} | |||||
if (GradientFunc.Length != 0) { | |||||
size += 1 + pb::CodedOutputStream.ComputeStringSize(GradientFunc); | |||||
} | |||||
if (_unknownFields != null) { | |||||
size += _unknownFields.CalculateSize(); | |||||
} | |||||
return size; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void MergeFrom(GradientDef other) { | |||||
if (other == null) { | |||||
return; | |||||
} | |||||
if (other.FunctionName.Length != 0) { | |||||
FunctionName = other.FunctionName; | |||||
} | |||||
if (other.GradientFunc.Length != 0) { | |||||
GradientFunc = other.GradientFunc; | |||||
} | |||||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void MergeFrom(pb::CodedInputStream input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||||
break; | |||||
case 10: { | |||||
FunctionName = input.ReadString(); | |||||
break; | |||||
} | |||||
case 18: { | |||||
GradientFunc = input.ReadString(); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#endregion | |||||
} | |||||
#endregion Designer generated code |
@@ -0,0 +1,309 @@ | |||||
// <auto-generated> | |||||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||||
// source: graph.proto | |||||
// </auto-generated> | |||||
#pragma warning disable 1591, 0612, 3021 | |||||
#region Designer generated code | |||||
using pb = global::Google.Protobuf; | |||||
using pbc = global::Google.Protobuf.Collections; | |||||
using pbr = global::Google.Protobuf.Reflection; | |||||
using scg = global::System.Collections.Generic; | |||||
namespace Tensorflow { | |||||
/// <summary>Holder for reflection information generated from graph.proto</summary> | |||||
public static partial class GraphReflection { | |||||
#region Descriptor | |||||
/// <summary>File descriptor for graph.proto</summary> | |||||
public static pbr::FileDescriptor Descriptor { | |||||
get { return descriptor; } | |||||
} | |||||
private static pbr::FileDescriptor descriptor; | |||||
static GraphReflection() { | |||||
byte[] descriptorData = global::System.Convert.FromBase64String( | |||||
string.Concat( | |||||
"CgtncmFwaC5wcm90bxIKdGVuc29yZmxvdxoObm9kZV9kZWYucHJvdG8aDmZ1", | |||||
"bmN0aW9uLnByb3RvGg52ZXJzaW9ucy5wcm90byKdAQoIR3JhcGhEZWYSIQoE", | |||||
"bm9kZRgBIAMoCzITLnRlbnNvcmZsb3cuTm9kZURlZhIoCgh2ZXJzaW9ucxgE", | |||||
"IAEoCzIWLnRlbnNvcmZsb3cuVmVyc2lvbkRlZhITCgd2ZXJzaW9uGAMgASgF", | |||||
"QgIYARIvCgdsaWJyYXJ5GAIgASgLMh4udGVuc29yZmxvdy5GdW5jdGlvbkRl", | |||||
"ZkxpYnJhcnlCawoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3JrQgtHcmFwaFBy", | |||||
"b3Rvc1ABWj1naXRodWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90ZW5z", | |||||
"b3JmbG93L2dvL2NvcmUvZnJhbWV3b3Jr+AEBYgZwcm90bzM=")); | |||||
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||||
new pbr::FileDescriptor[] { global::Tensorflow.NodeDefReflection.Descriptor, global::Tensorflow.FunctionReflection.Descriptor, global::Tensorflow.VersionsReflection.Descriptor, }, | |||||
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { | |||||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphDef), global::Tensorflow.GraphDef.Parser, new[]{ "Node", "Versions", "Version", "Library" }, null, null, null) | |||||
})); | |||||
} | |||||
#endregion | |||||
} | |||||
#region Messages | |||||
/// <summary> | |||||
/// Represents the graph of operations | |||||
/// </summary> | |||||
public sealed partial class GraphDef : pb::IMessage<GraphDef> { | |||||
private static readonly pb::MessageParser<GraphDef> _parser = new pb::MessageParser<GraphDef>(() => new GraphDef()); | |||||
private pb::UnknownFieldSet _unknownFields; | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public static pb::MessageParser<GraphDef> Parser { get { return _parser; } } | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public static pbr::MessageDescriptor Descriptor { | |||||
get { return global::Tensorflow.GraphReflection.Descriptor.MessageTypes[0]; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||||
get { return Descriptor; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public GraphDef() { | |||||
OnConstruction(); | |||||
} | |||||
partial void OnConstruction(); | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public GraphDef(GraphDef other) : this() { | |||||
node_ = other.node_.Clone(); | |||||
versions_ = other.versions_ != null ? other.versions_.Clone() : null; | |||||
version_ = other.version_; | |||||
library_ = other.library_ != null ? other.library_.Clone() : null; | |||||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public GraphDef Clone() { | |||||
return new GraphDef(this); | |||||
} | |||||
/// <summary>Field number for the "node" field.</summary> | |||||
public const int NodeFieldNumber = 1; | |||||
private static readonly pb::FieldCodec<global::Tensorflow.NodeDef> _repeated_node_codec | |||||
= pb::FieldCodec.ForMessage(10, global::Tensorflow.NodeDef.Parser); | |||||
private readonly pbc::RepeatedField<global::Tensorflow.NodeDef> node_ = new pbc::RepeatedField<global::Tensorflow.NodeDef>(); | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public pbc::RepeatedField<global::Tensorflow.NodeDef> Node { | |||||
get { return node_; } | |||||
} | |||||
/// <summary>Field number for the "versions" field.</summary> | |||||
public const int VersionsFieldNumber = 4; | |||||
private global::Tensorflow.VersionDef versions_; | |||||
/// <summary> | |||||
/// Compatibility versions of the graph. See core/public/version.h for version | |||||
/// history. The GraphDef version is distinct from the TensorFlow version, and | |||||
/// each release of TensorFlow will support a range of GraphDef versions. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public global::Tensorflow.VersionDef Versions { | |||||
get { return versions_; } | |||||
set { | |||||
versions_ = value; | |||||
} | |||||
} | |||||
/// <summary>Field number for the "version" field.</summary> | |||||
public const int VersionFieldNumber = 3; | |||||
private int version_; | |||||
/// <summary> | |||||
/// Deprecated single version field; use versions above instead. Since all | |||||
/// GraphDef changes before "versions" was introduced were forward | |||||
/// compatible, this field is entirely ignored. | |||||
/// </summary> | |||||
[global::System.ObsoleteAttribute] | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public int Version { | |||||
get { return version_; } | |||||
set { | |||||
version_ = value; | |||||
} | |||||
} | |||||
/// <summary>Field number for the "library" field.</summary> | |||||
public const int LibraryFieldNumber = 2; | |||||
private global::Tensorflow.FunctionDefLibrary library_; | |||||
/// <summary> | |||||
/// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. | |||||
/// | |||||
/// "library" provides user-defined functions. | |||||
/// | |||||
/// Naming: | |||||
/// * library.function.name are in a flat namespace. | |||||
/// NOTE: We may need to change it to be hierarchical to support | |||||
/// different orgs. E.g., | |||||
/// { "/google/nn", { ... }}, | |||||
/// { "/google/vision", { ... }} | |||||
/// { "/org_foo/module_bar", { ... }} | |||||
/// map<string, FunctionDefLib> named_lib; | |||||
/// * If node[i].op is the name of one function in "library", | |||||
/// node[i] is deemed as a function call. Otherwise, node[i].op | |||||
/// must be a primitive operation supported by the runtime. | |||||
/// | |||||
/// Function call semantics: | |||||
/// | |||||
/// * The callee may start execution as soon as some of its inputs | |||||
/// are ready. The caller may want to use Tuple() mechanism to | |||||
/// ensure all inputs are ready in the same time. | |||||
/// | |||||
/// * The consumer of return values may start executing as soon as | |||||
/// the return values the consumer depends on are ready. The | |||||
/// consumer may want to use Tuple() mechanism to ensure the | |||||
/// consumer does not start until all return values of the callee | |||||
/// function are ready. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public global::Tensorflow.FunctionDefLibrary Library { | |||||
get { return library_; } | |||||
set { | |||||
library_ = value; | |||||
} | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override bool Equals(object other) { | |||||
return Equals(other as GraphDef); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public bool Equals(GraphDef other) { | |||||
if (ReferenceEquals(other, null)) { | |||||
return false; | |||||
} | |||||
if (ReferenceEquals(other, this)) { | |||||
return true; | |||||
} | |||||
if(!node_.Equals(other.node_)) return false; | |||||
if (!object.Equals(Versions, other.Versions)) return false; | |||||
if (Version != other.Version) return false; | |||||
if (!object.Equals(Library, other.Library)) return false; | |||||
return Equals(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override int GetHashCode() { | |||||
int hash = 1; | |||||
hash ^= node_.GetHashCode(); | |||||
if (versions_ != null) hash ^= Versions.GetHashCode(); | |||||
if (Version != 0) hash ^= Version.GetHashCode(); | |||||
if (library_ != null) hash ^= Library.GetHashCode(); | |||||
if (_unknownFields != null) { | |||||
hash ^= _unknownFields.GetHashCode(); | |||||
} | |||||
return hash; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override string ToString() { | |||||
return pb::JsonFormatter.ToDiagnosticString(this); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void WriteTo(pb::CodedOutputStream output) { | |||||
node_.WriteTo(output, _repeated_node_codec); | |||||
if (library_ != null) { | |||||
output.WriteRawTag(18); | |||||
output.WriteMessage(Library); | |||||
} | |||||
if (Version != 0) { | |||||
output.WriteRawTag(24); | |||||
output.WriteInt32(Version); | |||||
} | |||||
if (versions_ != null) { | |||||
output.WriteRawTag(34); | |||||
output.WriteMessage(Versions); | |||||
} | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(output); | |||||
} | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public int CalculateSize() { | |||||
int size = 0; | |||||
size += node_.CalculateSize(_repeated_node_codec); | |||||
if (versions_ != null) { | |||||
size += 1 + pb::CodedOutputStream.ComputeMessageSize(Versions); | |||||
} | |||||
if (Version != 0) { | |||||
size += 1 + pb::CodedOutputStream.ComputeInt32Size(Version); | |||||
} | |||||
if (library_ != null) { | |||||
size += 1 + pb::CodedOutputStream.ComputeMessageSize(Library); | |||||
} | |||||
if (_unknownFields != null) { | |||||
size += _unknownFields.CalculateSize(); | |||||
} | |||||
return size; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void MergeFrom(GraphDef other) { | |||||
if (other == null) { | |||||
return; | |||||
} | |||||
node_.Add(other.node_); | |||||
if (other.versions_ != null) { | |||||
if (versions_ == null) { | |||||
versions_ = new global::Tensorflow.VersionDef(); | |||||
} | |||||
Versions.MergeFrom(other.Versions); | |||||
} | |||||
if (other.Version != 0) { | |||||
Version = other.Version; | |||||
} | |||||
if (other.library_ != null) { | |||||
if (library_ == null) { | |||||
library_ = new global::Tensorflow.FunctionDefLibrary(); | |||||
} | |||||
Library.MergeFrom(other.Library); | |||||
} | |||||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void MergeFrom(pb::CodedInputStream input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||||
break; | |||||
case 10: { | |||||
node_.AddEntriesFrom(input, _repeated_node_codec); | |||||
break; | |||||
} | |||||
case 18: { | |||||
if (library_ == null) { | |||||
library_ = new global::Tensorflow.FunctionDefLibrary(); | |||||
} | |||||
input.ReadMessage(library_); | |||||
break; | |||||
} | |||||
case 24: { | |||||
Version = input.ReadInt32(); | |||||
break; | |||||
} | |||||
case 34: { | |||||
if (versions_ == null) { | |||||
versions_ = new global::Tensorflow.VersionDef(); | |||||
} | |||||
input.ReadMessage(versions_); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#endregion | |||||
} | |||||
#endregion Designer generated code |
@@ -1,12 +1,15 @@ | |||||
### Download compiler from https://github.com/protocolbuffers/protobuf/releases | ### Download compiler from https://github.com/protocolbuffers/protobuf/releases | ||||
```shell | ```shell | ||||
set SRC_DIR=D:\Projects\tensorflow\tensorflow\core\framework | |||||
set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Tensorflow | |||||
set SRC_DIR=D:\Projects\tensorflow-1.12.0\tensorflow\core\framework | |||||
set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Protobuf | |||||
.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% resource_handle.proto | |||||
.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor_shape.proto | |||||
.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% types.proto | |||||
.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor.proto | |||||
.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% attr_value.proto | |||||
.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% node_def.proto | |||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% resource_handle.proto | |||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor_shape.proto | |||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% types.proto | |||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor.proto | |||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% attr_value.proto | |||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% node_def.proto | |||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% versions.proto | |||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% function.proto | |||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% graph.proto | |||||
``` | ``` |
@@ -0,0 +1,247 @@ | |||||
// <auto-generated> | |||||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||||
// source: versions.proto | |||||
// </auto-generated> | |||||
#pragma warning disable 1591, 0612, 3021 | |||||
#region Designer generated code | |||||
using pb = global::Google.Protobuf; | |||||
using pbc = global::Google.Protobuf.Collections; | |||||
using pbr = global::Google.Protobuf.Reflection; | |||||
using scg = global::System.Collections.Generic; | |||||
namespace Tensorflow { | |||||
/// <summary>Holder for reflection information generated from versions.proto</summary> | |||||
public static partial class VersionsReflection { | |||||
#region Descriptor | |||||
/// <summary>File descriptor for versions.proto</summary> | |||||
public static pbr::FileDescriptor Descriptor { | |||||
get { return descriptor; } | |||||
} | |||||
private static pbr::FileDescriptor descriptor; | |||||
static VersionsReflection() { | |||||
byte[] descriptorData = global::System.Convert.FromBase64String( | |||||
string.Concat( | |||||
"Cg52ZXJzaW9ucy5wcm90bxIKdGVuc29yZmxvdyJLCgpWZXJzaW9uRGVmEhAK", | |||||
"CHByb2R1Y2VyGAEgASgFEhQKDG1pbl9jb25zdW1lchgCIAEoBRIVCg1iYWRf", | |||||
"Y29uc3VtZXJzGAMgAygFQm4KGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IO", | |||||
"VmVyc2lvbnNQcm90b3NQAVo9Z2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNv", | |||||
"cmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29ya/gBAWIGcHJvdG8z")); | |||||
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||||
new pbr::FileDescriptor[] { }, | |||||
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { | |||||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.VersionDef), global::Tensorflow.VersionDef.Parser, new[]{ "Producer", "MinConsumer", "BadConsumers" }, null, null, null) | |||||
})); | |||||
} | |||||
#endregion | |||||
} | |||||
#region Messages | |||||
/// <summary> | |||||
/// Version information for a piece of serialized data | |||||
/// | |||||
/// There are different types of versions for each type of data | |||||
/// (GraphDef, etc.), but they all have the same common shape | |||||
/// described here. | |||||
/// | |||||
/// Each consumer has "consumer" and "min_producer" versions (specified | |||||
/// elsewhere). A consumer is allowed to consume this data if | |||||
/// | |||||
/// producer >= min_producer | |||||
/// consumer >= min_consumer | |||||
/// consumer not in bad_consumers | |||||
/// </summary> | |||||
public sealed partial class VersionDef : pb::IMessage<VersionDef> { | |||||
private static readonly pb::MessageParser<VersionDef> _parser = new pb::MessageParser<VersionDef>(() => new VersionDef()); | |||||
private pb::UnknownFieldSet _unknownFields; | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public static pb::MessageParser<VersionDef> Parser { get { return _parser; } } | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public static pbr::MessageDescriptor Descriptor { | |||||
get { return global::Tensorflow.VersionsReflection.Descriptor.MessageTypes[0]; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||||
get { return Descriptor; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public VersionDef() { | |||||
OnConstruction(); | |||||
} | |||||
partial void OnConstruction(); | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public VersionDef(VersionDef other) : this() { | |||||
producer_ = other.producer_; | |||||
minConsumer_ = other.minConsumer_; | |||||
badConsumers_ = other.badConsumers_.Clone(); | |||||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public VersionDef Clone() { | |||||
return new VersionDef(this); | |||||
} | |||||
/// <summary>Field number for the "producer" field.</summary> | |||||
public const int ProducerFieldNumber = 1; | |||||
private int producer_; | |||||
/// <summary> | |||||
/// The version of the code that produced this data. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public int Producer { | |||||
get { return producer_; } | |||||
set { | |||||
producer_ = value; | |||||
} | |||||
} | |||||
/// <summary>Field number for the "min_consumer" field.</summary> | |||||
public const int MinConsumerFieldNumber = 2; | |||||
private int minConsumer_; | |||||
/// <summary> | |||||
/// Any consumer below this version is not allowed to consume this data. | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public int MinConsumer { | |||||
get { return minConsumer_; } | |||||
set { | |||||
minConsumer_ = value; | |||||
} | |||||
} | |||||
/// <summary>Field number for the "bad_consumers" field.</summary> | |||||
public const int BadConsumersFieldNumber = 3; | |||||
private static readonly pb::FieldCodec<int> _repeated_badConsumers_codec | |||||
= pb::FieldCodec.ForInt32(26); | |||||
private readonly pbc::RepeatedField<int> badConsumers_ = new pbc::RepeatedField<int>(); | |||||
/// <summary> | |||||
/// Specific consumer versions which are disallowed (e.g. due to bugs). | |||||
/// </summary> | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public pbc::RepeatedField<int> BadConsumers { | |||||
get { return badConsumers_; } | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override bool Equals(object other) { | |||||
return Equals(other as VersionDef); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public bool Equals(VersionDef other) { | |||||
if (ReferenceEquals(other, null)) { | |||||
return false; | |||||
} | |||||
if (ReferenceEquals(other, this)) { | |||||
return true; | |||||
} | |||||
if (Producer != other.Producer) return false; | |||||
if (MinConsumer != other.MinConsumer) return false; | |||||
if(!badConsumers_.Equals(other.badConsumers_)) return false; | |||||
return Equals(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override int GetHashCode() { | |||||
int hash = 1; | |||||
if (Producer != 0) hash ^= Producer.GetHashCode(); | |||||
if (MinConsumer != 0) hash ^= MinConsumer.GetHashCode(); | |||||
hash ^= badConsumers_.GetHashCode(); | |||||
if (_unknownFields != null) { | |||||
hash ^= _unknownFields.GetHashCode(); | |||||
} | |||||
return hash; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public override string ToString() { | |||||
return pb::JsonFormatter.ToDiagnosticString(this); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void WriteTo(pb::CodedOutputStream output) { | |||||
if (Producer != 0) { | |||||
output.WriteRawTag(8); | |||||
output.WriteInt32(Producer); | |||||
} | |||||
if (MinConsumer != 0) { | |||||
output.WriteRawTag(16); | |||||
output.WriteInt32(MinConsumer); | |||||
} | |||||
badConsumers_.WriteTo(output, _repeated_badConsumers_codec); | |||||
if (_unknownFields != null) { | |||||
_unknownFields.WriteTo(output); | |||||
} | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public int CalculateSize() { | |||||
int size = 0; | |||||
if (Producer != 0) { | |||||
size += 1 + pb::CodedOutputStream.ComputeInt32Size(Producer); | |||||
} | |||||
if (MinConsumer != 0) { | |||||
size += 1 + pb::CodedOutputStream.ComputeInt32Size(MinConsumer); | |||||
} | |||||
size += badConsumers_.CalculateSize(_repeated_badConsumers_codec); | |||||
if (_unknownFields != null) { | |||||
size += _unknownFields.CalculateSize(); | |||||
} | |||||
return size; | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void MergeFrom(VersionDef other) { | |||||
if (other == null) { | |||||
return; | |||||
} | |||||
if (other.Producer != 0) { | |||||
Producer = other.Producer; | |||||
} | |||||
if (other.MinConsumer != 0) { | |||||
MinConsumer = other.MinConsumer; | |||||
} | |||||
badConsumers_.Add(other.badConsumers_); | |||||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||||
} | |||||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
public void MergeFrom(pb::CodedInputStream input) { | |||||
uint tag; | |||||
while ((tag = input.ReadTag()) != 0) { | |||||
switch(tag) { | |||||
default: | |||||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||||
break; | |||||
case 8: { | |||||
Producer = input.ReadInt32(); | |||||
break; | |||||
} | |||||
case 16: { | |||||
MinConsumer = input.ReadInt32(); | |||||
break; | |||||
} | |||||
case 26: | |||||
case 24: { | |||||
badConsumers_.AddEntriesFrom(input, _repeated_badConsumers_codec); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#endregion | |||||
} | |||||
#endregion Designer generated code |
@@ -17,7 +17,7 @@ namespace Tensorflow | |||||
/// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) | /// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) | ||||
/// struct => struct (TF_Output output) => (TF_Output output) | /// struct => struct (TF_Output output) => (TF_Output output) | ||||
/// struct* => struct (TF_Output* output) => (TF_Output[] output) | /// struct* => struct (TF_Output* output) => (TF_Output[] output) | ||||
/// struct* => ref IntPtr (TF_Input* consumers) => (ref IntPtr handle), if output is struct[] | |||||
/// struct* => struct* for ref | |||||
/// const char* => string | /// const char* => string | ||||
/// int32_t => int | /// int32_t => int | ||||
/// int64_t* => long[] | /// int64_t* => long[] | ||||
@@ -83,6 +83,42 @@ namespace TensorFlowNET.UnitTest | |||||
Assert.AreEqual(1, feed_port.Length); | Assert.AreEqual(1, feed_port.Length); | ||||
Assert.AreEqual(add, feed_port[0].oper); | Assert.AreEqual(add, feed_port[0].oper); | ||||
Assert.AreEqual(0, feed_port[0].index); | Assert.AreEqual(0, feed_port[0].index); | ||||
// The scalar const oper also has a consumer. | |||||
Assert.AreEqual(1, three.OutputNumConsumers(0)); | |||||
TF_Input[] three_port = three.OutputConsumers(0, 1); | |||||
Assert.AreEqual(add, three_port[0].oper); | |||||
Assert.AreEqual(1, three_port[0].index); | |||||
// Serialize to GraphDef. | |||||
var graph_def = c_test_util.GetGraphDef(graph); | |||||
// Validate GraphDef is what we expect. | |||||
bool found_placeholder = false; | |||||
bool found_scalar_const = false; | |||||
bool found_add = false; | |||||
foreach (var n in graph_def.Node) | |||||
{ | |||||
if (c_test_util.IsPlaceholder(n)) | |||||
{ | |||||
Assert.IsFalse(found_placeholder); | |||||
found_placeholder = true; | |||||
} | |||||
/*else if (IsScalarConst(n, 3)) | |||||
{ | |||||
Assert.IsFalse(found_scalar_const); | |||||
found_scalar_const = true; | |||||
} | |||||
else if (IsAddN(n, 2)) | |||||
{ | |||||
Assert.IsFalse(found_add); | |||||
found_add = true; | |||||
} | |||||
else | |||||
{ | |||||
ADD_FAILURE() << "Unexpected NodeDef: " << ProtoDebugString(n); | |||||
}*/ | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -19,7 +19,7 @@ namespace TensorFlowNET.UnitTest | |||||
{ | { | ||||
var handle = c_api.TF_GetAllOpList(); | var handle = c_api.TF_GetAllOpList(); | ||||
var buffer = new Buffer(handle); | var buffer = new Buffer(handle); | ||||
Assert.IsTrue(buffer.Length == buffer.Data.Length); | |||||
Assert.IsTrue(buffer.Length == buffer.Length); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -39,11 +39,20 @@ namespace TensorFlowNET.UnitTest | |||||
{ | { | ||||
var buffer = new Buffer(); | var buffer = new Buffer(); | ||||
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | ||||
attr_value = AttrValue.Parser.ParseFrom(buffer.Data); | |||||
attr_value = AttrValue.Parser.ParseFrom(buffer); | |||||
buffer.Dispose(); | buffer.Dispose(); | ||||
return s.Code == TF_Code.TF_OK; | return s.Code == TF_Code.TF_OK; | ||||
} | } | ||||
public static GraphDef GetGraphDef(Graph graph) | |||||
{ | |||||
var s = new Status(); | |||||
var buffer = new Buffer(); | |||||
c_api.TF_GraphToGraphDef(graph, buffer, s); | |||||
s.Check(); | |||||
return GraphDef.Parser.ParseFrom(buffer); | |||||
} | |||||
public static bool GetNodeDef(Operation oper, ref NodeDef node_def) | public static bool GetNodeDef(Operation oper, ref NodeDef node_def) | ||||
{ | { | ||||
var s = new Status(); | var s = new Status(); | ||||
@@ -53,6 +62,37 @@ namespace TensorFlowNET.UnitTest | |||||
return s.Code == TF_Code.TF_OK; | return s.Code == TF_Code.TF_OK; | ||||
} | } | ||||
public static bool IsPlaceholder(NodeDef node_def) | |||||
{ | |||||
if (node_def.Op != "Placeholder" || node_def.Name != "feed") | |||||
{ | |||||
return false; | |||||
} | |||||
bool found_dtype = false; | |||||
bool found_shape = false; | |||||
foreach (var attr in node_def.Attr) | |||||
{ | |||||
if (attr.Key == "dtype") | |||||
{ | |||||
if (attr.Value.Type == DataType.DtInt32) | |||||
{ | |||||
found_dtype = true; | |||||
} | |||||
else | |||||
{ | |||||
return false; | |||||
} | |||||
} | |||||
else if (attr.Key == "shape") | |||||
{ | |||||
found_shape = true; | |||||
} | |||||
} | |||||
return found_dtype && found_shape; | |||||
} | |||||
public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op) | public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op) | ||||
{ | { | ||||
var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | ||||