diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs index 112afc9c..4605e37a 100644 --- a/src/TensorFlowNET.Core/Buffers/Buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs @@ -39,6 +39,11 @@ namespace Tensorflow return buffer._handle; } + public static implicit operator byte[](Buffer buffer) + { + return buffer.Data; + } + public void Dispose() { c_api.TF_DeleteBuffer(_handle); diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index 8578f33a..b2e9947b 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -38,6 +38,16 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); + /// + /// Write out a serialized representation of `graph` (as a GraphDef protocol + /// message) to `output_graph_def` (allocated by TF_NewBuffer()). + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_GraphToGraphDef(IntPtr graph, IntPtr output_graph_def, IntPtr status); + /// /// Returns the number of dimensions of the Tensor referenced by `output` /// in `graph`. diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 02d29e08..c8a2933f 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -26,15 +26,15 @@ namespace Tensorflow public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status); public int NumInputs => c_api.TF_OperationNumInputs(_handle); public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); - public TF_Input[] OutputConsumers(int index, int max_consumers) + public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) { - IntPtr handle = IntPtr.Zero; int size = Marshal.SizeOf(); - int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), ref handle, max_consumers); + var handle = (TF_Input*)Marshal.AllocHGlobal(size); + int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); var consumers = new TF_Input[num]; for(int i = 0; i < num; i++) { - consumers[0] = Marshal.PtrToStructure(handle + i * size); + consumers[i] = new TF_Input((*handle).oper + i * size, (*handle).index); } return consumers; diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index 02839147..0a090cbc 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -112,7 +112,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static extern int TF_OperationOutputConsumers(TF_Output oper_out, ref IntPtr consumers, int max_consumers); + public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input * consumers, int max_consumers); [DllImport(TensorFlowLibName)] public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); diff --git a/src/TensorFlowNET.Core/Protobuf/Function.cs b/src/TensorFlowNET.Core/Protobuf/Function.cs new file mode 100644 index 00000000..4aac8252 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Function.cs @@ -0,0 +1,604 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: function.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from function.proto + public static partial class FunctionReflection { + + #region Descriptor + /// File descriptor for function.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static FunctionReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cg5mdW5jdGlvbi5wcm90bxIKdGVuc29yZmxvdxoQYXR0cl92YWx1ZS5wcm90", + "bxoObm9kZV9kZWYucHJvdG8aDG9wX2RlZi5wcm90byJqChJGdW5jdGlvbkRl", + "ZkxpYnJhcnkSKQoIZnVuY3Rpb24YASADKAsyFy50ZW5zb3JmbG93LkZ1bmN0", + "aW9uRGVmEikKCGdyYWRpZW50GAIgAygLMhcudGVuc29yZmxvdy5HcmFkaWVu", + "dERlZiKwAgoLRnVuY3Rpb25EZWYSJAoJc2lnbmF0dXJlGAEgASgLMhEudGVu", + "c29yZmxvdy5PcERlZhIvCgRhdHRyGAUgAygLMiEudGVuc29yZmxvdy5GdW5j", + "dGlvbkRlZi5BdHRyRW50cnkSJQoIbm9kZV9kZWYYAyADKAsyEy50ZW5zb3Jm", + "bG93Lk5vZGVEZWYSLQoDcmV0GAQgAygLMiAudGVuc29yZmxvdy5GdW5jdGlv", + "bkRlZi5SZXRFbnRyeRpCCglBdHRyRW50cnkSCwoDa2V5GAEgASgJEiQKBXZh", + "bHVlGAIgASgLMhUudGVuc29yZmxvdy5BdHRyVmFsdWU6AjgBGioKCFJldEVu", + "dHJ5EgsKA2tleRgBIAEoCRINCgV2YWx1ZRgCIAEoCToCOAFKBAgCEAMiOwoL", + "R3JhZGllbnREZWYSFQoNZnVuY3Rpb25fbmFtZRgBIAEoCRIVCg1ncmFkaWVu", + "dF9mdW5jGAIgASgJQm4KGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IORnVu", + "Y3Rpb25Qcm90b3NQAVo9Z2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNvcmZs", + "b3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29ya/gBAWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.AttrValueReflection.Descriptor, global::Tensorflow.NodeDefReflection.Descriptor, global::Tensorflow.OpDefReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionDefLibrary), global::Tensorflow.FunctionDefLibrary.Parser, new[]{ "Function", "Gradient" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionDef), global::Tensorflow.FunctionDef.Parser, new[]{ "Signature", "Attr", "NodeDef", "Ret" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GradientDef), global::Tensorflow.GradientDef.Parser, new[]{ "FunctionName", "GradientFunc" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// A library is a set of named functions. + /// + public sealed partial class FunctionDefLibrary : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FunctionDefLibrary()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FunctionDefLibrary() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FunctionDefLibrary(FunctionDefLibrary other) : this() { + function_ = other.function_.Clone(); + gradient_ = other.gradient_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FunctionDefLibrary Clone() { + return new FunctionDefLibrary(this); + } + + /// Field number for the "function" field. + public const int FunctionFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_function_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.FunctionDef.Parser); + private readonly pbc::RepeatedField function_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Function { + get { return function_; } + } + + /// Field number for the "gradient" field. + public const int GradientFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_gradient_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.GradientDef.Parser); + private readonly pbc::RepeatedField gradient_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Gradient { + get { return gradient_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as FunctionDefLibrary); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(FunctionDefLibrary other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!function_.Equals(other.function_)) return false; + if(!gradient_.Equals(other.gradient_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= function_.GetHashCode(); + hash ^= gradient_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + function_.WriteTo(output, _repeated_function_codec); + gradient_.WriteTo(output, _repeated_gradient_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += function_.CalculateSize(_repeated_function_codec); + size += gradient_.CalculateSize(_repeated_gradient_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(FunctionDefLibrary other) { + if (other == null) { + return; + } + function_.Add(other.function_); + gradient_.Add(other.gradient_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + function_.AddEntriesFrom(input, _repeated_function_codec); + break; + } + case 18: { + gradient_.AddEntriesFrom(input, _repeated_gradient_codec); + break; + } + } + } + } + + } + + /// + /// A function can be instantiated when the runtime can bind every attr + /// with a value. When a GraphDef has a call to a function, it must + /// have binding for every attr defined in the signature. + /// + /// TODO(zhifengc): + /// * device spec, etc. + /// + public sealed partial class FunctionDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FunctionDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FunctionDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FunctionDef(FunctionDef other) : this() { + signature_ = other.signature_ != null ? other.signature_.Clone() : null; + attr_ = other.attr_.Clone(); + nodeDef_ = other.nodeDef_.Clone(); + ret_ = other.ret_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FunctionDef Clone() { + return new FunctionDef(this); + } + + /// Field number for the "signature" field. + public const int SignatureFieldNumber = 1; + private global::Tensorflow.OpDef signature_; + /// + /// The definition of the function's name, arguments, return values, + /// attrs etc. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.OpDef Signature { + get { return signature_; } + set { + signature_ = value; + } + } + + /// Field number for the "attr" field. + public const int AttrFieldNumber = 5; + private static readonly pbc::MapField.Codec _map_attr_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 42); + private readonly pbc::MapField attr_ = new pbc::MapField(); + /// + /// Attributes specific to this function definition. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::MapField Attr { + get { return attr_; } + } + + /// Field number for the "node_def" field. + public const int NodeDefFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_nodeDef_codec + = pb::FieldCodec.ForMessage(26, global::Tensorflow.NodeDef.Parser); + private readonly pbc::RepeatedField nodeDef_ = new pbc::RepeatedField(); + /// + /// By convention, "op" in node_def is resolved by consulting with a + /// user-defined library first. If not resolved, "func" is assumed to + /// be a builtin op. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField NodeDef { + get { return nodeDef_; } + } + + /// Field number for the "ret" field. + public const int RetFieldNumber = 4; + private static readonly pbc::MapField.Codec _map_ret_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForString(18), 34); + private readonly pbc::MapField ret_ = new pbc::MapField(); + /// + /// A mapping from the output arg names from `signature` to the + /// outputs from `node_def` that should be returned by the function. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::MapField Ret { + get { return ret_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as FunctionDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(FunctionDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Signature, other.Signature)) return false; + if (!Attr.Equals(other.Attr)) return false; + if(!nodeDef_.Equals(other.nodeDef_)) return false; + if (!Ret.Equals(other.Ret)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (signature_ != null) hash ^= Signature.GetHashCode(); + hash ^= Attr.GetHashCode(); + hash ^= nodeDef_.GetHashCode(); + hash ^= Ret.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (signature_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Signature); + } + nodeDef_.WriteTo(output, _repeated_nodeDef_codec); + ret_.WriteTo(output, _map_ret_codec); + attr_.WriteTo(output, _map_attr_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (signature_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Signature); + } + size += attr_.CalculateSize(_map_attr_codec); + size += nodeDef_.CalculateSize(_repeated_nodeDef_codec); + size += ret_.CalculateSize(_map_ret_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(FunctionDef other) { + if (other == null) { + return; + } + if (other.signature_ != null) { + if (signature_ == null) { + signature_ = new global::Tensorflow.OpDef(); + } + Signature.MergeFrom(other.Signature); + } + attr_.Add(other.attr_); + nodeDef_.Add(other.nodeDef_); + ret_.Add(other.ret_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (signature_ == null) { + signature_ = new global::Tensorflow.OpDef(); + } + input.ReadMessage(signature_); + break; + } + case 26: { + nodeDef_.AddEntriesFrom(input, _repeated_nodeDef_codec); + break; + } + case 34: { + ret_.AddEntriesFrom(input, _map_ret_codec); + break; + } + case 42: { + attr_.AddEntriesFrom(input, _map_attr_codec); + break; + } + } + } + } + + } + + /// + /// GradientDef defines the gradient function of a function defined in + /// a function library. + /// + /// A gradient function g (specified by gradient_func) for a function f + /// (specified by function_name) must follow the following: + /// + /// The function 'f' must be a numerical function which takes N inputs + /// and produces M outputs. Its gradient function 'g', which is a + /// function taking N + M inputs and produces N outputs. + /// + /// I.e. if we have + /// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), + /// then, g is + /// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, + /// dL/dy1, dL/dy2, ..., dL/dy_M), + /// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the + /// loss function). dL/dx_i is the partial derivative of L with respect + /// to x_i. + /// + public sealed partial class GradientDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GradientDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public GradientDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public GradientDef(GradientDef other) : this() { + functionName_ = other.functionName_; + gradientFunc_ = other.gradientFunc_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public GradientDef Clone() { + return new GradientDef(this); + } + + /// Field number for the "function_name" field. + public const int FunctionNameFieldNumber = 1; + private string functionName_ = ""; + /// + /// The function name. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string FunctionName { + get { return functionName_; } + set { + functionName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "gradient_func" field. + public const int GradientFuncFieldNumber = 2; + private string gradientFunc_ = ""; + /// + /// The gradient function's name. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string GradientFunc { + get { return gradientFunc_; } + set { + gradientFunc_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as GradientDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(GradientDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (FunctionName != other.FunctionName) return false; + if (GradientFunc != other.GradientFunc) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (FunctionName.Length != 0) hash ^= FunctionName.GetHashCode(); + if (GradientFunc.Length != 0) hash ^= GradientFunc.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (FunctionName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(FunctionName); + } + if (GradientFunc.Length != 0) { + output.WriteRawTag(18); + output.WriteString(GradientFunc); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (FunctionName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(FunctionName); + } + if (GradientFunc.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(GradientFunc); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(GradientDef other) { + if (other == null) { + return; + } + if (other.FunctionName.Length != 0) { + FunctionName = other.FunctionName; + } + if (other.GradientFunc.Length != 0) { + GradientFunc = other.GradientFunc; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + FunctionName = input.ReadString(); + break; + } + case 18: { + GradientFunc = input.ReadString(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/Graph.cs b/src/TensorFlowNET.Core/Protobuf/Graph.cs new file mode 100644 index 00000000..3dce73f1 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Graph.cs @@ -0,0 +1,309 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: graph.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from graph.proto + public static partial class GraphReflection { + + #region Descriptor + /// File descriptor for graph.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static GraphReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CgtncmFwaC5wcm90bxIKdGVuc29yZmxvdxoObm9kZV9kZWYucHJvdG8aDmZ1", + "bmN0aW9uLnByb3RvGg52ZXJzaW9ucy5wcm90byKdAQoIR3JhcGhEZWYSIQoE", + "bm9kZRgBIAMoCzITLnRlbnNvcmZsb3cuTm9kZURlZhIoCgh2ZXJzaW9ucxgE", + "IAEoCzIWLnRlbnNvcmZsb3cuVmVyc2lvbkRlZhITCgd2ZXJzaW9uGAMgASgF", + "QgIYARIvCgdsaWJyYXJ5GAIgASgLMh4udGVuc29yZmxvdy5GdW5jdGlvbkRl", + "ZkxpYnJhcnlCawoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3JrQgtHcmFwaFBy", + "b3Rvc1ABWj1naXRodWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90ZW5z", + "b3JmbG93L2dvL2NvcmUvZnJhbWV3b3Jr+AEBYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.NodeDefReflection.Descriptor, global::Tensorflow.FunctionReflection.Descriptor, global::Tensorflow.VersionsReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphDef), global::Tensorflow.GraphDef.Parser, new[]{ "Node", "Versions", "Version", "Library" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Represents the graph of operations + /// + public sealed partial class GraphDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GraphDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.GraphReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public GraphDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public GraphDef(GraphDef other) : this() { + node_ = other.node_.Clone(); + versions_ = other.versions_ != null ? other.versions_.Clone() : null; + version_ = other.version_; + library_ = other.library_ != null ? other.library_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public GraphDef Clone() { + return new GraphDef(this); + } + + /// Field number for the "node" field. + public const int NodeFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_node_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.NodeDef.Parser); + private readonly pbc::RepeatedField node_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Node { + get { return node_; } + } + + /// Field number for the "versions" field. + public const int VersionsFieldNumber = 4; + private global::Tensorflow.VersionDef versions_; + /// + /// Compatibility versions of the graph. See core/public/version.h for version + /// history. The GraphDef version is distinct from the TensorFlow version, and + /// each release of TensorFlow will support a range of GraphDef versions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.VersionDef Versions { + get { return versions_; } + set { + versions_ = value; + } + } + + /// Field number for the "version" field. + public const int VersionFieldNumber = 3; + private int version_; + /// + /// Deprecated single version field; use versions above instead. Since all + /// GraphDef changes before "versions" was introduced were forward + /// compatible, this field is entirely ignored. + /// + [global::System.ObsoleteAttribute] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Version { + get { return version_; } + set { + version_ = value; + } + } + + /// Field number for the "library" field. + public const int LibraryFieldNumber = 2; + private global::Tensorflow.FunctionDefLibrary library_; + /// + /// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. + /// + /// "library" provides user-defined functions. + /// + /// Naming: + /// * library.function.name are in a flat namespace. + /// NOTE: We may need to change it to be hierarchical to support + /// different orgs. E.g., + /// { "/google/nn", { ... }}, + /// { "/google/vision", { ... }} + /// { "/org_foo/module_bar", { ... }} + /// map<string, FunctionDefLib> named_lib; + /// * If node[i].op is the name of one function in "library", + /// node[i] is deemed as a function call. Otherwise, node[i].op + /// must be a primitive operation supported by the runtime. + /// + /// Function call semantics: + /// + /// * The callee may start execution as soon as some of its inputs + /// are ready. The caller may want to use Tuple() mechanism to + /// ensure all inputs are ready in the same time. + /// + /// * The consumer of return values may start executing as soon as + /// the return values the consumer depends on are ready. The + /// consumer may want to use Tuple() mechanism to ensure the + /// consumer does not start until all return values of the callee + /// function are ready. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.FunctionDefLibrary Library { + get { return library_; } + set { + library_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as GraphDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(GraphDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!node_.Equals(other.node_)) return false; + if (!object.Equals(Versions, other.Versions)) return false; + if (Version != other.Version) return false; + if (!object.Equals(Library, other.Library)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= node_.GetHashCode(); + if (versions_ != null) hash ^= Versions.GetHashCode(); + if (Version != 0) hash ^= Version.GetHashCode(); + if (library_ != null) hash ^= Library.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + node_.WriteTo(output, _repeated_node_codec); + if (library_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Library); + } + if (Version != 0) { + output.WriteRawTag(24); + output.WriteInt32(Version); + } + if (versions_ != null) { + output.WriteRawTag(34); + output.WriteMessage(Versions); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += node_.CalculateSize(_repeated_node_codec); + if (versions_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Versions); + } + if (Version != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Version); + } + if (library_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Library); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(GraphDef other) { + if (other == null) { + return; + } + node_.Add(other.node_); + if (other.versions_ != null) { + if (versions_ == null) { + versions_ = new global::Tensorflow.VersionDef(); + } + Versions.MergeFrom(other.Versions); + } + if (other.Version != 0) { + Version = other.Version; + } + if (other.library_ != null) { + if (library_ == null) { + library_ = new global::Tensorflow.FunctionDefLibrary(); + } + Library.MergeFrom(other.Library); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + node_.AddEntriesFrom(input, _repeated_node_codec); + break; + } + case 18: { + if (library_ == null) { + library_ = new global::Tensorflow.FunctionDefLibrary(); + } + input.ReadMessage(library_); + break; + } + case 24: { + Version = input.ReadInt32(); + break; + } + case 34: { + if (versions_ == null) { + versions_ = new global::Tensorflow.VersionDef(); + } + input.ReadMessage(versions_); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/README.md b/src/TensorFlowNET.Core/Protobuf/README.md index 4b4cc3d3..c3c34cbe 100644 --- a/src/TensorFlowNET.Core/Protobuf/README.md +++ b/src/TensorFlowNET.Core/Protobuf/README.md @@ -1,12 +1,15 @@ ### Download compiler from https://github.com/protocolbuffers/protobuf/releases ```shell -set SRC_DIR=D:\Projects\tensorflow\tensorflow\core\framework -set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Tensorflow +set SRC_DIR=D:\Projects\tensorflow-1.12.0\tensorflow\core\framework +set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Protobuf -.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% resource_handle.proto -.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor_shape.proto -.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% types.proto -.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor.proto -.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% attr_value.proto -.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% node_def.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% resource_handle.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor_shape.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% types.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% attr_value.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% node_def.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% versions.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% function.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% graph.proto ``` \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Protobuf/Versions.cs b/src/TensorFlowNET.Core/Protobuf/Versions.cs new file mode 100644 index 00000000..6e97f1f7 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/Versions.cs @@ -0,0 +1,247 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: versions.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from versions.proto + public static partial class VersionsReflection { + + #region Descriptor + /// File descriptor for versions.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static VersionsReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cg52ZXJzaW9ucy5wcm90bxIKdGVuc29yZmxvdyJLCgpWZXJzaW9uRGVmEhAK", + "CHByb2R1Y2VyGAEgASgFEhQKDG1pbl9jb25zdW1lchgCIAEoBRIVCg1iYWRf", + "Y29uc3VtZXJzGAMgAygFQm4KGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IO", + "VmVyc2lvbnNQcm90b3NQAVo9Z2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNv", + "cmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29ya/gBAWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.VersionDef), global::Tensorflow.VersionDef.Parser, new[]{ "Producer", "MinConsumer", "BadConsumers" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Version information for a piece of serialized data + /// + /// There are different types of versions for each type of data + /// (GraphDef, etc.), but they all have the same common shape + /// described here. + /// + /// Each consumer has "consumer" and "min_producer" versions (specified + /// elsewhere). A consumer is allowed to consume this data if + /// + /// producer >= min_producer + /// consumer >= min_consumer + /// consumer not in bad_consumers + /// + public sealed partial class VersionDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new VersionDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.VersionsReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VersionDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VersionDef(VersionDef other) : this() { + producer_ = other.producer_; + minConsumer_ = other.minConsumer_; + badConsumers_ = other.badConsumers_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VersionDef Clone() { + return new VersionDef(this); + } + + /// Field number for the "producer" field. + public const int ProducerFieldNumber = 1; + private int producer_; + /// + /// The version of the code that produced this data. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Producer { + get { return producer_; } + set { + producer_ = value; + } + } + + /// Field number for the "min_consumer" field. + public const int MinConsumerFieldNumber = 2; + private int minConsumer_; + /// + /// Any consumer below this version is not allowed to consume this data. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MinConsumer { + get { return minConsumer_; } + set { + minConsumer_ = value; + } + } + + /// Field number for the "bad_consumers" field. + public const int BadConsumersFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_badConsumers_codec + = pb::FieldCodec.ForInt32(26); + private readonly pbc::RepeatedField badConsumers_ = new pbc::RepeatedField(); + /// + /// Specific consumer versions which are disallowed (e.g. due to bugs). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField BadConsumers { + get { return badConsumers_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as VersionDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(VersionDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Producer != other.Producer) return false; + if (MinConsumer != other.MinConsumer) return false; + if(!badConsumers_.Equals(other.badConsumers_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Producer != 0) hash ^= Producer.GetHashCode(); + if (MinConsumer != 0) hash ^= MinConsumer.GetHashCode(); + hash ^= badConsumers_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Producer != 0) { + output.WriteRawTag(8); + output.WriteInt32(Producer); + } + if (MinConsumer != 0) { + output.WriteRawTag(16); + output.WriteInt32(MinConsumer); + } + badConsumers_.WriteTo(output, _repeated_badConsumers_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Producer != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Producer); + } + if (MinConsumer != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MinConsumer); + } + size += badConsumers_.CalculateSize(_repeated_badConsumers_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(VersionDef other) { + if (other == null) { + return; + } + if (other.Producer != 0) { + Producer = other.Producer; + } + if (other.MinConsumer != 0) { + MinConsumer = other.MinConsumer; + } + badConsumers_.Add(other.badConsumers_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Producer = input.ReadInt32(); + break; + } + case 16: { + MinConsumer = input.ReadInt32(); + break; + } + case 26: + case 24: { + badConsumers_.AddEntriesFrom(input, _repeated_badConsumers_codec); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/c_api.cs b/src/TensorFlowNET.Core/c_api.cs index b6de4639..0e0316a1 100644 --- a/src/TensorFlowNET.Core/c_api.cs +++ b/src/TensorFlowNET.Core/c_api.cs @@ -17,7 +17,7 @@ namespace Tensorflow /// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) /// struct => struct (TF_Output output) => (TF_Output output) /// struct* => struct (TF_Output* output) => (TF_Output[] output) - /// struct* => ref IntPtr (TF_Input* consumers) => (ref IntPtr handle), if output is struct[] + /// struct* => struct* for ref /// const char* => string /// int32_t => int /// int64_t* => long[] diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index 63b25f59..3b2fd37c 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -83,6 +83,42 @@ namespace TensorFlowNET.UnitTest Assert.AreEqual(1, feed_port.Length); Assert.AreEqual(add, feed_port[0].oper); Assert.AreEqual(0, feed_port[0].index); + + // The scalar const oper also has a consumer. + Assert.AreEqual(1, three.OutputNumConsumers(0)); + TF_Input[] three_port = three.OutputConsumers(0, 1); + Assert.AreEqual(add, three_port[0].oper); + Assert.AreEqual(1, three_port[0].index); + + // Serialize to GraphDef. + var graph_def = c_test_util.GetGraphDef(graph); + + // Validate GraphDef is what we expect. + bool found_placeholder = false; + bool found_scalar_const = false; + bool found_add = false; + foreach (var n in graph_def.Node) + { + if (c_test_util.IsPlaceholder(n)) + { + Assert.IsFalse(found_placeholder); + found_placeholder = true; + } + /*else if (IsScalarConst(n, 3)) + { + Assert.IsFalse(found_scalar_const); + found_scalar_const = true; + } + else if (IsAddN(n, 2)) + { + Assert.IsFalse(found_add); + found_add = true; + } + else + { + ADD_FAILURE() << "Unexpected NodeDef: " << ProtoDebugString(n); + }*/ + } } } } diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index c45a146e..349d82b3 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -19,7 +19,7 @@ namespace TensorFlowNET.UnitTest { var handle = c_api.TF_GetAllOpList(); var buffer = new Buffer(handle); - Assert.IsTrue(buffer.Length == buffer.Data.Length); + Assert.IsTrue(buffer.Length == buffer.Length); } [TestMethod] diff --git a/test/TensorFlowNET.UnitTest/c_test_util.cs b/test/TensorFlowNET.UnitTest/c_test_util.cs index f62433ae..079ecee5 100644 --- a/test/TensorFlowNET.UnitTest/c_test_util.cs +++ b/test/TensorFlowNET.UnitTest/c_test_util.cs @@ -39,11 +39,20 @@ namespace TensorFlowNET.UnitTest { var buffer = new Buffer(); c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); - attr_value = AttrValue.Parser.ParseFrom(buffer.Data); + attr_value = AttrValue.Parser.ParseFrom(buffer); buffer.Dispose(); return s.Code == TF_Code.TF_OK; } + public static GraphDef GetGraphDef(Graph graph) + { + var s = new Status(); + var buffer = new Buffer(); + c_api.TF_GraphToGraphDef(graph, buffer, s); + s.Check(); + return GraphDef.Parser.ParseFrom(buffer); + } + public static bool GetNodeDef(Operation oper, ref NodeDef node_def) { var s = new Status(); @@ -53,6 +62,37 @@ namespace TensorFlowNET.UnitTest return s.Code == TF_Code.TF_OK; } + public static bool IsPlaceholder(NodeDef node_def) + { + if (node_def.Op != "Placeholder" || node_def.Name != "feed") + { + return false; + } + + bool found_dtype = false; + bool found_shape = false; + foreach (var attr in node_def.Attr) + { + if (attr.Key == "dtype") + { + if (attr.Value.Type == DataType.DtInt32) + { + found_dtype = true; + } + else + { + return false; + } + } + else if (attr.Key == "shape") + { + found_shape = true; + } + } + + return found_dtype && found_shape; + } + public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op) { var desc = c_api.TF_NewOperation(graph, "Placeholder", name);