diff --git a/README.md b/README.md index 4d7d81d2..8fa2127d 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ TensorFlow.NET is a member project of SciSharp stack. ### How to use ```cs -using tf = TensorFlowNET.Core.Tensorflow; +using TensorFlowNET.Core; namespace TensorFlowNET.Examples { diff --git a/src/TensorFlowNET.Core/Buffer.cs b/src/TensorFlowNET.Core/Buffer.cs new file mode 100644 index 00000000..26965dbb --- /dev/null +++ b/src/TensorFlowNET.Core/Buffer.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; +using Tensorflow; + +namespace TensorFlowNET.Core +{ + public class Buffer + { + private IntPtr _handle; + public IntPtr Handle => _handle; + //public TF_Buffer buffer => Marshal.PtrToStructure(_handle); + + public unsafe Buffer() + { + _handle = Marshal.AllocHGlobal(sizeof(TF_Buffer)); + } + + public byte[] GetBuffer() + { + var buffer = Marshal.PtrToStructure(_handle); + + var data = Marshal.AllocHGlobal(buffer.length); + //var bytes = c_api.TF_GetBuffer(buffer.data); + + return null; + } + } +} diff --git a/src/TensorFlowNET.Core/Graph.cs b/src/TensorFlowNET.Core/Graph.cs index 216db8f6..013ce8a2 100644 --- a/src/TensorFlowNET.Core/Graph.cs +++ b/src/TensorFlowNET.Core/Graph.cs @@ -32,7 +32,9 @@ namespace TensorFlowNET.Core _names_in_use = new Dictionary(); } - public unsafe Operation create_op(string op_type, object inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, Dictionary attrs = null, string name = "Const") + public unsafe Operation create_op(string op_type, object inputs, TF_DataType[] dtypes, + TF_DataType[] input_types = null, string name = "", + Dictionary attrs = null, OpDef op_def = null) { if (String.IsNullOrEmpty(name)) { diff --git a/src/TensorFlowNET.Core/OpDefLibrary.cs b/src/TensorFlowNET.Core/OpDefLibrary.cs new file mode 100644 index 00000000..192d6ef1 --- /dev/null +++ b/src/TensorFlowNET.Core/OpDefLibrary.cs @@ -0,0 +1,95 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Runtime.InteropServices; +using System.Text; +using Tensorflow; +using static Tensorflow.OpDef.Types; + +namespace TensorFlowNET.Core +{ + public class OpDefLibrary + { + public Dictionary _ops = new Dictionary(); + + public void add_op_list(OpList op_list) + { + foreach(var op_def in op_list.Op) + { + add_op(op_def); + } + } + + public void add_op(OpDef op_def) + { + _ops[op_def.Name] = op_def; + } + + public unsafe Operation _apply_op_helper(string op_type_name, string name = "", DataType? dtype = null, TensorShape shape = null) + { + var op_def = _ops[op_type_name]; + + var status = new Status(); + var buffer = new Buffer(); + + var g = ops.get_default_graph(); + + if (String.IsNullOrEmpty(name)) + { + name = op_type_name; + } + + foreach(var attr_def in op_def.Attr) + { + if (attr_def.Type != "type") continue; + var key = attr_def.Name; + } + + foreach(var input_arg in op_def.InputArg) + { + + } + + var attr_protos = new Dictionary(); + foreach (var attr_def in op_def.Attr) + { + var key = attr_def.Name; + var attr_value = new AttrValue(); + + switch (attr_def.Type) + { + case "type": + attr_value.Type = dtype.Value; + break; + case "shape": + attr_value.Shape = new TensorShapeProto(); + break; + } + + attr_protos[key] = attr_value; + } + + var output_types = new List(); + + foreach (var arg in op_def.OutputArg) + { + if (!String.IsNullOrEmpty(arg.NumberAttr)) + { + + } + else if (!String.IsNullOrEmpty(arg.TypeAttr)) + { + output_types.Add(attr_protos[arg.TypeAttr].Type); + } + } + + var op = g.create_op(op_type_name, null, output_types.ToArray(), + name: "Placeholder_1/", + input_types: new DataType[] { }, + attrs: null, + op_def: null); + + return op; + } + } +} diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 3b550231..a5b40af6 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -25,6 +25,9 @@ PreserveNewest + + PreserveNewest + diff --git a/src/TensorFlowNET.Core/Tensorflow/OpDef.cs b/src/TensorFlowNET.Core/Tensorflow/OpDef.cs new file mode 100644 index 00000000..737e97e5 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensorflow/OpDef.cs @@ -0,0 +1,1485 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: op_def.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from op_def.proto + public static partial class OpDefReflection { + + #region Descriptor + /// File descriptor for op_def.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static OpDefReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CgxvcF9kZWYucHJvdG8SCnRlbnNvcmZsb3caEGF0dHJfdmFsdWUucHJvdG8a", + "C3R5cGVzLnByb3RvIrgFCgVPcERlZhIMCgRuYW1lGAEgASgJEisKCWlucHV0", + "X2FyZxgCIAMoCzIYLnRlbnNvcmZsb3cuT3BEZWYuQXJnRGVmEiwKCm91dHB1", + "dF9hcmcYAyADKAsyGC50ZW5zb3JmbG93Lk9wRGVmLkFyZ0RlZhInCgRhdHRy", + "GAQgAygLMhkudGVuc29yZmxvdy5PcERlZi5BdHRyRGVmEi4KC2RlcHJlY2F0", + "aW9uGAggASgLMhkudGVuc29yZmxvdy5PcERlcHJlY2F0aW9uEg8KB3N1bW1h", + "cnkYBSABKAkSEwoLZGVzY3JpcHRpb24YBiABKAkSFgoOaXNfY29tbXV0YXRp", + "dmUYEiABKAgSFAoMaXNfYWdncmVnYXRlGBAgASgIEhMKC2lzX3N0YXRlZnVs", + "GBEgASgIEiIKGmFsbG93c191bmluaXRpYWxpemVkX2lucHV0GBMgASgIGp8B", + "CgZBcmdEZWYSDAoEbmFtZRgBIAEoCRITCgtkZXNjcmlwdGlvbhgCIAEoCRIi", + "CgR0eXBlGAMgASgOMhQudGVuc29yZmxvdy5EYXRhVHlwZRIRCgl0eXBlX2F0", + "dHIYBCABKAkSEwoLbnVtYmVyX2F0dHIYBSABKAkSFgoOdHlwZV9saXN0X2F0", + "dHIYBiABKAkSDgoGaXNfcmVmGBAgASgIGr0BCgdBdHRyRGVmEgwKBG5hbWUY", + "ASABKAkSDAoEdHlwZRgCIAEoCRIsCg1kZWZhdWx0X3ZhbHVlGAMgASgLMhUu", + "dGVuc29yZmxvdy5BdHRyVmFsdWUSEwoLZGVzY3JpcHRpb24YBCABKAkSEwoL", + "aGFzX21pbmltdW0YBSABKAgSDwoHbWluaW11bRgGIAEoAxItCg5hbGxvd2Vk", + "X3ZhbHVlcxgHIAEoCzIVLnRlbnNvcmZsb3cuQXR0clZhbHVlIjUKDU9wRGVw", + "cmVjYXRpb24SDwoHdmVyc2lvbhgBIAEoBRITCgtleHBsYW5hdGlvbhgCIAEo", + "CSInCgZPcExpc3QSHQoCb3AYASADKAsyES50ZW5zb3JmbG93Lk9wRGVmQmsK", + "GG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0ILT3BEZWZQcm90b3NQAVo9Z2l0", + "aHViLmNvbS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVuc29yZmxvdy9nby9j", + "b3JlL2ZyYW1ld29ya/gBAWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.AttrValueReflection.Descriptor, global::Tensorflow.TypesReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpDef), global::Tensorflow.OpDef.Parser, new[]{ "Name", "InputArg", "OutputArg", "Attr", "Deprecation", "Summary", "Description", "IsCommutative", "IsAggregate", "IsStateful", "AllowsUninitializedInput" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpDef.Types.ArgDef), global::Tensorflow.OpDef.Types.ArgDef.Parser, new[]{ "Name", "Description", "Type", "TypeAttr", "NumberAttr", "TypeListAttr", "IsRef" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpDef.Types.AttrDef), global::Tensorflow.OpDef.Types.AttrDef.Parser, new[]{ "Name", "Type", "DefaultValue", "Description", "HasMinimum", "Minimum", "AllowedValues" }, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpDeprecation), global::Tensorflow.OpDeprecation.Parser, new[]{ "Version", "Explanation" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpList), global::Tensorflow.OpList.Parser, new[]{ "Op" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Defines an operation. A NodeDef in a GraphDef specifies an Op by + /// using the "op" field which should match the name of a OpDef. + /// LINT.IfChange + /// + public sealed partial class OpDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new OpDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.OpDefReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public OpDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public OpDef(OpDef other) : this() { + name_ = other.name_; + inputArg_ = other.inputArg_.Clone(); + outputArg_ = other.outputArg_.Clone(); + attr_ = other.attr_.Clone(); + deprecation_ = other.deprecation_ != null ? other.deprecation_.Clone() : null; + summary_ = other.summary_; + description_ = other.description_; + isCommutative_ = other.isCommutative_; + isAggregate_ = other.isAggregate_; + isStateful_ = other.isStateful_; + allowsUninitializedInput_ = other.allowsUninitializedInput_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public OpDef Clone() { + return new OpDef(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + /// + /// Op names starting with an underscore are reserved for internal use. + /// Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "input_arg" field. + public const int InputArgFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_inputArg_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.OpDef.Types.ArgDef.Parser); + private readonly pbc::RepeatedField inputArg_ = new pbc::RepeatedField(); + /// + /// Description of the input(s). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField InputArg { + get { return inputArg_; } + } + + /// Field number for the "output_arg" field. + public const int OutputArgFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_outputArg_codec + = pb::FieldCodec.ForMessage(26, global::Tensorflow.OpDef.Types.ArgDef.Parser); + private readonly pbc::RepeatedField outputArg_ = new pbc::RepeatedField(); + /// + /// Description of the output(s). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField OutputArg { + get { return outputArg_; } + } + + /// Field number for the "attr" field. + public const int AttrFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_attr_codec + = pb::FieldCodec.ForMessage(34, global::Tensorflow.OpDef.Types.AttrDef.Parser); + private readonly pbc::RepeatedField attr_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Attr { + get { return attr_; } + } + + /// Field number for the "deprecation" field. + public const int DeprecationFieldNumber = 8; + private global::Tensorflow.OpDeprecation deprecation_; + /// + /// Optional deprecation based on GraphDef versions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.OpDeprecation Deprecation { + get { return deprecation_; } + set { + deprecation_ = value; + } + } + + /// Field number for the "summary" field. + public const int SummaryFieldNumber = 5; + private string summary_ = ""; + /// + /// One-line human-readable description of what the Op does. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Summary { + get { return summary_; } + set { + summary_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "description" field. + public const int DescriptionFieldNumber = 6; + private string description_ = ""; + /// + /// Additional, longer human-readable description of what the Op does. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Description { + get { return description_; } + set { + description_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "is_commutative" field. + public const int IsCommutativeFieldNumber = 18; + private bool isCommutative_; + /// + /// True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool IsCommutative { + get { return isCommutative_; } + set { + isCommutative_ = value; + } + } + + /// Field number for the "is_aggregate" field. + public const int IsAggregateFieldNumber = 16; + private bool isAggregate_; + /// + /// If is_aggregate is true, then this operation accepts N >= 2 + /// inputs and produces 1 output all of the same type. Should be + /// associative and commutative, and produce output with the same + /// shape as the input. The optimizer may replace an aggregate op + /// taking input from multiple devices with a tree of aggregate ops + /// that aggregate locally within each device (and possibly within + /// groups of nearby devices) before communicating. + /// TODO(josh11b): Implement that optimization. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool IsAggregate { + get { return isAggregate_; } + set { + isAggregate_ = value; + } + } + + /// Field number for the "is_stateful" field. + public const int IsStatefulFieldNumber = 17; + private bool isStateful_; + /// + /// Ops are marked as stateful if their behavior depends on some state beyond + /// their input tensors (e.g. variable reading op) or if they have + /// a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops + /// must always produce the same output for the same input and have + /// no side-effects. + /// + /// By default Ops may be moved between devices. Stateful ops should + /// either not be moved, or should only be moved if that state can also + /// be moved (e.g. via some sort of save / restore). + /// Stateful ops are guaranteed to never be optimized away by Common + /// Subexpression Elimination (CSE). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool IsStateful { + get { return isStateful_; } + set { + isStateful_ = value; + } + } + + /// Field number for the "allows_uninitialized_input" field. + public const int AllowsUninitializedInputFieldNumber = 19; + private bool allowsUninitializedInput_; + /// + /// By default, all inputs to an Op must be initialized Tensors. Ops + /// that may initialize tensors for the first time should set this + /// field to true, to allow the Op to take an uninitialized Tensor as + /// input. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool AllowsUninitializedInput { + get { return allowsUninitializedInput_; } + set { + allowsUninitializedInput_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as OpDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(OpDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if(!inputArg_.Equals(other.inputArg_)) return false; + if(!outputArg_.Equals(other.outputArg_)) return false; + if(!attr_.Equals(other.attr_)) return false; + if (!object.Equals(Deprecation, other.Deprecation)) return false; + if (Summary != other.Summary) return false; + if (Description != other.Description) return false; + if (IsCommutative != other.IsCommutative) return false; + if (IsAggregate != other.IsAggregate) return false; + if (IsStateful != other.IsStateful) return false; + if (AllowsUninitializedInput != other.AllowsUninitializedInput) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + hash ^= inputArg_.GetHashCode(); + hash ^= outputArg_.GetHashCode(); + hash ^= attr_.GetHashCode(); + if (deprecation_ != null) hash ^= Deprecation.GetHashCode(); + if (Summary.Length != 0) hash ^= Summary.GetHashCode(); + if (Description.Length != 0) hash ^= Description.GetHashCode(); + if (IsCommutative != false) hash ^= IsCommutative.GetHashCode(); + if (IsAggregate != false) hash ^= IsAggregate.GetHashCode(); + if (IsStateful != false) hash ^= IsStateful.GetHashCode(); + if (AllowsUninitializedInput != false) hash ^= AllowsUninitializedInput.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + inputArg_.WriteTo(output, _repeated_inputArg_codec); + outputArg_.WriteTo(output, _repeated_outputArg_codec); + attr_.WriteTo(output, _repeated_attr_codec); + if (Summary.Length != 0) { + output.WriteRawTag(42); + output.WriteString(Summary); + } + if (Description.Length != 0) { + output.WriteRawTag(50); + output.WriteString(Description); + } + if (deprecation_ != null) { + output.WriteRawTag(66); + output.WriteMessage(Deprecation); + } + if (IsAggregate != false) { + output.WriteRawTag(128, 1); + output.WriteBool(IsAggregate); + } + if (IsStateful != false) { + output.WriteRawTag(136, 1); + output.WriteBool(IsStateful); + } + if (IsCommutative != false) { + output.WriteRawTag(144, 1); + output.WriteBool(IsCommutative); + } + if (AllowsUninitializedInput != false) { + output.WriteRawTag(152, 1); + output.WriteBool(AllowsUninitializedInput); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + size += inputArg_.CalculateSize(_repeated_inputArg_codec); + size += outputArg_.CalculateSize(_repeated_outputArg_codec); + size += attr_.CalculateSize(_repeated_attr_codec); + if (deprecation_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Deprecation); + } + if (Summary.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Summary); + } + if (Description.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Description); + } + if (IsCommutative != false) { + size += 2 + 1; + } + if (IsAggregate != false) { + size += 2 + 1; + } + if (IsStateful != false) { + size += 2 + 1; + } + if (AllowsUninitializedInput != false) { + size += 2 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(OpDef other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + inputArg_.Add(other.inputArg_); + outputArg_.Add(other.outputArg_); + attr_.Add(other.attr_); + if (other.deprecation_ != null) { + if (deprecation_ == null) { + deprecation_ = new global::Tensorflow.OpDeprecation(); + } + Deprecation.MergeFrom(other.Deprecation); + } + if (other.Summary.Length != 0) { + Summary = other.Summary; + } + if (other.Description.Length != 0) { + Description = other.Description; + } + if (other.IsCommutative != false) { + IsCommutative = other.IsCommutative; + } + if (other.IsAggregate != false) { + IsAggregate = other.IsAggregate; + } + if (other.IsStateful != false) { + IsStateful = other.IsStateful; + } + if (other.AllowsUninitializedInput != false) { + AllowsUninitializedInput = other.AllowsUninitializedInput; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + inputArg_.AddEntriesFrom(input, _repeated_inputArg_codec); + break; + } + case 26: { + outputArg_.AddEntriesFrom(input, _repeated_outputArg_codec); + break; + } + case 34: { + attr_.AddEntriesFrom(input, _repeated_attr_codec); + break; + } + case 42: { + Summary = input.ReadString(); + break; + } + case 50: { + Description = input.ReadString(); + break; + } + case 66: { + if (deprecation_ == null) { + deprecation_ = new global::Tensorflow.OpDeprecation(); + } + input.ReadMessage(deprecation_); + break; + } + case 128: { + IsAggregate = input.ReadBool(); + break; + } + case 136: { + IsStateful = input.ReadBool(); + break; + } + case 144: { + IsCommutative = input.ReadBool(); + break; + } + case 152: { + AllowsUninitializedInput = input.ReadBool(); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the OpDef message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + /// + /// For describing inputs and outputs. + /// + public sealed partial class ArgDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ArgDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.OpDef.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ArgDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ArgDef(ArgDef other) : this() { + name_ = other.name_; + description_ = other.description_; + type_ = other.type_; + typeAttr_ = other.typeAttr_; + numberAttr_ = other.numberAttr_; + typeListAttr_ = other.typeListAttr_; + isRef_ = other.isRef_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ArgDef Clone() { + return new ArgDef(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + /// + /// Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "description" field. + public const int DescriptionFieldNumber = 2; + private string description_ = ""; + /// + /// Human readable description. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Description { + get { return description_; } + set { + description_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "type" field. + public const int TypeFieldNumber = 3; + private global::Tensorflow.DataType type_ = 0; + /// + /// Describes the type of one or more tensors that are accepted/produced + /// by this input/output arg. The only legal combinations are: + /// * For a single tensor: either the "type" field is set or the + /// "type_attr" field is set to the name of an attr with type "type". + /// * For a sequence of tensors with the same type: the "number_attr" + /// field will be set to the name of an attr with type "int", and + /// either the "type" or "type_attr" field will be set as for + /// single tensors. + /// * For a sequence of tensors, the "type_list_attr" field will be set + /// to the name of an attr with type "list(type)". + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.DataType Type { + get { return type_; } + set { + type_ = value; + } + } + + /// Field number for the "type_attr" field. + public const int TypeAttrFieldNumber = 4; + private string typeAttr_ = ""; + /// + /// if specified, attr must have type "type" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string TypeAttr { + get { return typeAttr_; } + set { + typeAttr_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "number_attr" field. + public const int NumberAttrFieldNumber = 5; + private string numberAttr_ = ""; + /// + /// if specified, attr must have type "int" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string NumberAttr { + get { return numberAttr_; } + set { + numberAttr_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "type_list_attr" field. + public const int TypeListAttrFieldNumber = 6; + private string typeListAttr_ = ""; + /// + /// If specified, attr must have type "list(type)", and none of + /// type, type_attr, and number_attr may be specified. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string TypeListAttr { + get { return typeListAttr_; } + set { + typeListAttr_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "is_ref" field. + public const int IsRefFieldNumber = 16; + private bool isRef_; + /// + /// For inputs: if true, the inputs are required to be refs. + /// By default, inputs can be either refs or non-refs. + /// For outputs: if true, outputs are refs, otherwise they are not. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool IsRef { + get { return isRef_; } + set { + isRef_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ArgDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ArgDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (Description != other.Description) return false; + if (Type != other.Type) return false; + if (TypeAttr != other.TypeAttr) return false; + if (NumberAttr != other.NumberAttr) return false; + if (TypeListAttr != other.TypeListAttr) return false; + if (IsRef != other.IsRef) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (Description.Length != 0) hash ^= Description.GetHashCode(); + if (Type != 0) hash ^= Type.GetHashCode(); + if (TypeAttr.Length != 0) hash ^= TypeAttr.GetHashCode(); + if (NumberAttr.Length != 0) hash ^= NumberAttr.GetHashCode(); + if (TypeListAttr.Length != 0) hash ^= TypeListAttr.GetHashCode(); + if (IsRef != false) hash ^= IsRef.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (Description.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Description); + } + if (Type != 0) { + output.WriteRawTag(24); + output.WriteEnum((int) Type); + } + if (TypeAttr.Length != 0) { + output.WriteRawTag(34); + output.WriteString(TypeAttr); + } + if (NumberAttr.Length != 0) { + output.WriteRawTag(42); + output.WriteString(NumberAttr); + } + if (TypeListAttr.Length != 0) { + output.WriteRawTag(50); + output.WriteString(TypeListAttr); + } + if (IsRef != false) { + output.WriteRawTag(128, 1); + output.WriteBool(IsRef); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (Description.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Description); + } + if (Type != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Type); + } + if (TypeAttr.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TypeAttr); + } + if (NumberAttr.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(NumberAttr); + } + if (TypeListAttr.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TypeListAttr); + } + if (IsRef != false) { + size += 2 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ArgDef other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.Description.Length != 0) { + Description = other.Description; + } + if (other.Type != 0) { + Type = other.Type; + } + if (other.TypeAttr.Length != 0) { + TypeAttr = other.TypeAttr; + } + if (other.NumberAttr.Length != 0) { + NumberAttr = other.NumberAttr; + } + if (other.TypeListAttr.Length != 0) { + TypeListAttr = other.TypeListAttr; + } + if (other.IsRef != false) { + IsRef = other.IsRef; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + Description = input.ReadString(); + break; + } + case 24: { + type_ = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + case 34: { + TypeAttr = input.ReadString(); + break; + } + case 42: { + NumberAttr = input.ReadString(); + break; + } + case 50: { + TypeListAttr = input.ReadString(); + break; + } + case 128: { + IsRef = input.ReadBool(); + break; + } + } + } + } + + } + + /// + /// Description of the graph-construction-time configuration of this + /// Op. That is to say, this describes the attr fields that will + /// be specified in the NodeDef. + /// + public sealed partial class AttrDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AttrDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.OpDef.Descriptor.NestedTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AttrDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AttrDef(AttrDef other) : this() { + name_ = other.name_; + type_ = other.type_; + defaultValue_ = other.defaultValue_ != null ? other.defaultValue_.Clone() : null; + description_ = other.description_; + hasMinimum_ = other.hasMinimum_; + minimum_ = other.minimum_; + allowedValues_ = other.allowedValues_ != null ? other.allowedValues_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AttrDef Clone() { + return new AttrDef(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + /// + /// A descriptive name for the argument. May be used, e.g. by the + /// Python client, as a keyword argument name, and so should match + /// the regexp "[a-z][a-z0-9_]+". + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "type" field. + public const int TypeFieldNumber = 2; + private string type_ = ""; + /// + /// One of the type names from attr_value.proto ("string", "list(string)", + /// "int", etc.). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Type { + get { return type_; } + set { + type_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "default_value" field. + public const int DefaultValueFieldNumber = 3; + private global::Tensorflow.AttrValue defaultValue_; + /// + /// A reasonable default for this attribute if the user does not supply + /// a value. If not specified, the user must supply a value. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.AttrValue DefaultValue { + get { return defaultValue_; } + set { + defaultValue_ = value; + } + } + + /// Field number for the "description" field. + public const int DescriptionFieldNumber = 4; + private string description_ = ""; + /// + /// Human-readable description. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Description { + get { return description_; } + set { + description_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "has_minimum" field. + public const int HasMinimumFieldNumber = 5; + private bool hasMinimum_; + /// + /// For type == "int", this is a minimum value. For "list(___)" + /// types, this is the minimum length. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool HasMinimum { + get { return hasMinimum_; } + set { + hasMinimum_ = value; + } + } + + /// Field number for the "minimum" field. + public const int MinimumFieldNumber = 6; + private long minimum_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public long Minimum { + get { return minimum_; } + set { + minimum_ = value; + } + } + + /// Field number for the "allowed_values" field. + public const int AllowedValuesFieldNumber = 7; + private global::Tensorflow.AttrValue allowedValues_; + /// + /// The set of allowed values. Has type that is the "list" version + /// of the "type" field above (uses the "list" field of AttrValue). + /// If type == "type" or "list(type)" above, then the "type" field + /// of "allowed_values.list" has the set of allowed DataTypes. + /// If type == "string" or "list(string)", then the "s" field of + /// "allowed_values.list" has the set of allowed strings. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.AttrValue AllowedValues { + get { return allowedValues_; } + set { + allowedValues_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as AttrDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(AttrDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (Type != other.Type) return false; + if (!object.Equals(DefaultValue, other.DefaultValue)) return false; + if (Description != other.Description) return false; + if (HasMinimum != other.HasMinimum) return false; + if (Minimum != other.Minimum) return false; + if (!object.Equals(AllowedValues, other.AllowedValues)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (Type.Length != 0) hash ^= Type.GetHashCode(); + if (defaultValue_ != null) hash ^= DefaultValue.GetHashCode(); + if (Description.Length != 0) hash ^= Description.GetHashCode(); + if (HasMinimum != false) hash ^= HasMinimum.GetHashCode(); + if (Minimum != 0L) hash ^= Minimum.GetHashCode(); + if (allowedValues_ != null) hash ^= AllowedValues.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (Type.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Type); + } + if (defaultValue_ != null) { + output.WriteRawTag(26); + output.WriteMessage(DefaultValue); + } + if (Description.Length != 0) { + output.WriteRawTag(34); + output.WriteString(Description); + } + if (HasMinimum != false) { + output.WriteRawTag(40); + output.WriteBool(HasMinimum); + } + if (Minimum != 0L) { + output.WriteRawTag(48); + output.WriteInt64(Minimum); + } + if (allowedValues_ != null) { + output.WriteRawTag(58); + output.WriteMessage(AllowedValues); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (Type.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Type); + } + if (defaultValue_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(DefaultValue); + } + if (Description.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Description); + } + if (HasMinimum != false) { + size += 1 + 1; + } + if (Minimum != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Minimum); + } + if (allowedValues_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(AllowedValues); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(AttrDef other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.Type.Length != 0) { + Type = other.Type; + } + if (other.defaultValue_ != null) { + if (defaultValue_ == null) { + defaultValue_ = new global::Tensorflow.AttrValue(); + } + DefaultValue.MergeFrom(other.DefaultValue); + } + if (other.Description.Length != 0) { + Description = other.Description; + } + if (other.HasMinimum != false) { + HasMinimum = other.HasMinimum; + } + if (other.Minimum != 0L) { + Minimum = other.Minimum; + } + if (other.allowedValues_ != null) { + if (allowedValues_ == null) { + allowedValues_ = new global::Tensorflow.AttrValue(); + } + AllowedValues.MergeFrom(other.AllowedValues); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + Type = input.ReadString(); + break; + } + case 26: { + if (defaultValue_ == null) { + defaultValue_ = new global::Tensorflow.AttrValue(); + } + input.ReadMessage(defaultValue_); + break; + } + case 34: { + Description = input.ReadString(); + break; + } + case 40: { + HasMinimum = input.ReadBool(); + break; + } + case 48: { + Minimum = input.ReadInt64(); + break; + } + case 58: { + if (allowedValues_ == null) { + allowedValues_ = new global::Tensorflow.AttrValue(); + } + input.ReadMessage(allowedValues_); + break; + } + } + } + } + + } + + } + #endregion + + } + + /// + /// Information about version-dependent deprecation of an op + /// + public sealed partial class OpDeprecation : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new OpDeprecation()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.OpDefReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public OpDeprecation() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public OpDeprecation(OpDeprecation other) : this() { + version_ = other.version_; + explanation_ = other.explanation_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public OpDeprecation Clone() { + return new OpDeprecation(this); + } + + /// Field number for the "version" field. + public const int VersionFieldNumber = 1; + private int version_; + /// + /// First GraphDef version at which the op is disallowed. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Version { + get { return version_; } + set { + version_ = value; + } + } + + /// Field number for the "explanation" field. + public const int ExplanationFieldNumber = 2; + private string explanation_ = ""; + /// + /// Explanation of why it was deprecated and what to use instead. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Explanation { + get { return explanation_; } + set { + explanation_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as OpDeprecation); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(OpDeprecation other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Version != other.Version) return false; + if (Explanation != other.Explanation) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Version != 0) hash ^= Version.GetHashCode(); + if (Explanation.Length != 0) hash ^= Explanation.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Version != 0) { + output.WriteRawTag(8); + output.WriteInt32(Version); + } + if (Explanation.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Explanation); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Version != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Version); + } + if (Explanation.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Explanation); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(OpDeprecation other) { + if (other == null) { + return; + } + if (other.Version != 0) { + Version = other.Version; + } + if (other.Explanation.Length != 0) { + Explanation = other.Explanation; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Version = input.ReadInt32(); + break; + } + case 18: { + Explanation = input.ReadString(); + break; + } + } + } + } + + } + + /// + /// A collection of OpDefs + /// + public sealed partial class OpList : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new OpList()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.OpDefReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public OpList() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public OpList(OpList other) : this() { + op_ = other.op_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public OpList Clone() { + return new OpList(this); + } + + /// Field number for the "op" field. + public const int OpFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_op_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.OpDef.Parser); + private readonly pbc::RepeatedField op_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Op { + get { return op_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as OpList); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(OpList other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!op_.Equals(other.op_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= op_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + op_.WriteTo(output, _repeated_op_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += op_.CalculateSize(_repeated_op_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(OpList other) { + if (other == null) { + return; + } + op_.Add(other.op_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + op_.AddEntriesFrom(input, _repeated_op_codec); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Tensorflow/TF_Buffer.cs b/src/TensorFlowNET.Core/Tensorflow/TF_Buffer.cs index 13e184c6..90fc98db 100644 --- a/src/TensorFlowNET.Core/Tensorflow/TF_Buffer.cs +++ b/src/TensorFlowNET.Core/Tensorflow/TF_Buffer.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using System.Runtime.InteropServices; using System.Text; -using size_t = System.IntPtr; namespace Tensorflow { @@ -10,7 +9,7 @@ namespace Tensorflow public struct TF_Buffer { public IntPtr data; - public size_t length; + public int length; public IntPtr data_deallocator; } } diff --git a/src/TensorFlowNET.Core/Tensorflow/op_list_proto_bytes.bin b/src/TensorFlowNET.Core/Tensorflow/op_list_proto_bytes.bin new file mode 100644 index 00000000..62d31e67 Binary files /dev/null and b/src/TensorFlowNET.Core/Tensorflow/op_list_proto_bytes.bin differ diff --git a/src/TensorFlowNET.Core/c_api.cs b/src/TensorFlowNET.Core/c_api.cs index f60d2007..347d27cd 100644 --- a/src/TensorFlowNET.Core/c_api.cs +++ b/src/TensorFlowNET.Core/c_api.cs @@ -27,9 +27,15 @@ namespace TensorFlowNET.Core [DllImport(TensorFlowLibName)] public static unsafe extern TF_Operation TF_FinishOperation(TF_OperationDescription desc, TF_Status status); + [DllImport(TensorFlowLibName)] + public static extern string TF_GetBuffer(IntPtr buffer); + [DllImport(TensorFlowLibName)] public static extern unsafe TF_Code TF_GetCode(TF_Status s); + [DllImport(TensorFlowLibName)] + public static extern void TF_GraphGetOpDef(TF_Graph graph, string op_name, IntPtr output_op_def, TF_Status status); + [DllImport(TensorFlowLibName)] public static extern unsafe string TF_Message(TF_Status s); diff --git a/src/TensorFlowNET.Core/gen_array_ops.cs b/src/TensorFlowNET.Core/gen_array_ops.cs index df131a8a..abd34659 100644 --- a/src/TensorFlowNET.Core/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/gen_array_ops.cs @@ -1,10 +1,31 @@ using System; using System.Collections.Generic; +using System.IO; using System.Text; +using Tensorflow; namespace TensorFlowNET.Core { public static class gen_array_ops { + public static OpDefLibrary _op_def_lib => _InitOpDefLibrary(); + + public static Tensor placeholder(DataType dtype, TensorShape shape = null) + { + var op = _op_def_lib._apply_op_helper("Placeholder", dtype: dtype, shape: shape); + + return null; + } + + private static OpDefLibrary _InitOpDefLibrary() + { + // c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); + var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_bytes.bin"); + var op_list = OpList.Parser.ParseFrom(bytes); + var op_def_lib = new OpDefLibrary(); + op_def_lib.add_op_list(op_list); + + return op_def_lib; + } } } diff --git a/src/TensorFlowNET.Core/tensor_util.cs b/src/TensorFlowNET.Core/tensor_util.cs index cc5aaa87..cad0eb6c 100644 --- a/src/TensorFlowNET.Core/tensor_util.cs +++ b/src/TensorFlowNET.Core/tensor_util.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; using System.Text; using Tensorflow; -using np = NumSharp.Core.NumPy; using tensor_pb2 = Tensorflow; namespace TensorFlowNET.Core diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index 43fc140c..83bbfc29 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -10,14 +10,13 @@ namespace TensorFlowNET.Core { public static class tf { - public static Type float32 = typeof(float); + public static DataType float32 = DataType.DtFloat; public delegate void Deallocator(IntPtr data, IntPtr size, IntPtr deallocatorData); - public static unsafe Tensor placeholder(Type dtype, TensorShape shape = null) + public static unsafe Tensor placeholder(DataType dtype, TensorShape shape = null) { - - return null; + return gen_array_ops.placeholder(dtype, shape); } public static unsafe Tensor constant(object value)