From 5a58965e9106f39100acba94242f0120ba785d65 Mon Sep 17 00:00:00 2001 From: haiping008 Date: Thu, 13 Dec 2018 10:51:57 -0600 Subject: [PATCH] Import Google Protocol Buffers #3 --- src/TensorFlowNET.Core/Graph.cs | 2 + src/TensorFlowNET.Core/TF_DataType.cs | 38 - .../TensorFlowNET.Core.csproj | 4 + src/TensorFlowNET.Core/Tensorflow.cs | 3 +- .../Tensorflow/AttrValue.cs | 1027 +++++++++++++++++ src/TensorFlowNET.Core/Tensorflow/NodeDef.cs | 326 ++++++ src/TensorFlowNET.Core/Tensorflow/README.md | 12 + .../Tensorflow/ResourceHandle.cs | 311 +++++ src/TensorFlowNET.Core/Tensorflow/Tensor.cs | 801 +++++++++++++ .../Tensorflow/TensorShape.cs | 397 +++++++ src/TensorFlowNET.Core/Tensorflow/Types.cs | 153 +++ src/TensorFlowNET.Core/c_api.cs | 2 + src/TensorFlowNET.Core/ops.cs | 5 +- 13 files changed, 3040 insertions(+), 41 deletions(-) delete mode 100644 src/TensorFlowNET.Core/TF_DataType.cs create mode 100644 src/TensorFlowNET.Core/Tensorflow/AttrValue.cs create mode 100644 src/TensorFlowNET.Core/Tensorflow/NodeDef.cs create mode 100644 src/TensorFlowNET.Core/Tensorflow/README.md create mode 100644 src/TensorFlowNET.Core/Tensorflow/ResourceHandle.cs create mode 100644 src/TensorFlowNET.Core/Tensorflow/Tensor.cs create mode 100644 src/TensorFlowNET.Core/Tensorflow/TensorShape.cs create mode 100644 src/TensorFlowNET.Core/Tensorflow/Types.cs diff --git a/src/TensorFlowNET.Core/Graph.cs b/src/TensorFlowNET.Core/Graph.cs index 3518ac7a..992cb77b 100644 --- a/src/TensorFlowNET.Core/Graph.cs +++ b/src/TensorFlowNET.Core/Graph.cs @@ -4,6 +4,8 @@ using System.Linq; using System.Runtime.InteropServices; using System.Text; +using TF_DataType = Tensorflow.DataType; + namespace TensorFlowNET.Core { /// diff --git a/src/TensorFlowNET.Core/TF_DataType.cs b/src/TensorFlowNET.Core/TF_DataType.cs deleted file mode 100644 index 912fba82..00000000 --- a/src/TensorFlowNET.Core/TF_DataType.cs +++ /dev/null @@ -1,38 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace TensorFlowNET.Core -{ - /// - /// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. - /// The enum values here are identical to corresponding values in types.proto. - /// - public enum TF_DataType - { - TF_FLOAT = 1, - TF_DOUBLE = 2, - TF_INT32 = 3, // Int32 tensors are always in 'host' memory. - TF_UINT8 = 4, - TF_INT16 = 5, - TF_INT8 = 6, - TF_STRING = 7, - TF_COMPLEX64 = 8, // Single-precision complex - TF_COMPLEX = 8, // Old identifier kept for API backwards compatibility - TF_INT64 = 9, - TF_BOOL = 10, - TF_QINT8 = 11, // Quantized int8 - TF_QUINT8 = 12, // Quantized uint8 - TF_QINT32 = 13, // Quantized int32 - TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops. - TF_QINT16 = 15, // Quantized int16 - TF_QUINT16 = 16, // Quantized uint16 - TF_UINT16 = 17, - TF_COMPLEX128 = 18, // Double-precision complex - TF_HALF = 19, - TF_RESOURCE = 20, - TF_VARIANT = 21, - TF_UINT32 = 22, - TF_UINT64 = 23, - } -} diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index c95a170c..b50fc496 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -9,6 +9,10 @@ DEBUG;TRACE + + + + diff --git a/src/TensorFlowNET.Core/Tensorflow.cs b/src/TensorFlowNET.Core/Tensorflow.cs index e83e8a3a..4a996a46 100644 --- a/src/TensorFlowNET.Core/Tensorflow.cs +++ b/src/TensorFlowNET.Core/Tensorflow.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Runtime.InteropServices; using System.Text; +using TF_DataType = Tensorflow.DataType; namespace TensorFlowNET.Core { @@ -12,7 +13,7 @@ namespace TensorFlowNET.Core public static unsafe Tensor constant(object value) { var g = ops.get_default_graph(); - g.create_op("Const", value, new TF_DataType[] { TF_DataType.TF_DOUBLE }); + g.create_op("Const", value, new TF_DataType[] { TF_DataType.DtDouble }); return new Tensor(); } diff --git a/src/TensorFlowNET.Core/Tensorflow/AttrValue.cs b/src/TensorFlowNET.Core/Tensorflow/AttrValue.cs new file mode 100644 index 00000000..158179a0 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensorflow/AttrValue.cs @@ -0,0 +1,1027 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: attr_value.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from attr_value.proto + public static partial class AttrValueReflection { + + #region Descriptor + /// File descriptor for attr_value.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static AttrValueReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "ChBhdHRyX3ZhbHVlLnByb3RvEgp0ZW5zb3JmbG93Ggx0ZW5zb3IucHJvdG8a", + "EnRlbnNvcl9zaGFwZS5wcm90bxoLdHlwZXMucHJvdG8ipgQKCUF0dHJWYWx1", + "ZRILCgFzGAIgASgMSAASCwoBaRgDIAEoA0gAEgsKAWYYBCABKAJIABILCgFi", + "GAUgASgISAASJAoEdHlwZRgGIAEoDjIULnRlbnNvcmZsb3cuRGF0YVR5cGVI", + "ABItCgVzaGFwZRgHIAEoCzIcLnRlbnNvcmZsb3cuVGVuc29yU2hhcGVQcm90", + "b0gAEikKBnRlbnNvchgIIAEoCzIXLnRlbnNvcmZsb3cuVGVuc29yUHJvdG9I", + "ABIvCgRsaXN0GAEgASgLMh8udGVuc29yZmxvdy5BdHRyVmFsdWUuTGlzdFZh", + "bHVlSAASKAoEZnVuYxgKIAEoCzIYLnRlbnNvcmZsb3cuTmFtZUF0dHJMaXN0", + "SAASFQoLcGxhY2Vob2xkZXIYCSABKAlIABrpAQoJTGlzdFZhbHVlEgkKAXMY", + "AiADKAwSDQoBaRgDIAMoA0ICEAESDQoBZhgEIAMoAkICEAESDQoBYhgFIAMo", + "CEICEAESJgoEdHlwZRgGIAMoDjIULnRlbnNvcmZsb3cuRGF0YVR5cGVCAhAB", + "EisKBXNoYXBlGAcgAygLMhwudGVuc29yZmxvdy5UZW5zb3JTaGFwZVByb3Rv", + "EicKBnRlbnNvchgIIAMoCzIXLnRlbnNvcmZsb3cuVGVuc29yUHJvdG8SJgoE", + "ZnVuYxgJIAMoCzIYLnRlbnNvcmZsb3cuTmFtZUF0dHJMaXN0QgcKBXZhbHVl", + "IpIBCgxOYW1lQXR0ckxpc3QSDAoEbmFtZRgBIAEoCRIwCgRhdHRyGAIgAygL", + "MiIudGVuc29yZmxvdy5OYW1lQXR0ckxpc3QuQXR0ckVudHJ5GkIKCUF0dHJF", + "bnRyeRILCgNrZXkYASABKAkSJAoFdmFsdWUYAiABKAsyFS50ZW5zb3JmbG93", + "LkF0dHJWYWx1ZToCOAFCbwoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3JrQg9B", + "dHRyVmFsdWVQcm90b3NQAVo9Z2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNv", + "cmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29ya/gBAWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.TensorReflection.Descriptor, global::Tensorflow.TensorShapeReflection.Descriptor, global::Tensorflow.TypesReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.AttrValue), global::Tensorflow.AttrValue.Parser, new[]{ "S", "I", "F", "B", "Type", "Shape", "Tensor", "List", "Func", "Placeholder" }, new[]{ "Value" }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.AttrValue.Types.ListValue), global::Tensorflow.AttrValue.Types.ListValue.Parser, new[]{ "S", "I", "F", "B", "Type", "Shape", "Tensor", "Func" }, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.NameAttrList), global::Tensorflow.NameAttrList.Parser, new[]{ "Name", "Attr" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, }) + })); + } + #endregion + + } + #region Messages + /// + /// Protocol buffer representing the value for an attr used to configure an Op. + /// Comment indicates the corresponding attr type. Only the field matching the + /// attr type may be filled. + /// + public sealed partial class AttrValue : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AttrValue()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.AttrValueReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AttrValue() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AttrValue(AttrValue other) : this() { + switch (other.ValueCase) { + case ValueOneofCase.S: + S = other.S; + break; + case ValueOneofCase.I: + I = other.I; + break; + case ValueOneofCase.F: + F = other.F; + break; + case ValueOneofCase.B: + B = other.B; + break; + case ValueOneofCase.Type: + Type = other.Type; + break; + case ValueOneofCase.Shape: + Shape = other.Shape.Clone(); + break; + case ValueOneofCase.Tensor: + Tensor = other.Tensor.Clone(); + break; + case ValueOneofCase.List: + List = other.List.Clone(); + break; + case ValueOneofCase.Func: + Func = other.Func.Clone(); + break; + case ValueOneofCase.Placeholder: + Placeholder = other.Placeholder; + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AttrValue Clone() { + return new AttrValue(this); + } + + /// Field number for the "s" field. + public const int SFieldNumber = 2; + /// + /// "string" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pb::ByteString S { + get { return valueCase_ == ValueOneofCase.S ? (pb::ByteString) value_ : pb::ByteString.Empty; } + set { + value_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + valueCase_ = ValueOneofCase.S; + } + } + + /// Field number for the "i" field. + public const int IFieldNumber = 3; + /// + /// "int" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public long I { + get { return valueCase_ == ValueOneofCase.I ? (long) value_ : 0L; } + set { + value_ = value; + valueCase_ = ValueOneofCase.I; + } + } + + /// Field number for the "f" field. + public const int FFieldNumber = 4; + /// + /// "float" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float F { + get { return valueCase_ == ValueOneofCase.F ? (float) value_ : 0F; } + set { + value_ = value; + valueCase_ = ValueOneofCase.F; + } + } + + /// Field number for the "b" field. + public const int BFieldNumber = 5; + /// + /// "bool" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool B { + get { return valueCase_ == ValueOneofCase.B ? (bool) value_ : false; } + set { + value_ = value; + valueCase_ = ValueOneofCase.B; + } + } + + /// Field number for the "type" field. + public const int TypeFieldNumber = 6; + /// + /// "type" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.DataType Type { + get { return valueCase_ == ValueOneofCase.Type ? (global::Tensorflow.DataType) value_ : 0; } + set { + value_ = value; + valueCase_ = ValueOneofCase.Type; + } + } + + /// Field number for the "shape" field. + public const int ShapeFieldNumber = 7; + /// + /// "shape" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.TensorShapeProto Shape { + get { return valueCase_ == ValueOneofCase.Shape ? (global::Tensorflow.TensorShapeProto) value_ : null; } + set { + value_ = value; + valueCase_ = value == null ? ValueOneofCase.None : ValueOneofCase.Shape; + } + } + + /// Field number for the "tensor" field. + public const int TensorFieldNumber = 8; + /// + /// "tensor" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.TensorProto Tensor { + get { return valueCase_ == ValueOneofCase.Tensor ? (global::Tensorflow.TensorProto) value_ : null; } + set { + value_ = value; + valueCase_ = value == null ? ValueOneofCase.None : ValueOneofCase.Tensor; + } + } + + /// Field number for the "list" field. + public const int ListFieldNumber = 1; + /// + /// any "list(...)" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.AttrValue.Types.ListValue List { + get { return valueCase_ == ValueOneofCase.List ? (global::Tensorflow.AttrValue.Types.ListValue) value_ : null; } + set { + value_ = value; + valueCase_ = value == null ? ValueOneofCase.None : ValueOneofCase.List; + } + } + + /// Field number for the "func" field. + public const int FuncFieldNumber = 10; + /// + /// "func" represents a function. func.name is a function's name or + /// a primitive op's name. func.attr.first is the name of an attr + /// defined for that function. func.attr.second is the value for + /// that attr in the instantiation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.NameAttrList Func { + get { return valueCase_ == ValueOneofCase.Func ? (global::Tensorflow.NameAttrList) value_ : null; } + set { + value_ = value; + valueCase_ = value == null ? ValueOneofCase.None : ValueOneofCase.Func; + } + } + + /// Field number for the "placeholder" field. + public const int PlaceholderFieldNumber = 9; + /// + /// This is a placeholder only used in nodes defined inside a + /// function. It indicates the attr value will be supplied when + /// the function is instantiated. For example, let us suppose a + /// node "N" in function "FN". "N" has an attr "A" with value + /// placeholder = "foo". When FN is instantiated with attr "foo" + /// set to "bar", the instantiated node N's attr A will have been + /// given the value "bar". + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Placeholder { + get { return valueCase_ == ValueOneofCase.Placeholder ? (string) value_ : ""; } + set { + value_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + valueCase_ = ValueOneofCase.Placeholder; + } + } + + private object value_; + /// Enum of possible cases for the "value" oneof. + public enum ValueOneofCase { + None = 0, + S = 2, + I = 3, + F = 4, + B = 5, + Type = 6, + Shape = 7, + Tensor = 8, + List = 1, + Func = 10, + Placeholder = 9, + } + private ValueOneofCase valueCase_ = ValueOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ValueOneofCase ValueCase { + get { return valueCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearValue() { + valueCase_ = ValueOneofCase.None; + value_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as AttrValue); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(AttrValue other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (S != other.S) return false; + if (I != other.I) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(F, other.F)) return false; + if (B != other.B) return false; + if (Type != other.Type) return false; + if (!object.Equals(Shape, other.Shape)) return false; + if (!object.Equals(Tensor, other.Tensor)) return false; + if (!object.Equals(List, other.List)) return false; + if (!object.Equals(Func, other.Func)) return false; + if (Placeholder != other.Placeholder) return false; + if (ValueCase != other.ValueCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (valueCase_ == ValueOneofCase.S) hash ^= S.GetHashCode(); + if (valueCase_ == ValueOneofCase.I) hash ^= I.GetHashCode(); + if (valueCase_ == ValueOneofCase.F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(F); + if (valueCase_ == ValueOneofCase.B) hash ^= B.GetHashCode(); + if (valueCase_ == ValueOneofCase.Type) hash ^= Type.GetHashCode(); + if (valueCase_ == ValueOneofCase.Shape) hash ^= Shape.GetHashCode(); + if (valueCase_ == ValueOneofCase.Tensor) hash ^= Tensor.GetHashCode(); + if (valueCase_ == ValueOneofCase.List) hash ^= List.GetHashCode(); + if (valueCase_ == ValueOneofCase.Func) hash ^= Func.GetHashCode(); + if (valueCase_ == ValueOneofCase.Placeholder) hash ^= Placeholder.GetHashCode(); + hash ^= (int) valueCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (valueCase_ == ValueOneofCase.List) { + output.WriteRawTag(10); + output.WriteMessage(List); + } + if (valueCase_ == ValueOneofCase.S) { + output.WriteRawTag(18); + output.WriteBytes(S); + } + if (valueCase_ == ValueOneofCase.I) { + output.WriteRawTag(24); + output.WriteInt64(I); + } + if (valueCase_ == ValueOneofCase.F) { + output.WriteRawTag(37); + output.WriteFloat(F); + } + if (valueCase_ == ValueOneofCase.B) { + output.WriteRawTag(40); + output.WriteBool(B); + } + if (valueCase_ == ValueOneofCase.Type) { + output.WriteRawTag(48); + output.WriteEnum((int) Type); + } + if (valueCase_ == ValueOneofCase.Shape) { + output.WriteRawTag(58); + output.WriteMessage(Shape); + } + if (valueCase_ == ValueOneofCase.Tensor) { + output.WriteRawTag(66); + output.WriteMessage(Tensor); + } + if (valueCase_ == ValueOneofCase.Placeholder) { + output.WriteRawTag(74); + output.WriteString(Placeholder); + } + if (valueCase_ == ValueOneofCase.Func) { + output.WriteRawTag(82); + output.WriteMessage(Func); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (valueCase_ == ValueOneofCase.S) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(S); + } + if (valueCase_ == ValueOneofCase.I) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(I); + } + if (valueCase_ == ValueOneofCase.F) { + size += 1 + 4; + } + if (valueCase_ == ValueOneofCase.B) { + size += 1 + 1; + } + if (valueCase_ == ValueOneofCase.Type) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Type); + } + if (valueCase_ == ValueOneofCase.Shape) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Shape); + } + if (valueCase_ == ValueOneofCase.Tensor) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Tensor); + } + if (valueCase_ == ValueOneofCase.List) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(List); + } + if (valueCase_ == ValueOneofCase.Func) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Func); + } + if (valueCase_ == ValueOneofCase.Placeholder) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Placeholder); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(AttrValue other) { + if (other == null) { + return; + } + switch (other.ValueCase) { + case ValueOneofCase.S: + S = other.S; + break; + case ValueOneofCase.I: + I = other.I; + break; + case ValueOneofCase.F: + F = other.F; + break; + case ValueOneofCase.B: + B = other.B; + break; + case ValueOneofCase.Type: + Type = other.Type; + break; + case ValueOneofCase.Shape: + if (Shape == null) { + Shape = new global::Tensorflow.TensorShapeProto(); + } + Shape.MergeFrom(other.Shape); + break; + case ValueOneofCase.Tensor: + if (Tensor == null) { + Tensor = new global::Tensorflow.TensorProto(); + } + Tensor.MergeFrom(other.Tensor); + break; + case ValueOneofCase.List: + if (List == null) { + List = new global::Tensorflow.AttrValue.Types.ListValue(); + } + List.MergeFrom(other.List); + break; + case ValueOneofCase.Func: + if (Func == null) { + Func = new global::Tensorflow.NameAttrList(); + } + Func.MergeFrom(other.Func); + break; + case ValueOneofCase.Placeholder: + Placeholder = other.Placeholder; + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.AttrValue.Types.ListValue subBuilder = new global::Tensorflow.AttrValue.Types.ListValue(); + if (valueCase_ == ValueOneofCase.List) { + subBuilder.MergeFrom(List); + } + input.ReadMessage(subBuilder); + List = subBuilder; + break; + } + case 18: { + S = input.ReadBytes(); + break; + } + case 24: { + I = input.ReadInt64(); + break; + } + case 37: { + F = input.ReadFloat(); + break; + } + case 40: { + B = input.ReadBool(); + break; + } + case 48: { + value_ = input.ReadEnum(); + valueCase_ = ValueOneofCase.Type; + break; + } + case 58: { + global::Tensorflow.TensorShapeProto subBuilder = new global::Tensorflow.TensorShapeProto(); + if (valueCase_ == ValueOneofCase.Shape) { + subBuilder.MergeFrom(Shape); + } + input.ReadMessage(subBuilder); + Shape = subBuilder; + break; + } + case 66: { + global::Tensorflow.TensorProto subBuilder = new global::Tensorflow.TensorProto(); + if (valueCase_ == ValueOneofCase.Tensor) { + subBuilder.MergeFrom(Tensor); + } + input.ReadMessage(subBuilder); + Tensor = subBuilder; + break; + } + case 74: { + Placeholder = input.ReadString(); + break; + } + case 82: { + global::Tensorflow.NameAttrList subBuilder = new global::Tensorflow.NameAttrList(); + if (valueCase_ == ValueOneofCase.Func) { + subBuilder.MergeFrom(Func); + } + input.ReadMessage(subBuilder); + Func = subBuilder; + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the AttrValue message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + /// + /// LINT.IfChange + /// + public sealed partial class ListValue : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ListValue()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.AttrValue.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ListValue() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ListValue(ListValue other) : this() { + s_ = other.s_.Clone(); + i_ = other.i_.Clone(); + f_ = other.f_.Clone(); + b_ = other.b_.Clone(); + type_ = other.type_.Clone(); + shape_ = other.shape_.Clone(); + tensor_ = other.tensor_.Clone(); + func_ = other.func_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ListValue Clone() { + return new ListValue(this); + } + + /// Field number for the "s" field. + public const int SFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_s_codec + = pb::FieldCodec.ForBytes(18); + private readonly pbc::RepeatedField s_ = new pbc::RepeatedField(); + /// + /// "list(string)" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField S { + get { return s_; } + } + + /// Field number for the "i" field. + public const int IFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_i_codec + = pb::FieldCodec.ForInt64(26); + private readonly pbc::RepeatedField i_ = new pbc::RepeatedField(); + /// + /// "list(int)" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField I { + get { return i_; } + } + + /// Field number for the "f" field. + public const int FFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_f_codec + = pb::FieldCodec.ForFloat(34); + private readonly pbc::RepeatedField f_ = new pbc::RepeatedField(); + /// + /// "list(float)" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField F { + get { return f_; } + } + + /// Field number for the "b" field. + public const int BFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_b_codec + = pb::FieldCodec.ForBool(42); + private readonly pbc::RepeatedField b_ = new pbc::RepeatedField(); + /// + /// "list(bool)" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField B { + get { return b_; } + } + + /// Field number for the "type" field. + public const int TypeFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_type_codec + = pb::FieldCodec.ForEnum(50, x => (int) x, x => (global::Tensorflow.DataType) x); + private readonly pbc::RepeatedField type_ = new pbc::RepeatedField(); + /// + /// "list(type)" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Type { + get { return type_; } + } + + /// Field number for the "shape" field. + public const int ShapeFieldNumber = 7; + private static readonly pb::FieldCodec _repeated_shape_codec + = pb::FieldCodec.ForMessage(58, global::Tensorflow.TensorShapeProto.Parser); + private readonly pbc::RepeatedField shape_ = new pbc::RepeatedField(); + /// + /// "list(shape)" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Shape { + get { return shape_; } + } + + /// Field number for the "tensor" field. + public const int TensorFieldNumber = 8; + private static readonly pb::FieldCodec _repeated_tensor_codec + = pb::FieldCodec.ForMessage(66, global::Tensorflow.TensorProto.Parser); + private readonly pbc::RepeatedField tensor_ = new pbc::RepeatedField(); + /// + /// "list(tensor)" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Tensor { + get { return tensor_; } + } + + /// Field number for the "func" field. + public const int FuncFieldNumber = 9; + private static readonly pb::FieldCodec _repeated_func_codec + = pb::FieldCodec.ForMessage(74, global::Tensorflow.NameAttrList.Parser); + private readonly pbc::RepeatedField func_ = new pbc::RepeatedField(); + /// + /// "list(attr)" + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Func { + get { return func_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ListValue); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ListValue other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!s_.Equals(other.s_)) return false; + if(!i_.Equals(other.i_)) return false; + if(!f_.Equals(other.f_)) return false; + if(!b_.Equals(other.b_)) return false; + if(!type_.Equals(other.type_)) return false; + if(!shape_.Equals(other.shape_)) return false; + if(!tensor_.Equals(other.tensor_)) return false; + if(!func_.Equals(other.func_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= s_.GetHashCode(); + hash ^= i_.GetHashCode(); + hash ^= f_.GetHashCode(); + hash ^= b_.GetHashCode(); + hash ^= type_.GetHashCode(); + hash ^= shape_.GetHashCode(); + hash ^= tensor_.GetHashCode(); + hash ^= func_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + s_.WriteTo(output, _repeated_s_codec); + i_.WriteTo(output, _repeated_i_codec); + f_.WriteTo(output, _repeated_f_codec); + b_.WriteTo(output, _repeated_b_codec); + type_.WriteTo(output, _repeated_type_codec); + shape_.WriteTo(output, _repeated_shape_codec); + tensor_.WriteTo(output, _repeated_tensor_codec); + func_.WriteTo(output, _repeated_func_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += s_.CalculateSize(_repeated_s_codec); + size += i_.CalculateSize(_repeated_i_codec); + size += f_.CalculateSize(_repeated_f_codec); + size += b_.CalculateSize(_repeated_b_codec); + size += type_.CalculateSize(_repeated_type_codec); + size += shape_.CalculateSize(_repeated_shape_codec); + size += tensor_.CalculateSize(_repeated_tensor_codec); + size += func_.CalculateSize(_repeated_func_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ListValue other) { + if (other == null) { + return; + } + s_.Add(other.s_); + i_.Add(other.i_); + f_.Add(other.f_); + b_.Add(other.b_); + type_.Add(other.type_); + shape_.Add(other.shape_); + tensor_.Add(other.tensor_); + func_.Add(other.func_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 18: { + s_.AddEntriesFrom(input, _repeated_s_codec); + break; + } + case 26: + case 24: { + i_.AddEntriesFrom(input, _repeated_i_codec); + break; + } + case 34: + case 37: { + f_.AddEntriesFrom(input, _repeated_f_codec); + break; + } + case 42: + case 40: { + b_.AddEntriesFrom(input, _repeated_b_codec); + break; + } + case 50: + case 48: { + type_.AddEntriesFrom(input, _repeated_type_codec); + break; + } + case 58: { + shape_.AddEntriesFrom(input, _repeated_shape_codec); + break; + } + case 66: { + tensor_.AddEntriesFrom(input, _repeated_tensor_codec); + break; + } + case 74: { + func_.AddEntriesFrom(input, _repeated_func_codec); + break; + } + } + } + } + + } + + } + #endregion + + } + + /// + /// A list of attr names and their values. The whole list is attached + /// with a string name. E.g., MatMul[T=float]. + /// + public sealed partial class NameAttrList : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new NameAttrList()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.AttrValueReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public NameAttrList() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public NameAttrList(NameAttrList other) : this() { + name_ = other.name_; + attr_ = other.attr_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public NameAttrList Clone() { + return new NameAttrList(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "attr" field. + public const int AttrFieldNumber = 2; + private static readonly pbc::MapField.Codec _map_attr_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 18); + private readonly pbc::MapField attr_ = new pbc::MapField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::MapField Attr { + get { return attr_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as NameAttrList); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(NameAttrList other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (!Attr.Equals(other.Attr)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + hash ^= Attr.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + attr_.WriteTo(output, _map_attr_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + size += attr_.CalculateSize(_map_attr_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(NameAttrList other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + attr_.Add(other.attr_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + attr_.AddEntriesFrom(input, _map_attr_codec); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Tensorflow/NodeDef.cs b/src/TensorFlowNET.Core/Tensorflow/NodeDef.cs new file mode 100644 index 00000000..af40fd62 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensorflow/NodeDef.cs @@ -0,0 +1,326 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: node_def.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from node_def.proto + public static partial class NodeDefReflection { + + #region Descriptor + /// File descriptor for node_def.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static NodeDefReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cg5ub2RlX2RlZi5wcm90bxIKdGVuc29yZmxvdxoQYXR0cl92YWx1ZS5wcm90", + "byKzAQoHTm9kZURlZhIMCgRuYW1lGAEgASgJEgoKAm9wGAIgASgJEg0KBWlu", + "cHV0GAMgAygJEg4KBmRldmljZRgEIAEoCRIrCgRhdHRyGAUgAygLMh0udGVu", + "c29yZmxvdy5Ob2RlRGVmLkF0dHJFbnRyeRpCCglBdHRyRW50cnkSCwoDa2V5", + "GAEgASgJEiQKBXZhbHVlGAIgASgLMhUudGVuc29yZmxvdy5BdHRyVmFsdWU6", + "AjgBQmkKGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IJTm9kZVByb3RvUAFa", + "PWdpdGh1Yi5jb20vdGVuc29yZmxvdy90ZW5zb3JmbG93L3RlbnNvcmZsb3cv", + "Z28vY29yZS9mcmFtZXdvcmv4AQFiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.AttrValueReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.NodeDef), global::Tensorflow.NodeDef.Parser, new[]{ "Name", "Op", "Input", "Device", "Attr" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, }) + })); + } + #endregion + + } + #region Messages + public sealed partial class NodeDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new NodeDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.NodeDefReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public NodeDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public NodeDef(NodeDef other) : this() { + name_ = other.name_; + op_ = other.op_; + input_ = other.input_.Clone(); + device_ = other.device_; + attr_ = other.attr_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public NodeDef Clone() { + return new NodeDef(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + /// + /// The name given to this operator. Used for naming inputs, + /// logging, visualization, etc. Unique within a single GraphDef. + /// Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "op" field. + public const int OpFieldNumber = 2; + private string op_ = ""; + /// + /// The operation name. There may be custom parameters in attrs. + /// Op names starting with an underscore are reserved for internal use. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Op { + get { return op_; } + set { + op_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "input" field. + public const int InputFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_input_codec + = pb::FieldCodec.ForString(26); + private readonly pbc::RepeatedField input_ = new pbc::RepeatedField(); + /// + /// Each input is "node:src_output" with "node" being a string name and + /// "src_output" indicating which output tensor to use from "node". If + /// "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + /// may optionally be followed by control inputs that have the format + /// "^node". + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Input { + get { return input_; } + } + + /// Field number for the "device" field. + public const int DeviceFieldNumber = 4; + private string device_ = ""; + /// + /// A (possibly partial) specification for the device on which this + /// node should be placed. + /// The expected syntax for this string is as follows: + /// + /// DEVICE_SPEC ::= PARTIAL_SPEC + /// + /// PARTIAL_SPEC ::= ("/" CONSTRAINT) * + /// CONSTRAINT ::= ("job:" JOB_NAME) + /// | ("replica:" [1-9][0-9]*) + /// | ("task:" [1-9][0-9]*) + /// | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) + /// + /// Valid values for this string include: + /// * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) + /// * "/job:worker/device:GPU:3" (partial specification) + /// * "" (no specification) + /// + /// If the constraints do not resolve to a single device (or if this + /// field is empty or not present), the runtime will attempt to + /// choose a device automatically. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Device { + get { return device_; } + set { + device_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "attr" field. + public const int AttrFieldNumber = 5; + private static readonly pbc::MapField.Codec _map_attr_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 42); + private readonly pbc::MapField attr_ = new pbc::MapField(); + /// + /// Operation-specific graph-construction-time configuration. + /// Note that this should include all attrs defined in the + /// corresponding OpDef, including those with a value matching + /// the default -- this allows the default to change and makes + /// NodeDefs easier to interpret on their own. However, if + /// an attr with a default is not specified in this list, the + /// default will be used. + /// The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and + /// one of the names from the corresponding OpDef's attr field). + /// The values must have a type matching the corresponding OpDef + /// attr's type field. + /// TODO(josh11b): Add some examples here showing best practices. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::MapField Attr { + get { return attr_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as NodeDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(NodeDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (Op != other.Op) return false; + if(!input_.Equals(other.input_)) return false; + if (Device != other.Device) return false; + if (!Attr.Equals(other.Attr)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (Op.Length != 0) hash ^= Op.GetHashCode(); + hash ^= input_.GetHashCode(); + if (Device.Length != 0) hash ^= Device.GetHashCode(); + hash ^= Attr.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (Op.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Op); + } + input_.WriteTo(output, _repeated_input_codec); + if (Device.Length != 0) { + output.WriteRawTag(34); + output.WriteString(Device); + } + attr_.WriteTo(output, _map_attr_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (Op.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Op); + } + size += input_.CalculateSize(_repeated_input_codec); + if (Device.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Device); + } + size += attr_.CalculateSize(_map_attr_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(NodeDef other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.Op.Length != 0) { + Op = other.Op; + } + input_.Add(other.input_); + if (other.Device.Length != 0) { + Device = other.Device; + } + attr_.Add(other.attr_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + Op = input.ReadString(); + break; + } + case 26: { + input_.AddEntriesFrom(input, _repeated_input_codec); + break; + } + case 34: { + Device = input.ReadString(); + break; + } + case 42: { + attr_.AddEntriesFrom(input, _map_attr_codec); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Tensorflow/README.md b/src/TensorFlowNET.Core/Tensorflow/README.md new file mode 100644 index 00000000..4b4cc3d3 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensorflow/README.md @@ -0,0 +1,12 @@ +### Download compiler from https://github.com/protocolbuffers/protobuf/releases +```shell +set SRC_DIR=D:\Projects\tensorflow\tensorflow\core\framework +set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Tensorflow + +.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% resource_handle.proto +.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor_shape.proto +.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% types.proto +.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor.proto +.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% attr_value.proto +.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% node_def.proto +``` \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensorflow/ResourceHandle.cs b/src/TensorFlowNET.Core/Tensorflow/ResourceHandle.cs new file mode 100644 index 00000000..b9c3033c --- /dev/null +++ b/src/TensorFlowNET.Core/Tensorflow/ResourceHandle.cs @@ -0,0 +1,311 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: resource_handle.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from resource_handle.proto + public static partial class ResourceHandleReflection { + + #region Descriptor + /// File descriptor for resource_handle.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static ResourceHandleReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "ChVyZXNvdXJjZV9oYW5kbGUucHJvdG8SCnRlbnNvcmZsb3cicgoTUmVzb3Vy", + "Y2VIYW5kbGVQcm90bxIOCgZkZXZpY2UYASABKAkSEQoJY29udGFpbmVyGAIg", + "ASgJEgwKBG5hbWUYAyABKAkSEQoJaGFzaF9jb2RlGAQgASgEEhcKD21heWJl", + "X3R5cGVfbmFtZRgFIAEoCUJuChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtC", + "DlJlc291cmNlSGFuZGxlUAFaPWdpdGh1Yi5jb20vdGVuc29yZmxvdy90ZW5z", + "b3JmbG93L3RlbnNvcmZsb3cvZ28vY29yZS9mcmFtZXdvcmv4AQFiBnByb3Rv", + "Mw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ResourceHandleProto), global::Tensorflow.ResourceHandleProto.Parser, new[]{ "Device", "Container", "Name", "HashCode", "MaybeTypeName" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Protocol buffer representing a handle to a tensorflow resource. Handles are + /// not valid across executions, but can be serialized back and forth from within + /// a single run. + /// + public sealed partial class ResourceHandleProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ResourceHandleProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ResourceHandleReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ResourceHandleProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ResourceHandleProto(ResourceHandleProto other) : this() { + device_ = other.device_; + container_ = other.container_; + name_ = other.name_; + hashCode_ = other.hashCode_; + maybeTypeName_ = other.maybeTypeName_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ResourceHandleProto Clone() { + return new ResourceHandleProto(this); + } + + /// Field number for the "device" field. + public const int DeviceFieldNumber = 1; + private string device_ = ""; + /// + /// Unique name for the device containing the resource. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Device { + get { return device_; } + set { + device_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "container" field. + public const int ContainerFieldNumber = 2; + private string container_ = ""; + /// + /// Container in which this resource is placed. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Container { + get { return container_; } + set { + container_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 3; + private string name_ = ""; + /// + /// Unique name of this resource. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "hash_code" field. + public const int HashCodeFieldNumber = 4; + private ulong hashCode_; + /// + /// Hash code for the type of the resource. Is only valid in the same device + /// and in the same execution. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ulong HashCode { + get { return hashCode_; } + set { + hashCode_ = value; + } + } + + /// Field number for the "maybe_type_name" field. + public const int MaybeTypeNameFieldNumber = 5; + private string maybeTypeName_ = ""; + /// + /// For debug-only, the name of the type pointed to by this handle, if + /// available. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string MaybeTypeName { + get { return maybeTypeName_; } + set { + maybeTypeName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ResourceHandleProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ResourceHandleProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Device != other.Device) return false; + if (Container != other.Container) return false; + if (Name != other.Name) return false; + if (HashCode != other.HashCode) return false; + if (MaybeTypeName != other.MaybeTypeName) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Device.Length != 0) hash ^= Device.GetHashCode(); + if (Container.Length != 0) hash ^= Container.GetHashCode(); + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (HashCode != 0UL) hash ^= HashCode.GetHashCode(); + if (MaybeTypeName.Length != 0) hash ^= MaybeTypeName.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Device.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Device); + } + if (Container.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Container); + } + if (Name.Length != 0) { + output.WriteRawTag(26); + output.WriteString(Name); + } + if (HashCode != 0UL) { + output.WriteRawTag(32); + output.WriteUInt64(HashCode); + } + if (MaybeTypeName.Length != 0) { + output.WriteRawTag(42); + output.WriteString(MaybeTypeName); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Device.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Device); + } + if (Container.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Container); + } + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (HashCode != 0UL) { + size += 1 + pb::CodedOutputStream.ComputeUInt64Size(HashCode); + } + if (MaybeTypeName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(MaybeTypeName); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ResourceHandleProto other) { + if (other == null) { + return; + } + if (other.Device.Length != 0) { + Device = other.Device; + } + if (other.Container.Length != 0) { + Container = other.Container; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.HashCode != 0UL) { + HashCode = other.HashCode; + } + if (other.MaybeTypeName.Length != 0) { + MaybeTypeName = other.MaybeTypeName; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Device = input.ReadString(); + break; + } + case 18: { + Container = input.ReadString(); + break; + } + case 26: { + Name = input.ReadString(); + break; + } + case 32: { + HashCode = input.ReadUInt64(); + break; + } + case 42: { + MaybeTypeName = input.ReadString(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Tensorflow/Tensor.cs b/src/TensorFlowNET.Core/Tensorflow/Tensor.cs new file mode 100644 index 00000000..1f333300 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensorflow/Tensor.cs @@ -0,0 +1,801 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensor.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensor.proto + public static partial class TensorReflection { + + #region Descriptor + /// File descriptor for tensor.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static TensorReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cgx0ZW5zb3IucHJvdG8SCnRlbnNvcmZsb3caFXJlc291cmNlX2hhbmRsZS5w", + "cm90bxoSdGVuc29yX3NoYXBlLnByb3RvGgt0eXBlcy5wcm90byKMBAoLVGVu", + "c29yUHJvdG8SIwoFZHR5cGUYASABKA4yFC50ZW5zb3JmbG93LkRhdGFUeXBl", + "EjIKDHRlbnNvcl9zaGFwZRgCIAEoCzIcLnRlbnNvcmZsb3cuVGVuc29yU2hh", + "cGVQcm90bxIWCg52ZXJzaW9uX251bWJlchgDIAEoBRIWCg50ZW5zb3JfY29u", + "dGVudBgEIAEoDBIUCghoYWxmX3ZhbBgNIAMoBUICEAESFQoJZmxvYXRfdmFs", + "GAUgAygCQgIQARIWCgpkb3VibGVfdmFsGAYgAygBQgIQARITCgdpbnRfdmFs", + "GAcgAygFQgIQARISCgpzdHJpbmdfdmFsGAggAygMEhgKDHNjb21wbGV4X3Zh", + "bBgJIAMoAkICEAESFQoJaW50NjRfdmFsGAogAygDQgIQARIUCghib29sX3Zh", + "bBgLIAMoCEICEAESGAoMZGNvbXBsZXhfdmFsGAwgAygBQgIQARI8ChNyZXNv", + "dXJjZV9oYW5kbGVfdmFsGA4gAygLMh8udGVuc29yZmxvdy5SZXNvdXJjZUhh", + "bmRsZVByb3RvEjcKC3ZhcmlhbnRfdmFsGA8gAygLMiIudGVuc29yZmxvdy5W", + "YXJpYW50VGVuc29yRGF0YVByb3RvEhYKCnVpbnQzMl92YWwYECADKA1CAhAB", + "EhYKCnVpbnQ2NF92YWwYESADKARCAhABImcKFlZhcmlhbnRUZW5zb3JEYXRh", + "UHJvdG8SEQoJdHlwZV9uYW1lGAEgASgJEhAKCG1ldGFkYXRhGAIgASgMEigK", + "B3RlbnNvcnMYAyADKAsyFy50ZW5zb3JmbG93LlRlbnNvclByb3RvQmwKGG9y", + "Zy50ZW5zb3JmbG93LmZyYW1ld29ya0IMVGVuc29yUHJvdG9zUAFaPWdpdGh1", + "Yi5jb20vdGVuc29yZmxvdy90ZW5zb3JmbG93L3RlbnNvcmZsb3cvZ28vY29y", + "ZS9mcmFtZXdvcmv4AQFiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.ResourceHandleReflection.Descriptor, global::Tensorflow.TensorShapeReflection.Descriptor, global::Tensorflow.TypesReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorProto), global::Tensorflow.TensorProto.Parser, new[]{ "Dtype", "TensorShape", "VersionNumber", "TensorContent", "HalfVal", "FloatVal", "DoubleVal", "IntVal", "StringVal", "ScomplexVal", "Int64Val", "BoolVal", "DcomplexVal", "ResourceHandleVal", "VariantVal", "Uint32Val", "Uint64Val" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.VariantTensorDataProto), global::Tensorflow.VariantTensorDataProto.Parser, new[]{ "TypeName", "Metadata", "Tensors" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Protocol buffer representing a tensor. + /// + public sealed partial class TensorProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TensorProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.TensorReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TensorProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TensorProto(TensorProto other) : this() { + dtype_ = other.dtype_; + tensorShape_ = other.tensorShape_ != null ? other.tensorShape_.Clone() : null; + versionNumber_ = other.versionNumber_; + tensorContent_ = other.tensorContent_; + halfVal_ = other.halfVal_.Clone(); + floatVal_ = other.floatVal_.Clone(); + doubleVal_ = other.doubleVal_.Clone(); + intVal_ = other.intVal_.Clone(); + stringVal_ = other.stringVal_.Clone(); + scomplexVal_ = other.scomplexVal_.Clone(); + int64Val_ = other.int64Val_.Clone(); + boolVal_ = other.boolVal_.Clone(); + dcomplexVal_ = other.dcomplexVal_.Clone(); + resourceHandleVal_ = other.resourceHandleVal_.Clone(); + variantVal_ = other.variantVal_.Clone(); + uint32Val_ = other.uint32Val_.Clone(); + uint64Val_ = other.uint64Val_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TensorProto Clone() { + return new TensorProto(this); + } + + /// Field number for the "dtype" field. + public const int DtypeFieldNumber = 1; + private global::Tensorflow.DataType dtype_ = 0; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.DataType Dtype { + get { return dtype_; } + set { + dtype_ = value; + } + } + + /// Field number for the "tensor_shape" field. + public const int TensorShapeFieldNumber = 2; + private global::Tensorflow.TensorShapeProto tensorShape_; + /// + /// Shape of the tensor. TODO(touts): sort out the 0-rank issues. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.TensorShapeProto TensorShape { + get { return tensorShape_; } + set { + tensorShape_ = value; + } + } + + /// Field number for the "version_number" field. + public const int VersionNumberFieldNumber = 3; + private int versionNumber_; + /// + /// Version number. + /// + /// In version 0, if the "repeated xxx" representations contain only one + /// element, that element is repeated to fill the shape. This makes it easy + /// to represent a constant Tensor with a single value. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int VersionNumber { + get { return versionNumber_; } + set { + versionNumber_ = value; + } + } + + /// Field number for the "tensor_content" field. + public const int TensorContentFieldNumber = 4; + private pb::ByteString tensorContent_ = pb::ByteString.Empty; + /// + /// Serialized raw tensor content from either Tensor::AsProtoTensorContent or + /// memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation + /// can be used for all tensor types. The purpose of this representation is to + /// reduce serialization overhead during RPC call by avoiding serialization of + /// many repeated small items. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pb::ByteString TensorContent { + get { return tensorContent_; } + set { + tensorContent_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "half_val" field. + public const int HalfValFieldNumber = 13; + private static readonly pb::FieldCodec _repeated_halfVal_codec + = pb::FieldCodec.ForInt32(106); + private readonly pbc::RepeatedField halfVal_ = new pbc::RepeatedField(); + /// + /// DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll + /// have some pointless zero padding for each value here. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField HalfVal { + get { return halfVal_; } + } + + /// Field number for the "float_val" field. + public const int FloatValFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_floatVal_codec + = pb::FieldCodec.ForFloat(42); + private readonly pbc::RepeatedField floatVal_ = new pbc::RepeatedField(); + /// + /// DT_FLOAT. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField FloatVal { + get { return floatVal_; } + } + + /// Field number for the "double_val" field. + public const int DoubleValFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_doubleVal_codec + = pb::FieldCodec.ForDouble(50); + private readonly pbc::RepeatedField doubleVal_ = new pbc::RepeatedField(); + /// + /// DT_DOUBLE. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField DoubleVal { + get { return doubleVal_; } + } + + /// Field number for the "int_val" field. + public const int IntValFieldNumber = 7; + private static readonly pb::FieldCodec _repeated_intVal_codec + = pb::FieldCodec.ForInt32(58); + private readonly pbc::RepeatedField intVal_ = new pbc::RepeatedField(); + /// + /// DT_INT32, DT_INT16, DT_INT8, DT_UINT8. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField IntVal { + get { return intVal_; } + } + + /// Field number for the "string_val" field. + public const int StringValFieldNumber = 8; + private static readonly pb::FieldCodec _repeated_stringVal_codec + = pb::FieldCodec.ForBytes(66); + private readonly pbc::RepeatedField stringVal_ = new pbc::RepeatedField(); + /// + /// DT_STRING + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField StringVal { + get { return stringVal_; } + } + + /// Field number for the "scomplex_val" field. + public const int ScomplexValFieldNumber = 9; + private static readonly pb::FieldCodec _repeated_scomplexVal_codec + = pb::FieldCodec.ForFloat(74); + private readonly pbc::RepeatedField scomplexVal_ = new pbc::RepeatedField(); + /// + /// DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real + /// and imaginary parts of i-th single precision complex. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField ScomplexVal { + get { return scomplexVal_; } + } + + /// Field number for the "int64_val" field. + public const int Int64ValFieldNumber = 10; + private static readonly pb::FieldCodec _repeated_int64Val_codec + = pb::FieldCodec.ForInt64(82); + private readonly pbc::RepeatedField int64Val_ = new pbc::RepeatedField(); + /// + /// DT_INT64 + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Int64Val { + get { return int64Val_; } + } + + /// Field number for the "bool_val" field. + public const int BoolValFieldNumber = 11; + private static readonly pb::FieldCodec _repeated_boolVal_codec + = pb::FieldCodec.ForBool(90); + private readonly pbc::RepeatedField boolVal_ = new pbc::RepeatedField(); + /// + /// DT_BOOL + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField BoolVal { + get { return boolVal_; } + } + + /// Field number for the "dcomplex_val" field. + public const int DcomplexValFieldNumber = 12; + private static readonly pb::FieldCodec _repeated_dcomplexVal_codec + = pb::FieldCodec.ForDouble(98); + private readonly pbc::RepeatedField dcomplexVal_ = new pbc::RepeatedField(); + /// + /// DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real + /// and imaginary parts of i-th double precision complex. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField DcomplexVal { + get { return dcomplexVal_; } + } + + /// Field number for the "resource_handle_val" field. + public const int ResourceHandleValFieldNumber = 14; + private static readonly pb::FieldCodec _repeated_resourceHandleVal_codec + = pb::FieldCodec.ForMessage(114, global::Tensorflow.ResourceHandleProto.Parser); + private readonly pbc::RepeatedField resourceHandleVal_ = new pbc::RepeatedField(); + /// + /// DT_RESOURCE + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField ResourceHandleVal { + get { return resourceHandleVal_; } + } + + /// Field number for the "variant_val" field. + public const int VariantValFieldNumber = 15; + private static readonly pb::FieldCodec _repeated_variantVal_codec + = pb::FieldCodec.ForMessage(122, global::Tensorflow.VariantTensorDataProto.Parser); + private readonly pbc::RepeatedField variantVal_ = new pbc::RepeatedField(); + /// + /// DT_VARIANT + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField VariantVal { + get { return variantVal_; } + } + + /// Field number for the "uint32_val" field. + public const int Uint32ValFieldNumber = 16; + private static readonly pb::FieldCodec _repeated_uint32Val_codec + = pb::FieldCodec.ForUInt32(130); + private readonly pbc::RepeatedField uint32Val_ = new pbc::RepeatedField(); + /// + /// DT_UINT32 + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Uint32Val { + get { return uint32Val_; } + } + + /// Field number for the "uint64_val" field. + public const int Uint64ValFieldNumber = 17; + private static readonly pb::FieldCodec _repeated_uint64Val_codec + = pb::FieldCodec.ForUInt64(138); + private readonly pbc::RepeatedField uint64Val_ = new pbc::RepeatedField(); + /// + /// DT_UINT64 + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Uint64Val { + get { return uint64Val_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as TensorProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(TensorProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Dtype != other.Dtype) return false; + if (!object.Equals(TensorShape, other.TensorShape)) return false; + if (VersionNumber != other.VersionNumber) return false; + if (TensorContent != other.TensorContent) return false; + if(!halfVal_.Equals(other.halfVal_)) return false; + if(!floatVal_.Equals(other.floatVal_)) return false; + if(!doubleVal_.Equals(other.doubleVal_)) return false; + if(!intVal_.Equals(other.intVal_)) return false; + if(!stringVal_.Equals(other.stringVal_)) return false; + if(!scomplexVal_.Equals(other.scomplexVal_)) return false; + if(!int64Val_.Equals(other.int64Val_)) return false; + if(!boolVal_.Equals(other.boolVal_)) return false; + if(!dcomplexVal_.Equals(other.dcomplexVal_)) return false; + if(!resourceHandleVal_.Equals(other.resourceHandleVal_)) return false; + if(!variantVal_.Equals(other.variantVal_)) return false; + if(!uint32Val_.Equals(other.uint32Val_)) return false; + if(!uint64Val_.Equals(other.uint64Val_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Dtype != 0) hash ^= Dtype.GetHashCode(); + if (tensorShape_ != null) hash ^= TensorShape.GetHashCode(); + if (VersionNumber != 0) hash ^= VersionNumber.GetHashCode(); + if (TensorContent.Length != 0) hash ^= TensorContent.GetHashCode(); + hash ^= halfVal_.GetHashCode(); + hash ^= floatVal_.GetHashCode(); + hash ^= doubleVal_.GetHashCode(); + hash ^= intVal_.GetHashCode(); + hash ^= stringVal_.GetHashCode(); + hash ^= scomplexVal_.GetHashCode(); + hash ^= int64Val_.GetHashCode(); + hash ^= boolVal_.GetHashCode(); + hash ^= dcomplexVal_.GetHashCode(); + hash ^= resourceHandleVal_.GetHashCode(); + hash ^= variantVal_.GetHashCode(); + hash ^= uint32Val_.GetHashCode(); + hash ^= uint64Val_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Dtype != 0) { + output.WriteRawTag(8); + output.WriteEnum((int) Dtype); + } + if (tensorShape_ != null) { + output.WriteRawTag(18); + output.WriteMessage(TensorShape); + } + if (VersionNumber != 0) { + output.WriteRawTag(24); + output.WriteInt32(VersionNumber); + } + if (TensorContent.Length != 0) { + output.WriteRawTag(34); + output.WriteBytes(TensorContent); + } + floatVal_.WriteTo(output, _repeated_floatVal_codec); + doubleVal_.WriteTo(output, _repeated_doubleVal_codec); + intVal_.WriteTo(output, _repeated_intVal_codec); + stringVal_.WriteTo(output, _repeated_stringVal_codec); + scomplexVal_.WriteTo(output, _repeated_scomplexVal_codec); + int64Val_.WriteTo(output, _repeated_int64Val_codec); + boolVal_.WriteTo(output, _repeated_boolVal_codec); + dcomplexVal_.WriteTo(output, _repeated_dcomplexVal_codec); + halfVal_.WriteTo(output, _repeated_halfVal_codec); + resourceHandleVal_.WriteTo(output, _repeated_resourceHandleVal_codec); + variantVal_.WriteTo(output, _repeated_variantVal_codec); + uint32Val_.WriteTo(output, _repeated_uint32Val_codec); + uint64Val_.WriteTo(output, _repeated_uint64Val_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Dtype != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Dtype); + } + if (tensorShape_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(TensorShape); + } + if (VersionNumber != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(VersionNumber); + } + if (TensorContent.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(TensorContent); + } + size += halfVal_.CalculateSize(_repeated_halfVal_codec); + size += floatVal_.CalculateSize(_repeated_floatVal_codec); + size += doubleVal_.CalculateSize(_repeated_doubleVal_codec); + size += intVal_.CalculateSize(_repeated_intVal_codec); + size += stringVal_.CalculateSize(_repeated_stringVal_codec); + size += scomplexVal_.CalculateSize(_repeated_scomplexVal_codec); + size += int64Val_.CalculateSize(_repeated_int64Val_codec); + size += boolVal_.CalculateSize(_repeated_boolVal_codec); + size += dcomplexVal_.CalculateSize(_repeated_dcomplexVal_codec); + size += resourceHandleVal_.CalculateSize(_repeated_resourceHandleVal_codec); + size += variantVal_.CalculateSize(_repeated_variantVal_codec); + size += uint32Val_.CalculateSize(_repeated_uint32Val_codec); + size += uint64Val_.CalculateSize(_repeated_uint64Val_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(TensorProto other) { + if (other == null) { + return; + } + if (other.Dtype != 0) { + Dtype = other.Dtype; + } + if (other.tensorShape_ != null) { + if (tensorShape_ == null) { + tensorShape_ = new global::Tensorflow.TensorShapeProto(); + } + TensorShape.MergeFrom(other.TensorShape); + } + if (other.VersionNumber != 0) { + VersionNumber = other.VersionNumber; + } + if (other.TensorContent.Length != 0) { + TensorContent = other.TensorContent; + } + halfVal_.Add(other.halfVal_); + floatVal_.Add(other.floatVal_); + doubleVal_.Add(other.doubleVal_); + intVal_.Add(other.intVal_); + stringVal_.Add(other.stringVal_); + scomplexVal_.Add(other.scomplexVal_); + int64Val_.Add(other.int64Val_); + boolVal_.Add(other.boolVal_); + dcomplexVal_.Add(other.dcomplexVal_); + resourceHandleVal_.Add(other.resourceHandleVal_); + variantVal_.Add(other.variantVal_); + uint32Val_.Add(other.uint32Val_); + uint64Val_.Add(other.uint64Val_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + dtype_ = (global::Tensorflow.DataType) input.ReadEnum(); + break; + } + case 18: { + if (tensorShape_ == null) { + tensorShape_ = new global::Tensorflow.TensorShapeProto(); + } + input.ReadMessage(tensorShape_); + break; + } + case 24: { + VersionNumber = input.ReadInt32(); + break; + } + case 34: { + TensorContent = input.ReadBytes(); + break; + } + case 42: + case 45: { + floatVal_.AddEntriesFrom(input, _repeated_floatVal_codec); + break; + } + case 50: + case 49: { + doubleVal_.AddEntriesFrom(input, _repeated_doubleVal_codec); + break; + } + case 58: + case 56: { + intVal_.AddEntriesFrom(input, _repeated_intVal_codec); + break; + } + case 66: { + stringVal_.AddEntriesFrom(input, _repeated_stringVal_codec); + break; + } + case 74: + case 77: { + scomplexVal_.AddEntriesFrom(input, _repeated_scomplexVal_codec); + break; + } + case 82: + case 80: { + int64Val_.AddEntriesFrom(input, _repeated_int64Val_codec); + break; + } + case 90: + case 88: { + boolVal_.AddEntriesFrom(input, _repeated_boolVal_codec); + break; + } + case 98: + case 97: { + dcomplexVal_.AddEntriesFrom(input, _repeated_dcomplexVal_codec); + break; + } + case 106: + case 104: { + halfVal_.AddEntriesFrom(input, _repeated_halfVal_codec); + break; + } + case 114: { + resourceHandleVal_.AddEntriesFrom(input, _repeated_resourceHandleVal_codec); + break; + } + case 122: { + variantVal_.AddEntriesFrom(input, _repeated_variantVal_codec); + break; + } + case 130: + case 128: { + uint32Val_.AddEntriesFrom(input, _repeated_uint32Val_codec); + break; + } + case 138: + case 136: { + uint64Val_.AddEntriesFrom(input, _repeated_uint64Val_codec); + break; + } + } + } + } + + } + + /// + /// Protocol buffer representing the serialization format of DT_VARIANT tensors. + /// + public sealed partial class VariantTensorDataProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new VariantTensorDataProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.TensorReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VariantTensorDataProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VariantTensorDataProto(VariantTensorDataProto other) : this() { + typeName_ = other.typeName_; + metadata_ = other.metadata_; + tensors_ = other.tensors_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VariantTensorDataProto Clone() { + return new VariantTensorDataProto(this); + } + + /// Field number for the "type_name" field. + public const int TypeNameFieldNumber = 1; + private string typeName_ = ""; + /// + /// Name of the type of objects being serialized. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string TypeName { + get { return typeName_; } + set { + typeName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "metadata" field. + public const int MetadataFieldNumber = 2; + private pb::ByteString metadata_ = pb::ByteString.Empty; + /// + /// Portions of the object that are not Tensors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pb::ByteString Metadata { + get { return metadata_; } + set { + metadata_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "tensors" field. + public const int TensorsFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_tensors_codec + = pb::FieldCodec.ForMessage(26, global::Tensorflow.TensorProto.Parser); + private readonly pbc::RepeatedField tensors_ = new pbc::RepeatedField(); + /// + /// Tensors contained within objects being serialized. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Tensors { + get { return tensors_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as VariantTensorDataProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(VariantTensorDataProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (TypeName != other.TypeName) return false; + if (Metadata != other.Metadata) return false; + if(!tensors_.Equals(other.tensors_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (TypeName.Length != 0) hash ^= TypeName.GetHashCode(); + if (Metadata.Length != 0) hash ^= Metadata.GetHashCode(); + hash ^= tensors_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (TypeName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(TypeName); + } + if (Metadata.Length != 0) { + output.WriteRawTag(18); + output.WriteBytes(Metadata); + } + tensors_.WriteTo(output, _repeated_tensors_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (TypeName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(TypeName); + } + if (Metadata.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeBytesSize(Metadata); + } + size += tensors_.CalculateSize(_repeated_tensors_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(VariantTensorDataProto other) { + if (other == null) { + return; + } + if (other.TypeName.Length != 0) { + TypeName = other.TypeName; + } + if (other.Metadata.Length != 0) { + Metadata = other.Metadata; + } + tensors_.Add(other.tensors_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + TypeName = input.ReadString(); + break; + } + case 18: { + Metadata = input.ReadBytes(); + break; + } + case 26: { + tensors_.AddEntriesFrom(input, _repeated_tensors_codec); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Tensorflow/TensorShape.cs b/src/TensorFlowNET.Core/Tensorflow/TensorShape.cs new file mode 100644 index 00000000..0a891dd0 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensorflow/TensorShape.cs @@ -0,0 +1,397 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensor_shape.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensor_shape.proto + public static partial class TensorShapeReflection { + + #region Descriptor + /// File descriptor for tensor_shape.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static TensorShapeReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "ChJ0ZW5zb3Jfc2hhcGUucHJvdG8SCnRlbnNvcmZsb3ciegoQVGVuc29yU2hh", + "cGVQcm90bxItCgNkaW0YAiADKAsyIC50ZW5zb3JmbG93LlRlbnNvclNoYXBl", + "UHJvdG8uRGltEhQKDHVua25vd25fcmFuaxgDIAEoCBohCgNEaW0SDAoEc2l6", + "ZRgBIAEoAxIMCgRuYW1lGAIgASgJQnEKGG9yZy50ZW5zb3JmbG93LmZyYW1l", + "d29ya0IRVGVuc29yU2hhcGVQcm90b3NQAVo9Z2l0aHViLmNvbS90ZW5zb3Jm", + "bG93L3RlbnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29ya/gB", + "AWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorShapeProto), global::Tensorflow.TensorShapeProto.Parser, new[]{ "Dim", "UnknownRank" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorShapeProto.Types.Dim), global::Tensorflow.TensorShapeProto.Types.Dim.Parser, new[]{ "Size", "Name" }, null, null, null)}) + })); + } + #endregion + + } + #region Messages + /// + /// Dimensions of a tensor. + /// + public sealed partial class TensorShapeProto : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TensorShapeProto()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.TensorShapeReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TensorShapeProto() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TensorShapeProto(TensorShapeProto other) : this() { + dim_ = other.dim_.Clone(); + unknownRank_ = other.unknownRank_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TensorShapeProto Clone() { + return new TensorShapeProto(this); + } + + /// Field number for the "dim" field. + public const int DimFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_dim_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.TensorShapeProto.Types.Dim.Parser); + private readonly pbc::RepeatedField dim_ = new pbc::RepeatedField(); + /// + /// Dimensions of the tensor, such as {"input", 30}, {"output", 40} + /// for a 30 x 40 2D tensor. If an entry has size -1, this + /// corresponds to a dimension of unknown size. The names are + /// optional. + /// + /// The order of entries in "dim" matters: It indicates the layout of the + /// values in the tensor in-memory representation. + /// + /// The first entry in "dim" is the outermost dimension used to layout the + /// values, the last entry is the innermost dimension. This matches the + /// in-memory layout of RowMajor Eigen tensors. + /// + /// If "dim.size()" > 0, "unknown_rank" must be false. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Dim { + get { return dim_; } + } + + /// Field number for the "unknown_rank" field. + public const int UnknownRankFieldNumber = 3; + private bool unknownRank_; + /// + /// If true, the number of dimensions in the shape is unknown. + /// + /// If true, "dim.size()" must be 0. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UnknownRank { + get { return unknownRank_; } + set { + unknownRank_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as TensorShapeProto); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(TensorShapeProto other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!dim_.Equals(other.dim_)) return false; + if (UnknownRank != other.UnknownRank) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= dim_.GetHashCode(); + if (UnknownRank != false) hash ^= UnknownRank.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + dim_.WriteTo(output, _repeated_dim_codec); + if (UnknownRank != false) { + output.WriteRawTag(24); + output.WriteBool(UnknownRank); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += dim_.CalculateSize(_repeated_dim_codec); + if (UnknownRank != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(TensorShapeProto other) { + if (other == null) { + return; + } + dim_.Add(other.dim_); + if (other.UnknownRank != false) { + UnknownRank = other.UnknownRank; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 18: { + dim_.AddEntriesFrom(input, _repeated_dim_codec); + break; + } + case 24: { + UnknownRank = input.ReadBool(); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the TensorShapeProto message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + /// + /// One dimension of the tensor. + /// + public sealed partial class Dim : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Dim()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.TensorShapeProto.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Dim() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Dim(Dim other) : this() { + size_ = other.size_; + name_ = other.name_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Dim Clone() { + return new Dim(this); + } + + /// Field number for the "size" field. + public const int SizeFieldNumber = 1; + private long size_; + /// + /// Size of the tensor in that dimension. + /// This value must be >= -1, but values of -1 are reserved for "unknown" + /// shapes (values of -1 mean "unknown" dimension). Certain wrappers + /// that work with TensorShapeProto may fail at runtime when deserializing + /// a TensorShapeProto containing a dim value of -1. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public long Size { + get { return size_; } + set { + size_ = value; + } + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 2; + private string name_ = ""; + /// + /// Optional name of the tensor dimension. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as Dim); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(Dim other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Size != other.Size) return false; + if (Name != other.Name) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Size != 0L) hash ^= Size.GetHashCode(); + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Size != 0L) { + output.WriteRawTag(8); + output.WriteInt64(Size); + } + if (Name.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Name); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Size != 0L) { + size += 1 + pb::CodedOutputStream.ComputeInt64Size(Size); + } + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(Dim other) { + if (other == null) { + return; + } + if (other.Size != 0L) { + Size = other.Size; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Size = input.ReadInt64(); + break; + } + case 18: { + Name = input.ReadString(); + break; + } + } + } + } + + } + + } + #endregion + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Tensorflow/Types.cs b/src/TensorFlowNET.Core/Tensorflow/Types.cs new file mode 100644 index 00000000..887ff322 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensorflow/Types.cs @@ -0,0 +1,153 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: types.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from types.proto + public static partial class TypesReflection { + + #region Descriptor + /// File descriptor for types.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static TypesReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cgt0eXBlcy5wcm90bxIKdGVuc29yZmxvdyqqBgoIRGF0YVR5cGUSDgoKRFRf", + "SU5WQUxJRBAAEgwKCERUX0ZMT0FUEAESDQoJRFRfRE9VQkxFEAISDAoIRFRf", + "SU5UMzIQAxIMCghEVF9VSU5UOBAEEgwKCERUX0lOVDE2EAUSCwoHRFRfSU5U", + "OBAGEg0KCURUX1NUUklORxAHEhAKDERUX0NPTVBMRVg2NBAIEgwKCERUX0lO", + "VDY0EAkSCwoHRFRfQk9PTBAKEgwKCERUX1FJTlQ4EAsSDQoJRFRfUVVJTlQ4", + "EAwSDQoJRFRfUUlOVDMyEA0SDwoLRFRfQkZMT0FUMTYQDhINCglEVF9RSU5U", + "MTYQDxIOCgpEVF9RVUlOVDE2EBASDQoJRFRfVUlOVDE2EBESEQoNRFRfQ09N", + "UExFWDEyOBASEgsKB0RUX0hBTEYQExIPCgtEVF9SRVNPVVJDRRAUEg4KCkRU", + "X1ZBUklBTlQQFRINCglEVF9VSU5UMzIQFhINCglEVF9VSU5UNjQQFxIQCgxE", + "VF9GTE9BVF9SRUYQZRIRCg1EVF9ET1VCTEVfUkVGEGYSEAoMRFRfSU5UMzJf", + "UkVGEGcSEAoMRFRfVUlOVDhfUkVGEGgSEAoMRFRfSU5UMTZfUkVGEGkSDwoL", + "RFRfSU5UOF9SRUYQahIRCg1EVF9TVFJJTkdfUkVGEGsSFAoQRFRfQ09NUExF", + "WDY0X1JFRhBsEhAKDERUX0lOVDY0X1JFRhBtEg8KC0RUX0JPT0xfUkVGEG4S", + "EAoMRFRfUUlOVDhfUkVGEG8SEQoNRFRfUVVJTlQ4X1JFRhBwEhEKDURUX1FJ", + "TlQzMl9SRUYQcRITCg9EVF9CRkxPQVQxNl9SRUYQchIRCg1EVF9RSU5UMTZf", + "UkVGEHMSEgoORFRfUVVJTlQxNl9SRUYQdBIRCg1EVF9VSU5UMTZfUkVGEHUS", + "FQoRRFRfQ09NUExFWDEyOF9SRUYQdhIPCgtEVF9IQUxGX1JFRhB3EhMKD0RU", + "X1JFU09VUkNFX1JFRhB4EhIKDkRUX1ZBUklBTlRfUkVGEHkSEQoNRFRfVUlO", + "VDMyX1JFRhB6EhEKDURUX1VJTlQ2NF9SRUYQe0JrChhvcmcudGVuc29yZmxv", + "dy5mcmFtZXdvcmtCC1R5cGVzUHJvdG9zUAFaPWdpdGh1Yi5jb20vdGVuc29y", + "Zmxvdy90ZW5zb3JmbG93L3RlbnNvcmZsb3cvZ28vY29yZS9mcmFtZXdvcmv4", + "AQFiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Tensorflow.DataType), }, null)); + } + #endregion + + } + #region Enums + /// + /// LINT.IfChange + /// + public enum DataType { + /// + /// Not a legal value for DataType. Used to indicate a DataType field + /// has not been set. + /// + [pbr::OriginalName("DT_INVALID")] DtInvalid = 0, + /// + /// Data types that all computation devices are expected to be + /// capable to support. + /// + [pbr::OriginalName("DT_FLOAT")] DtFloat = 1, + [pbr::OriginalName("DT_DOUBLE")] DtDouble = 2, + [pbr::OriginalName("DT_INT32")] DtInt32 = 3, + [pbr::OriginalName("DT_UINT8")] DtUint8 = 4, + [pbr::OriginalName("DT_INT16")] DtInt16 = 5, + [pbr::OriginalName("DT_INT8")] DtInt8 = 6, + [pbr::OriginalName("DT_STRING")] DtString = 7, + /// + /// Single-precision complex + /// + [pbr::OriginalName("DT_COMPLEX64")] DtComplex64 = 8, + [pbr::OriginalName("DT_INT64")] DtInt64 = 9, + [pbr::OriginalName("DT_BOOL")] DtBool = 10, + /// + /// Quantized int8 + /// + [pbr::OriginalName("DT_QINT8")] DtQint8 = 11, + /// + /// Quantized uint8 + /// + [pbr::OriginalName("DT_QUINT8")] DtQuint8 = 12, + /// + /// Quantized int32 + /// + [pbr::OriginalName("DT_QINT32")] DtQint32 = 13, + /// + /// Float32 truncated to 16 bits. Only for cast ops. + /// + [pbr::OriginalName("DT_BFLOAT16")] DtBfloat16 = 14, + /// + /// Quantized int16 + /// + [pbr::OriginalName("DT_QINT16")] DtQint16 = 15, + /// + /// Quantized uint16 + /// + [pbr::OriginalName("DT_QUINT16")] DtQuint16 = 16, + [pbr::OriginalName("DT_UINT16")] DtUint16 = 17, + /// + /// Double-precision complex + /// + [pbr::OriginalName("DT_COMPLEX128")] DtComplex128 = 18, + [pbr::OriginalName("DT_HALF")] DtHalf = 19, + [pbr::OriginalName("DT_RESOURCE")] DtResource = 20, + /// + /// Arbitrary C++ data types + /// + [pbr::OriginalName("DT_VARIANT")] DtVariant = 21, + [pbr::OriginalName("DT_UINT32")] DtUint32 = 22, + [pbr::OriginalName("DT_UINT64")] DtUint64 = 23, + /// + /// Do not use! These are only for parameters. Every enum above + /// should have a corresponding value below (verified by types_test). + /// + [pbr::OriginalName("DT_FLOAT_REF")] DtFloatRef = 101, + [pbr::OriginalName("DT_DOUBLE_REF")] DtDoubleRef = 102, + [pbr::OriginalName("DT_INT32_REF")] DtInt32Ref = 103, + [pbr::OriginalName("DT_UINT8_REF")] DtUint8Ref = 104, + [pbr::OriginalName("DT_INT16_REF")] DtInt16Ref = 105, + [pbr::OriginalName("DT_INT8_REF")] DtInt8Ref = 106, + [pbr::OriginalName("DT_STRING_REF")] DtStringRef = 107, + [pbr::OriginalName("DT_COMPLEX64_REF")] DtComplex64Ref = 108, + [pbr::OriginalName("DT_INT64_REF")] DtInt64Ref = 109, + [pbr::OriginalName("DT_BOOL_REF")] DtBoolRef = 110, + [pbr::OriginalName("DT_QINT8_REF")] DtQint8Ref = 111, + [pbr::OriginalName("DT_QUINT8_REF")] DtQuint8Ref = 112, + [pbr::OriginalName("DT_QINT32_REF")] DtQint32Ref = 113, + [pbr::OriginalName("DT_BFLOAT16_REF")] DtBfloat16Ref = 114, + [pbr::OriginalName("DT_QINT16_REF")] DtQint16Ref = 115, + [pbr::OriginalName("DT_QUINT16_REF")] DtQuint16Ref = 116, + [pbr::OriginalName("DT_UINT16_REF")] DtUint16Ref = 117, + [pbr::OriginalName("DT_COMPLEX128_REF")] DtComplex128Ref = 118, + [pbr::OriginalName("DT_HALF_REF")] DtHalfRef = 119, + [pbr::OriginalName("DT_RESOURCE_REF")] DtResourceRef = 120, + [pbr::OriginalName("DT_VARIANT_REF")] DtVariantRef = 121, + [pbr::OriginalName("DT_UINT32_REF")] DtUint32Ref = 122, + [pbr::OriginalName("DT_UINT64_REF")] DtUint64Ref = 123, + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/c_api.cs b/src/TensorFlowNET.Core/c_api.cs index b3303ba4..f979c827 100644 --- a/src/TensorFlowNET.Core/c_api.cs +++ b/src/TensorFlowNET.Core/c_api.cs @@ -10,6 +10,8 @@ using TF_Operation = System.IntPtr; using TF_Status = System.IntPtr; using TF_Tensor = System.IntPtr; +using TF_DataType = Tensorflow.DataType; + using static TensorFlowNET.Core.Tensorflow; namespace TensorFlowNET.Core diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 252fdbf6..879e571f 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -4,6 +4,7 @@ using System.Runtime.InteropServices; using System.Text; using System.Threading; using tf = TensorFlowNET.Core.Tensorflow; +using TF_DataType = Tensorflow.DataType; namespace TensorFlowNET.Core { @@ -26,8 +27,8 @@ namespace TensorFlowNET.Core case double value: var v = (double*)Marshal.AllocHGlobal(sizeof(double)); *v = value; - tensor = c_api.TF_NewTensor(TF_DataType.TF_DOUBLE, 0, 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: Tensorflow.FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); - c_api.TF_SetAttrType(op_desc, "dtype", TF_DataType.TF_DOUBLE); + tensor = c_api.TF_NewTensor(TF_DataType.DtDouble, 0, 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: Tensorflow.FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); + c_api.TF_SetAttrType(op_desc, "dtype", TF_DataType.DtDouble); break; }