diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs
index 112afc9c..4605e37a 100644
--- a/src/TensorFlowNET.Core/Buffers/Buffer.cs
+++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs
@@ -39,6 +39,11 @@ namespace Tensorflow
return buffer._handle;
}
+ public static implicit operator byte[](Buffer buffer)
+ {
+ return buffer.Data;
+ }
+
public void Dispose()
{
c_api.TF_DeleteBuffer(_handle);
diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs
index 8578f33a..b2e9947b 100644
--- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs
@@ -38,6 +38,16 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status);
+ ///
+ /// Write out a serialized representation of `graph` (as a GraphDef protocol
+ /// message) to `output_graph_def` (allocated by TF_NewBuffer()).
+ ///
+ ///
+ ///
+ ///
+ [DllImport(TensorFlowLibName)]
+ public static extern void TF_GraphToGraphDef(IntPtr graph, IntPtr output_graph_def, IntPtr status);
+
///
/// Returns the number of dimensions of the Tensor referenced by `output`
/// in `graph`.
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index 02d29e08..c8a2933f 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -26,15 +26,15 @@ namespace Tensorflow
public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status);
public int NumInputs => c_api.TF_OperationNumInputs(_handle);
public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index));
- public TF_Input[] OutputConsumers(int index, int max_consumers)
+ public unsafe TF_Input[] OutputConsumers(int index, int max_consumers)
{
- IntPtr handle = IntPtr.Zero;
int size = Marshal.SizeOf();
- int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), ref handle, max_consumers);
+ var handle = (TF_Input*)Marshal.AllocHGlobal(size);
+ int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers);
var consumers = new TF_Input[num];
for(int i = 0; i < num; i++)
{
- consumers[0] = Marshal.PtrToStructure(handle + i * size);
+ consumers[i] = new TF_Input((*handle).oper + i * size, (*handle).index);
}
return consumers;
diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs
index 02839147..0a090cbc 100644
--- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs
+++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs
@@ -112,7 +112,7 @@ namespace Tensorflow
///
///
[DllImport(TensorFlowLibName)]
- public static extern int TF_OperationOutputConsumers(TF_Output oper_out, ref IntPtr consumers, int max_consumers);
+ public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input * consumers, int max_consumers);
[DllImport(TensorFlowLibName)]
public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out);
diff --git a/src/TensorFlowNET.Core/Protobuf/Function.cs b/src/TensorFlowNET.Core/Protobuf/Function.cs
new file mode 100644
index 00000000..4aac8252
--- /dev/null
+++ b/src/TensorFlowNET.Core/Protobuf/Function.cs
@@ -0,0 +1,604 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: function.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace Tensorflow {
+
+ /// Holder for reflection information generated from function.proto
+ public static partial class FunctionReflection {
+
+ #region Descriptor
+ /// File descriptor for function.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static FunctionReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "Cg5mdW5jdGlvbi5wcm90bxIKdGVuc29yZmxvdxoQYXR0cl92YWx1ZS5wcm90",
+ "bxoObm9kZV9kZWYucHJvdG8aDG9wX2RlZi5wcm90byJqChJGdW5jdGlvbkRl",
+ "ZkxpYnJhcnkSKQoIZnVuY3Rpb24YASADKAsyFy50ZW5zb3JmbG93LkZ1bmN0",
+ "aW9uRGVmEikKCGdyYWRpZW50GAIgAygLMhcudGVuc29yZmxvdy5HcmFkaWVu",
+ "dERlZiKwAgoLRnVuY3Rpb25EZWYSJAoJc2lnbmF0dXJlGAEgASgLMhEudGVu",
+ "c29yZmxvdy5PcERlZhIvCgRhdHRyGAUgAygLMiEudGVuc29yZmxvdy5GdW5j",
+ "dGlvbkRlZi5BdHRyRW50cnkSJQoIbm9kZV9kZWYYAyADKAsyEy50ZW5zb3Jm",
+ "bG93Lk5vZGVEZWYSLQoDcmV0GAQgAygLMiAudGVuc29yZmxvdy5GdW5jdGlv",
+ "bkRlZi5SZXRFbnRyeRpCCglBdHRyRW50cnkSCwoDa2V5GAEgASgJEiQKBXZh",
+ "bHVlGAIgASgLMhUudGVuc29yZmxvdy5BdHRyVmFsdWU6AjgBGioKCFJldEVu",
+ "dHJ5EgsKA2tleRgBIAEoCRINCgV2YWx1ZRgCIAEoCToCOAFKBAgCEAMiOwoL",
+ "R3JhZGllbnREZWYSFQoNZnVuY3Rpb25fbmFtZRgBIAEoCRIVCg1ncmFkaWVu",
+ "dF9mdW5jGAIgASgJQm4KGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IORnVu",
+ "Y3Rpb25Qcm90b3NQAVo9Z2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNvcmZs",
+ "b3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29ya/gBAWIGcHJvdG8z"));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { global::Tensorflow.AttrValueReflection.Descriptor, global::Tensorflow.NodeDefReflection.Descriptor, global::Tensorflow.OpDefReflection.Descriptor, },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionDefLibrary), global::Tensorflow.FunctionDefLibrary.Parser, new[]{ "Function", "Gradient" }, null, null, null),
+ new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionDef), global::Tensorflow.FunctionDef.Parser, new[]{ "Signature", "Attr", "NodeDef", "Ret" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, null, }),
+ new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GradientDef), global::Tensorflow.GradientDef.Parser, new[]{ "FunctionName", "GradientFunc" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ ///
+ /// A library is a set of named functions.
+ ///
+ public sealed partial class FunctionDefLibrary : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FunctionDefLibrary());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public FunctionDefLibrary() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public FunctionDefLibrary(FunctionDefLibrary other) : this() {
+ function_ = other.function_.Clone();
+ gradient_ = other.gradient_.Clone();
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public FunctionDefLibrary Clone() {
+ return new FunctionDefLibrary(this);
+ }
+
+ /// Field number for the "function" field.
+ public const int FunctionFieldNumber = 1;
+ private static readonly pb::FieldCodec _repeated_function_codec
+ = pb::FieldCodec.ForMessage(10, global::Tensorflow.FunctionDef.Parser);
+ private readonly pbc::RepeatedField function_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField Function {
+ get { return function_; }
+ }
+
+ /// Field number for the "gradient" field.
+ public const int GradientFieldNumber = 2;
+ private static readonly pb::FieldCodec _repeated_gradient_codec
+ = pb::FieldCodec.ForMessage(18, global::Tensorflow.GradientDef.Parser);
+ private readonly pbc::RepeatedField gradient_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField Gradient {
+ get { return gradient_; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as FunctionDefLibrary);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(FunctionDefLibrary other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if(!function_.Equals(other.function_)) return false;
+ if(!gradient_.Equals(other.gradient_)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ hash ^= function_.GetHashCode();
+ hash ^= gradient_.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ function_.WriteTo(output, _repeated_function_codec);
+ gradient_.WriteTo(output, _repeated_gradient_codec);
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ size += function_.CalculateSize(_repeated_function_codec);
+ size += gradient_.CalculateSize(_repeated_gradient_codec);
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(FunctionDefLibrary other) {
+ if (other == null) {
+ return;
+ }
+ function_.Add(other.function_);
+ gradient_.Add(other.gradient_);
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ function_.AddEntriesFrom(input, _repeated_function_codec);
+ break;
+ }
+ case 18: {
+ gradient_.AddEntriesFrom(input, _repeated_gradient_codec);
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ ///
+ /// A function can be instantiated when the runtime can bind every attr
+ /// with a value. When a GraphDef has a call to a function, it must
+ /// have binding for every attr defined in the signature.
+ ///
+ /// TODO(zhifengc):
+ /// * device spec, etc.
+ ///
+ public sealed partial class FunctionDef : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FunctionDef());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[1]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public FunctionDef() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public FunctionDef(FunctionDef other) : this() {
+ signature_ = other.signature_ != null ? other.signature_.Clone() : null;
+ attr_ = other.attr_.Clone();
+ nodeDef_ = other.nodeDef_.Clone();
+ ret_ = other.ret_.Clone();
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public FunctionDef Clone() {
+ return new FunctionDef(this);
+ }
+
+ /// Field number for the "signature" field.
+ public const int SignatureFieldNumber = 1;
+ private global::Tensorflow.OpDef signature_;
+ ///
+ /// The definition of the function's name, arguments, return values,
+ /// attrs etc.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::Tensorflow.OpDef Signature {
+ get { return signature_; }
+ set {
+ signature_ = value;
+ }
+ }
+
+ /// Field number for the "attr" field.
+ public const int AttrFieldNumber = 5;
+ private static readonly pbc::MapField.Codec _map_attr_codec
+ = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 42);
+ private readonly pbc::MapField attr_ = new pbc::MapField();
+ ///
+ /// Attributes specific to this function definition.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::MapField Attr {
+ get { return attr_; }
+ }
+
+ /// Field number for the "node_def" field.
+ public const int NodeDefFieldNumber = 3;
+ private static readonly pb::FieldCodec _repeated_nodeDef_codec
+ = pb::FieldCodec.ForMessage(26, global::Tensorflow.NodeDef.Parser);
+ private readonly pbc::RepeatedField nodeDef_ = new pbc::RepeatedField();
+ ///
+ /// By convention, "op" in node_def is resolved by consulting with a
+ /// user-defined library first. If not resolved, "func" is assumed to
+ /// be a builtin op.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField NodeDef {
+ get { return nodeDef_; }
+ }
+
+ /// Field number for the "ret" field.
+ public const int RetFieldNumber = 4;
+ private static readonly pbc::MapField.Codec _map_ret_codec
+ = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForString(18), 34);
+ private readonly pbc::MapField ret_ = new pbc::MapField();
+ ///
+ /// A mapping from the output arg names from `signature` to the
+ /// outputs from `node_def` that should be returned by the function.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::MapField Ret {
+ get { return ret_; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as FunctionDef);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(FunctionDef other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (!object.Equals(Signature, other.Signature)) return false;
+ if (!Attr.Equals(other.Attr)) return false;
+ if(!nodeDef_.Equals(other.nodeDef_)) return false;
+ if (!Ret.Equals(other.Ret)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (signature_ != null) hash ^= Signature.GetHashCode();
+ hash ^= Attr.GetHashCode();
+ hash ^= nodeDef_.GetHashCode();
+ hash ^= Ret.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (signature_ != null) {
+ output.WriteRawTag(10);
+ output.WriteMessage(Signature);
+ }
+ nodeDef_.WriteTo(output, _repeated_nodeDef_codec);
+ ret_.WriteTo(output, _map_ret_codec);
+ attr_.WriteTo(output, _map_attr_codec);
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (signature_ != null) {
+ size += 1 + pb::CodedOutputStream.ComputeMessageSize(Signature);
+ }
+ size += attr_.CalculateSize(_map_attr_codec);
+ size += nodeDef_.CalculateSize(_repeated_nodeDef_codec);
+ size += ret_.CalculateSize(_map_ret_codec);
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(FunctionDef other) {
+ if (other == null) {
+ return;
+ }
+ if (other.signature_ != null) {
+ if (signature_ == null) {
+ signature_ = new global::Tensorflow.OpDef();
+ }
+ Signature.MergeFrom(other.Signature);
+ }
+ attr_.Add(other.attr_);
+ nodeDef_.Add(other.nodeDef_);
+ ret_.Add(other.ret_);
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ if (signature_ == null) {
+ signature_ = new global::Tensorflow.OpDef();
+ }
+ input.ReadMessage(signature_);
+ break;
+ }
+ case 26: {
+ nodeDef_.AddEntriesFrom(input, _repeated_nodeDef_codec);
+ break;
+ }
+ case 34: {
+ ret_.AddEntriesFrom(input, _map_ret_codec);
+ break;
+ }
+ case 42: {
+ attr_.AddEntriesFrom(input, _map_attr_codec);
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ ///
+ /// GradientDef defines the gradient function of a function defined in
+ /// a function library.
+ ///
+ /// A gradient function g (specified by gradient_func) for a function f
+ /// (specified by function_name) must follow the following:
+ ///
+ /// The function 'f' must be a numerical function which takes N inputs
+ /// and produces M outputs. Its gradient function 'g', which is a
+ /// function taking N + M inputs and produces N outputs.
+ ///
+ /// I.e. if we have
+ /// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
+ /// then, g is
+ /// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
+ /// dL/dy1, dL/dy2, ..., dL/dy_M),
+ /// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
+ /// loss function). dL/dx_i is the partial derivative of L with respect
+ /// to x_i.
+ ///
+ public sealed partial class GradientDef : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GradientDef());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[2]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public GradientDef() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public GradientDef(GradientDef other) : this() {
+ functionName_ = other.functionName_;
+ gradientFunc_ = other.gradientFunc_;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public GradientDef Clone() {
+ return new GradientDef(this);
+ }
+
+ /// Field number for the "function_name" field.
+ public const int FunctionNameFieldNumber = 1;
+ private string functionName_ = "";
+ ///
+ /// The function name.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string FunctionName {
+ get { return functionName_; }
+ set {
+ functionName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ /// Field number for the "gradient_func" field.
+ public const int GradientFuncFieldNumber = 2;
+ private string gradientFunc_ = "";
+ ///
+ /// The gradient function's name.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public string GradientFunc {
+ get { return gradientFunc_; }
+ set {
+ gradientFunc_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as GradientDef);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(GradientDef other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (FunctionName != other.FunctionName) return false;
+ if (GradientFunc != other.GradientFunc) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (FunctionName.Length != 0) hash ^= FunctionName.GetHashCode();
+ if (GradientFunc.Length != 0) hash ^= GradientFunc.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (FunctionName.Length != 0) {
+ output.WriteRawTag(10);
+ output.WriteString(FunctionName);
+ }
+ if (GradientFunc.Length != 0) {
+ output.WriteRawTag(18);
+ output.WriteString(GradientFunc);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (FunctionName.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(FunctionName);
+ }
+ if (GradientFunc.Length != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeStringSize(GradientFunc);
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(GradientDef other) {
+ if (other == null) {
+ return;
+ }
+ if (other.FunctionName.Length != 0) {
+ FunctionName = other.FunctionName;
+ }
+ if (other.GradientFunc.Length != 0) {
+ GradientFunc = other.GradientFunc;
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ FunctionName = input.ReadString();
+ break;
+ }
+ case 18: {
+ GradientFunc = input.ReadString();
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/src/TensorFlowNET.Core/Protobuf/Graph.cs b/src/TensorFlowNET.Core/Protobuf/Graph.cs
new file mode 100644
index 00000000..3dce73f1
--- /dev/null
+++ b/src/TensorFlowNET.Core/Protobuf/Graph.cs
@@ -0,0 +1,309 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: graph.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace Tensorflow {
+
+ /// Holder for reflection information generated from graph.proto
+ public static partial class GraphReflection {
+
+ #region Descriptor
+ /// File descriptor for graph.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static GraphReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CgtncmFwaC5wcm90bxIKdGVuc29yZmxvdxoObm9kZV9kZWYucHJvdG8aDmZ1",
+ "bmN0aW9uLnByb3RvGg52ZXJzaW9ucy5wcm90byKdAQoIR3JhcGhEZWYSIQoE",
+ "bm9kZRgBIAMoCzITLnRlbnNvcmZsb3cuTm9kZURlZhIoCgh2ZXJzaW9ucxgE",
+ "IAEoCzIWLnRlbnNvcmZsb3cuVmVyc2lvbkRlZhITCgd2ZXJzaW9uGAMgASgF",
+ "QgIYARIvCgdsaWJyYXJ5GAIgASgLMh4udGVuc29yZmxvdy5GdW5jdGlvbkRl",
+ "ZkxpYnJhcnlCawoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3JrQgtHcmFwaFBy",
+ "b3Rvc1ABWj1naXRodWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90ZW5z",
+ "b3JmbG93L2dvL2NvcmUvZnJhbWV3b3Jr+AEBYgZwcm90bzM="));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { global::Tensorflow.NodeDefReflection.Descriptor, global::Tensorflow.FunctionReflection.Descriptor, global::Tensorflow.VersionsReflection.Descriptor, },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphDef), global::Tensorflow.GraphDef.Parser, new[]{ "Node", "Versions", "Version", "Library" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ ///
+ /// Represents the graph of operations
+ ///
+ public sealed partial class GraphDef : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GraphDef());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::Tensorflow.GraphReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public GraphDef() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public GraphDef(GraphDef other) : this() {
+ node_ = other.node_.Clone();
+ versions_ = other.versions_ != null ? other.versions_.Clone() : null;
+ version_ = other.version_;
+ library_ = other.library_ != null ? other.library_.Clone() : null;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public GraphDef Clone() {
+ return new GraphDef(this);
+ }
+
+ /// Field number for the "node" field.
+ public const int NodeFieldNumber = 1;
+ private static readonly pb::FieldCodec _repeated_node_codec
+ = pb::FieldCodec.ForMessage(10, global::Tensorflow.NodeDef.Parser);
+ private readonly pbc::RepeatedField node_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField Node {
+ get { return node_; }
+ }
+
+ /// Field number for the "versions" field.
+ public const int VersionsFieldNumber = 4;
+ private global::Tensorflow.VersionDef versions_;
+ ///
+ /// Compatibility versions of the graph. See core/public/version.h for version
+ /// history. The GraphDef version is distinct from the TensorFlow version, and
+ /// each release of TensorFlow will support a range of GraphDef versions.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::Tensorflow.VersionDef Versions {
+ get { return versions_; }
+ set {
+ versions_ = value;
+ }
+ }
+
+ /// Field number for the "version" field.
+ public const int VersionFieldNumber = 3;
+ private int version_;
+ ///
+ /// Deprecated single version field; use versions above instead. Since all
+ /// GraphDef changes before "versions" was introduced were forward
+ /// compatible, this field is entirely ignored.
+ ///
+ [global::System.ObsoleteAttribute]
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int Version {
+ get { return version_; }
+ set {
+ version_ = value;
+ }
+ }
+
+ /// Field number for the "library" field.
+ public const int LibraryFieldNumber = 2;
+ private global::Tensorflow.FunctionDefLibrary library_;
+ ///
+ /// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
+ ///
+ /// "library" provides user-defined functions.
+ ///
+ /// Naming:
+ /// * library.function.name are in a flat namespace.
+ /// NOTE: We may need to change it to be hierarchical to support
+ /// different orgs. E.g.,
+ /// { "/google/nn", { ... }},
+ /// { "/google/vision", { ... }}
+ /// { "/org_foo/module_bar", { ... }}
+ /// map<string, FunctionDefLib> named_lib;
+ /// * If node[i].op is the name of one function in "library",
+ /// node[i] is deemed as a function call. Otherwise, node[i].op
+ /// must be a primitive operation supported by the runtime.
+ ///
+ /// Function call semantics:
+ ///
+ /// * The callee may start execution as soon as some of its inputs
+ /// are ready. The caller may want to use Tuple() mechanism to
+ /// ensure all inputs are ready in the same time.
+ ///
+ /// * The consumer of return values may start executing as soon as
+ /// the return values the consumer depends on are ready. The
+ /// consumer may want to use Tuple() mechanism to ensure the
+ /// consumer does not start until all return values of the callee
+ /// function are ready.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::Tensorflow.FunctionDefLibrary Library {
+ get { return library_; }
+ set {
+ library_ = value;
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as GraphDef);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(GraphDef other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if(!node_.Equals(other.node_)) return false;
+ if (!object.Equals(Versions, other.Versions)) return false;
+ if (Version != other.Version) return false;
+ if (!object.Equals(Library, other.Library)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ hash ^= node_.GetHashCode();
+ if (versions_ != null) hash ^= Versions.GetHashCode();
+ if (Version != 0) hash ^= Version.GetHashCode();
+ if (library_ != null) hash ^= Library.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ node_.WriteTo(output, _repeated_node_codec);
+ if (library_ != null) {
+ output.WriteRawTag(18);
+ output.WriteMessage(Library);
+ }
+ if (Version != 0) {
+ output.WriteRawTag(24);
+ output.WriteInt32(Version);
+ }
+ if (versions_ != null) {
+ output.WriteRawTag(34);
+ output.WriteMessage(Versions);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ size += node_.CalculateSize(_repeated_node_codec);
+ if (versions_ != null) {
+ size += 1 + pb::CodedOutputStream.ComputeMessageSize(Versions);
+ }
+ if (Version != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(Version);
+ }
+ if (library_ != null) {
+ size += 1 + pb::CodedOutputStream.ComputeMessageSize(Library);
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(GraphDef other) {
+ if (other == null) {
+ return;
+ }
+ node_.Add(other.node_);
+ if (other.versions_ != null) {
+ if (versions_ == null) {
+ versions_ = new global::Tensorflow.VersionDef();
+ }
+ Versions.MergeFrom(other.Versions);
+ }
+ if (other.Version != 0) {
+ Version = other.Version;
+ }
+ if (other.library_ != null) {
+ if (library_ == null) {
+ library_ = new global::Tensorflow.FunctionDefLibrary();
+ }
+ Library.MergeFrom(other.Library);
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ node_.AddEntriesFrom(input, _repeated_node_codec);
+ break;
+ }
+ case 18: {
+ if (library_ == null) {
+ library_ = new global::Tensorflow.FunctionDefLibrary();
+ }
+ input.ReadMessage(library_);
+ break;
+ }
+ case 24: {
+ Version = input.ReadInt32();
+ break;
+ }
+ case 34: {
+ if (versions_ == null) {
+ versions_ = new global::Tensorflow.VersionDef();
+ }
+ input.ReadMessage(versions_);
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/src/TensorFlowNET.Core/Protobuf/README.md b/src/TensorFlowNET.Core/Protobuf/README.md
index 4b4cc3d3..c3c34cbe 100644
--- a/src/TensorFlowNET.Core/Protobuf/README.md
+++ b/src/TensorFlowNET.Core/Protobuf/README.md
@@ -1,12 +1,15 @@
### Download compiler from https://github.com/protocolbuffers/protobuf/releases
```shell
-set SRC_DIR=D:\Projects\tensorflow\tensorflow\core\framework
-set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Tensorflow
+set SRC_DIR=D:\Projects\tensorflow-1.12.0\tensorflow\core\framework
+set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Protobuf
-.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% resource_handle.proto
-.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor_shape.proto
-.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% types.proto
-.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor.proto
-.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% attr_value.proto
-.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% node_def.proto
+protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% resource_handle.proto
+protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor_shape.proto
+protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% types.proto
+protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor.proto
+protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% attr_value.proto
+protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% node_def.proto
+protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% versions.proto
+protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% function.proto
+protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% graph.proto
```
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Protobuf/Versions.cs b/src/TensorFlowNET.Core/Protobuf/Versions.cs
new file mode 100644
index 00000000..6e97f1f7
--- /dev/null
+++ b/src/TensorFlowNET.Core/Protobuf/Versions.cs
@@ -0,0 +1,247 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: versions.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace Tensorflow {
+
+ /// Holder for reflection information generated from versions.proto
+ public static partial class VersionsReflection {
+
+ #region Descriptor
+ /// File descriptor for versions.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static VersionsReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "Cg52ZXJzaW9ucy5wcm90bxIKdGVuc29yZmxvdyJLCgpWZXJzaW9uRGVmEhAK",
+ "CHByb2R1Y2VyGAEgASgFEhQKDG1pbl9jb25zdW1lchgCIAEoBRIVCg1iYWRf",
+ "Y29uc3VtZXJzGAMgAygFQm4KGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IO",
+ "VmVyc2lvbnNQcm90b3NQAVo9Z2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNv",
+ "cmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29ya/gBAWIGcHJvdG8z"));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.VersionDef), global::Tensorflow.VersionDef.Parser, new[]{ "Producer", "MinConsumer", "BadConsumers" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ ///
+ /// Version information for a piece of serialized data
+ ///
+ /// There are different types of versions for each type of data
+ /// (GraphDef, etc.), but they all have the same common shape
+ /// described here.
+ ///
+ /// Each consumer has "consumer" and "min_producer" versions (specified
+ /// elsewhere). A consumer is allowed to consume this data if
+ ///
+ /// producer >= min_producer
+ /// consumer >= min_consumer
+ /// consumer not in bad_consumers
+ ///
+ public sealed partial class VersionDef : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new VersionDef());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::Tensorflow.VersionsReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public VersionDef() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public VersionDef(VersionDef other) : this() {
+ producer_ = other.producer_;
+ minConsumer_ = other.minConsumer_;
+ badConsumers_ = other.badConsumers_.Clone();
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public VersionDef Clone() {
+ return new VersionDef(this);
+ }
+
+ /// Field number for the "producer" field.
+ public const int ProducerFieldNumber = 1;
+ private int producer_;
+ ///
+ /// The version of the code that produced this data.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int Producer {
+ get { return producer_; }
+ set {
+ producer_ = value;
+ }
+ }
+
+ /// Field number for the "min_consumer" field.
+ public const int MinConsumerFieldNumber = 2;
+ private int minConsumer_;
+ ///
+ /// Any consumer below this version is not allowed to consume this data.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int MinConsumer {
+ get { return minConsumer_; }
+ set {
+ minConsumer_ = value;
+ }
+ }
+
+ /// Field number for the "bad_consumers" field.
+ public const int BadConsumersFieldNumber = 3;
+ private static readonly pb::FieldCodec _repeated_badConsumers_codec
+ = pb::FieldCodec.ForInt32(26);
+ private readonly pbc::RepeatedField badConsumers_ = new pbc::RepeatedField();
+ ///
+ /// Specific consumer versions which are disallowed (e.g. due to bugs).
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField BadConsumers {
+ get { return badConsumers_; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as VersionDef);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(VersionDef other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (Producer != other.Producer) return false;
+ if (MinConsumer != other.MinConsumer) return false;
+ if(!badConsumers_.Equals(other.badConsumers_)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (Producer != 0) hash ^= Producer.GetHashCode();
+ if (MinConsumer != 0) hash ^= MinConsumer.GetHashCode();
+ hash ^= badConsumers_.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (Producer != 0) {
+ output.WriteRawTag(8);
+ output.WriteInt32(Producer);
+ }
+ if (MinConsumer != 0) {
+ output.WriteRawTag(16);
+ output.WriteInt32(MinConsumer);
+ }
+ badConsumers_.WriteTo(output, _repeated_badConsumers_codec);
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (Producer != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(Producer);
+ }
+ if (MinConsumer != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeInt32Size(MinConsumer);
+ }
+ size += badConsumers_.CalculateSize(_repeated_badConsumers_codec);
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(VersionDef other) {
+ if (other == null) {
+ return;
+ }
+ if (other.Producer != 0) {
+ Producer = other.Producer;
+ }
+ if (other.MinConsumer != 0) {
+ MinConsumer = other.MinConsumer;
+ }
+ badConsumers_.Add(other.badConsumers_);
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 8: {
+ Producer = input.ReadInt32();
+ break;
+ }
+ case 16: {
+ MinConsumer = input.ReadInt32();
+ break;
+ }
+ case 26:
+ case 24: {
+ badConsumers_.AddEntriesFrom(input, _repeated_badConsumers_codec);
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/src/TensorFlowNET.Core/c_api.cs b/src/TensorFlowNET.Core/c_api.cs
index b6de4639..0e0316a1 100644
--- a/src/TensorFlowNET.Core/c_api.cs
+++ b/src/TensorFlowNET.Core/c_api.cs
@@ -17,7 +17,7 @@ namespace Tensorflow
/// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph)
/// struct => struct (TF_Output output) => (TF_Output output)
/// struct* => struct (TF_Output* output) => (TF_Output[] output)
- /// struct* => ref IntPtr (TF_Input* consumers) => (ref IntPtr handle), if output is struct[]
+ /// struct* => struct* for ref
/// const char* => string
/// int32_t => int
/// int64_t* => long[]
diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs
index 63b25f59..3b2fd37c 100644
--- a/test/TensorFlowNET.UnitTest/GraphTest.cs
+++ b/test/TensorFlowNET.UnitTest/GraphTest.cs
@@ -83,6 +83,42 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual(1, feed_port.Length);
Assert.AreEqual(add, feed_port[0].oper);
Assert.AreEqual(0, feed_port[0].index);
+
+ // The scalar const oper also has a consumer.
+ Assert.AreEqual(1, three.OutputNumConsumers(0));
+ TF_Input[] three_port = three.OutputConsumers(0, 1);
+ Assert.AreEqual(add, three_port[0].oper);
+ Assert.AreEqual(1, three_port[0].index);
+
+ // Serialize to GraphDef.
+ var graph_def = c_test_util.GetGraphDef(graph);
+
+ // Validate GraphDef is what we expect.
+ bool found_placeholder = false;
+ bool found_scalar_const = false;
+ bool found_add = false;
+ foreach (var n in graph_def.Node)
+ {
+ if (c_test_util.IsPlaceholder(n))
+ {
+ Assert.IsFalse(found_placeholder);
+ found_placeholder = true;
+ }
+ /*else if (IsScalarConst(n, 3))
+ {
+ Assert.IsFalse(found_scalar_const);
+ found_scalar_const = true;
+ }
+ else if (IsAddN(n, 2))
+ {
+ Assert.IsFalse(found_add);
+ found_add = true;
+ }
+ else
+ {
+ ADD_FAILURE() << "Unexpected NodeDef: " << ProtoDebugString(n);
+ }*/
+ }
}
}
}
diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs
index c45a146e..349d82b3 100644
--- a/test/TensorFlowNET.UnitTest/OperationsTest.cs
+++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs
@@ -19,7 +19,7 @@ namespace TensorFlowNET.UnitTest
{
var handle = c_api.TF_GetAllOpList();
var buffer = new Buffer(handle);
- Assert.IsTrue(buffer.Length == buffer.Data.Length);
+ Assert.IsTrue(buffer.Length == buffer.Length);
}
[TestMethod]
diff --git a/test/TensorFlowNET.UnitTest/c_test_util.cs b/test/TensorFlowNET.UnitTest/c_test_util.cs
index f62433ae..079ecee5 100644
--- a/test/TensorFlowNET.UnitTest/c_test_util.cs
+++ b/test/TensorFlowNET.UnitTest/c_test_util.cs
@@ -39,11 +39,20 @@ namespace TensorFlowNET.UnitTest
{
var buffer = new Buffer();
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s);
- attr_value = AttrValue.Parser.ParseFrom(buffer.Data);
+ attr_value = AttrValue.Parser.ParseFrom(buffer);
buffer.Dispose();
return s.Code == TF_Code.TF_OK;
}
+ public static GraphDef GetGraphDef(Graph graph)
+ {
+ var s = new Status();
+ var buffer = new Buffer();
+ c_api.TF_GraphToGraphDef(graph, buffer, s);
+ s.Check();
+ return GraphDef.Parser.ParseFrom(buffer);
+ }
+
public static bool GetNodeDef(Operation oper, ref NodeDef node_def)
{
var s = new Status();
@@ -53,6 +62,37 @@ namespace TensorFlowNET.UnitTest
return s.Code == TF_Code.TF_OK;
}
+ public static bool IsPlaceholder(NodeDef node_def)
+ {
+ if (node_def.Op != "Placeholder" || node_def.Name != "feed")
+ {
+ return false;
+ }
+
+ bool found_dtype = false;
+ bool found_shape = false;
+ foreach (var attr in node_def.Attr)
+ {
+ if (attr.Key == "dtype")
+ {
+ if (attr.Value.Type == DataType.DtInt32)
+ {
+ found_dtype = true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+ else if (attr.Key == "shape")
+ {
+ found_shape = true;
+ }
+ }
+
+ return found_dtype && found_shape;
+ }
+
public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op)
{
var desc = c_api.TF_NewOperation(graph, "Placeholder", name);