@@ -39,6 +39,11 @@ namespace Tensorflow | |||
return buffer._handle; | |||
} | |||
public static implicit operator byte[](Buffer buffer) | |||
{ | |||
return buffer.Data; | |||
} | |||
public void Dispose() | |||
{ | |||
c_api.TF_DeleteBuffer(_handle); | |||
@@ -38,6 +38,16 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | |||
/// <summary> | |||
/// Write out a serialized representation of `graph` (as a GraphDef protocol | |||
/// message) to `output_graph_def` (allocated by TF_NewBuffer()). | |||
/// </summary> | |||
/// <param name="graph"></param> | |||
/// <param name="output_graph_def"></param> | |||
/// <param name="status"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphToGraphDef(IntPtr graph, IntPtr output_graph_def, IntPtr status); | |||
/// <summary> | |||
/// Returns the number of dimensions of the Tensor referenced by `output` | |||
/// in `graph`. | |||
@@ -26,15 +26,15 @@ namespace Tensorflow | |||
public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status); | |||
public int NumInputs => c_api.TF_OperationNumInputs(_handle); | |||
public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); | |||
public TF_Input[] OutputConsumers(int index, int max_consumers) | |||
public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | |||
{ | |||
IntPtr handle = IntPtr.Zero; | |||
int size = Marshal.SizeOf<TF_Input>(); | |||
int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), ref handle, max_consumers); | |||
var handle = (TF_Input*)Marshal.AllocHGlobal(size); | |||
int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); | |||
var consumers = new TF_Input[num]; | |||
for(int i = 0; i < num; i++) | |||
{ | |||
consumers[0] = Marshal.PtrToStructure<TF_Input>(handle + i * size); | |||
consumers[i] = new TF_Input((*handle).oper + i * size, (*handle).index); | |||
} | |||
return consumers; | |||
@@ -112,7 +112,7 @@ namespace Tensorflow | |||
/// <param name="max_consumers"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_OperationOutputConsumers(TF_Output oper_out, ref IntPtr consumers, int max_consumers); | |||
public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input * consumers, int max_consumers); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); | |||
@@ -0,0 +1,604 @@ | |||
// <auto-generated> | |||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||
// source: function.proto | |||
// </auto-generated> | |||
#pragma warning disable 1591, 0612, 3021 | |||
#region Designer generated code | |||
using pb = global::Google.Protobuf; | |||
using pbc = global::Google.Protobuf.Collections; | |||
using pbr = global::Google.Protobuf.Reflection; | |||
using scg = global::System.Collections.Generic; | |||
namespace Tensorflow { | |||
/// <summary>Holder for reflection information generated from function.proto</summary> | |||
public static partial class FunctionReflection { | |||
#region Descriptor | |||
/// <summary>File descriptor for function.proto</summary> | |||
public static pbr::FileDescriptor Descriptor { | |||
get { return descriptor; } | |||
} | |||
private static pbr::FileDescriptor descriptor; | |||
static FunctionReflection() { | |||
byte[] descriptorData = global::System.Convert.FromBase64String( | |||
string.Concat( | |||
"Cg5mdW5jdGlvbi5wcm90bxIKdGVuc29yZmxvdxoQYXR0cl92YWx1ZS5wcm90", | |||
"bxoObm9kZV9kZWYucHJvdG8aDG9wX2RlZi5wcm90byJqChJGdW5jdGlvbkRl", | |||
"ZkxpYnJhcnkSKQoIZnVuY3Rpb24YASADKAsyFy50ZW5zb3JmbG93LkZ1bmN0", | |||
"aW9uRGVmEikKCGdyYWRpZW50GAIgAygLMhcudGVuc29yZmxvdy5HcmFkaWVu", | |||
"dERlZiKwAgoLRnVuY3Rpb25EZWYSJAoJc2lnbmF0dXJlGAEgASgLMhEudGVu", | |||
"c29yZmxvdy5PcERlZhIvCgRhdHRyGAUgAygLMiEudGVuc29yZmxvdy5GdW5j", | |||
"dGlvbkRlZi5BdHRyRW50cnkSJQoIbm9kZV9kZWYYAyADKAsyEy50ZW5zb3Jm", | |||
"bG93Lk5vZGVEZWYSLQoDcmV0GAQgAygLMiAudGVuc29yZmxvdy5GdW5jdGlv", | |||
"bkRlZi5SZXRFbnRyeRpCCglBdHRyRW50cnkSCwoDa2V5GAEgASgJEiQKBXZh", | |||
"bHVlGAIgASgLMhUudGVuc29yZmxvdy5BdHRyVmFsdWU6AjgBGioKCFJldEVu", | |||
"dHJ5EgsKA2tleRgBIAEoCRINCgV2YWx1ZRgCIAEoCToCOAFKBAgCEAMiOwoL", | |||
"R3JhZGllbnREZWYSFQoNZnVuY3Rpb25fbmFtZRgBIAEoCRIVCg1ncmFkaWVu", | |||
"dF9mdW5jGAIgASgJQm4KGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IORnVu", | |||
"Y3Rpb25Qcm90b3NQAVo9Z2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNvcmZs", | |||
"b3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29ya/gBAWIGcHJvdG8z")); | |||
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||
new pbr::FileDescriptor[] { global::Tensorflow.AttrValueReflection.Descriptor, global::Tensorflow.NodeDefReflection.Descriptor, global::Tensorflow.OpDefReflection.Descriptor, }, | |||
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { | |||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionDefLibrary), global::Tensorflow.FunctionDefLibrary.Parser, new[]{ "Function", "Gradient" }, null, null, null), | |||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionDef), global::Tensorflow.FunctionDef.Parser, new[]{ "Signature", "Attr", "NodeDef", "Ret" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, null, }), | |||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GradientDef), global::Tensorflow.GradientDef.Parser, new[]{ "FunctionName", "GradientFunc" }, null, null, null) | |||
})); | |||
} | |||
#endregion | |||
} | |||
#region Messages | |||
/// <summary> | |||
/// A library is a set of named functions. | |||
/// </summary> | |||
public sealed partial class FunctionDefLibrary : pb::IMessage<FunctionDefLibrary> { | |||
private static readonly pb::MessageParser<FunctionDefLibrary> _parser = new pb::MessageParser<FunctionDefLibrary>(() => new FunctionDefLibrary()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public static pb::MessageParser<FunctionDefLibrary> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[0]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public FunctionDefLibrary() { | |||
OnConstruction(); | |||
} | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public FunctionDefLibrary(FunctionDefLibrary other) : this() { | |||
function_ = other.function_.Clone(); | |||
gradient_ = other.gradient_.Clone(); | |||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public FunctionDefLibrary Clone() { | |||
return new FunctionDefLibrary(this); | |||
} | |||
/// <summary>Field number for the "function" field.</summary> | |||
public const int FunctionFieldNumber = 1; | |||
private static readonly pb::FieldCodec<global::Tensorflow.FunctionDef> _repeated_function_codec | |||
= pb::FieldCodec.ForMessage(10, global::Tensorflow.FunctionDef.Parser); | |||
private readonly pbc::RepeatedField<global::Tensorflow.FunctionDef> function_ = new pbc::RepeatedField<global::Tensorflow.FunctionDef>(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public pbc::RepeatedField<global::Tensorflow.FunctionDef> Function { | |||
get { return function_; } | |||
} | |||
/// <summary>Field number for the "gradient" field.</summary> | |||
public const int GradientFieldNumber = 2; | |||
private static readonly pb::FieldCodec<global::Tensorflow.GradientDef> _repeated_gradient_codec | |||
= pb::FieldCodec.ForMessage(18, global::Tensorflow.GradientDef.Parser); | |||
private readonly pbc::RepeatedField<global::Tensorflow.GradientDef> gradient_ = new pbc::RepeatedField<global::Tensorflow.GradientDef>(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public pbc::RepeatedField<global::Tensorflow.GradientDef> Gradient { | |||
get { return gradient_; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override bool Equals(object other) { | |||
return Equals(other as FunctionDefLibrary); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public bool Equals(FunctionDefLibrary other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
} | |||
if (ReferenceEquals(other, this)) { | |||
return true; | |||
} | |||
if(!function_.Equals(other.function_)) return false; | |||
if(!gradient_.Equals(other.gradient_)) return false; | |||
return Equals(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
hash ^= function_.GetHashCode(); | |||
hash ^= gradient_.GetHashCode(); | |||
if (_unknownFields != null) { | |||
hash ^= _unknownFields.GetHashCode(); | |||
} | |||
return hash; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
function_.WriteTo(output, _repeated_function_codec); | |||
gradient_.WriteTo(output, _repeated_gradient_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public int CalculateSize() { | |||
int size = 0; | |||
size += function_.CalculateSize(_repeated_function_codec); | |||
size += gradient_.CalculateSize(_repeated_gradient_codec); | |||
if (_unknownFields != null) { | |||
size += _unknownFields.CalculateSize(); | |||
} | |||
return size; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void MergeFrom(FunctionDefLibrary other) { | |||
if (other == null) { | |||
return; | |||
} | |||
function_.Add(other.function_); | |||
gradient_.Add(other.gradient_); | |||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
break; | |||
case 10: { | |||
function_.AddEntriesFrom(input, _repeated_function_codec); | |||
break; | |||
} | |||
case 18: { | |||
gradient_.AddEntriesFrom(input, _repeated_gradient_codec); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
/// <summary> | |||
/// A function can be instantiated when the runtime can bind every attr | |||
/// with a value. When a GraphDef has a call to a function, it must | |||
/// have binding for every attr defined in the signature. | |||
/// | |||
/// TODO(zhifengc): | |||
/// * device spec, etc. | |||
/// </summary> | |||
public sealed partial class FunctionDef : pb::IMessage<FunctionDef> { | |||
private static readonly pb::MessageParser<FunctionDef> _parser = new pb::MessageParser<FunctionDef>(() => new FunctionDef()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public static pb::MessageParser<FunctionDef> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[1]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public FunctionDef() { | |||
OnConstruction(); | |||
} | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public FunctionDef(FunctionDef other) : this() { | |||
signature_ = other.signature_ != null ? other.signature_.Clone() : null; | |||
attr_ = other.attr_.Clone(); | |||
nodeDef_ = other.nodeDef_.Clone(); | |||
ret_ = other.ret_.Clone(); | |||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public FunctionDef Clone() { | |||
return new FunctionDef(this); | |||
} | |||
/// <summary>Field number for the "signature" field.</summary> | |||
public const int SignatureFieldNumber = 1; | |||
private global::Tensorflow.OpDef signature_; | |||
/// <summary> | |||
/// The definition of the function's name, arguments, return values, | |||
/// attrs etc. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public global::Tensorflow.OpDef Signature { | |||
get { return signature_; } | |||
set { | |||
signature_ = value; | |||
} | |||
} | |||
/// <summary>Field number for the "attr" field.</summary> | |||
public const int AttrFieldNumber = 5; | |||
private static readonly pbc::MapField<string, global::Tensorflow.AttrValue>.Codec _map_attr_codec | |||
= new pbc::MapField<string, global::Tensorflow.AttrValue>.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 42); | |||
private readonly pbc::MapField<string, global::Tensorflow.AttrValue> attr_ = new pbc::MapField<string, global::Tensorflow.AttrValue>(); | |||
/// <summary> | |||
/// Attributes specific to this function definition. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public pbc::MapField<string, global::Tensorflow.AttrValue> Attr { | |||
get { return attr_; } | |||
} | |||
/// <summary>Field number for the "node_def" field.</summary> | |||
public const int NodeDefFieldNumber = 3; | |||
private static readonly pb::FieldCodec<global::Tensorflow.NodeDef> _repeated_nodeDef_codec | |||
= pb::FieldCodec.ForMessage(26, global::Tensorflow.NodeDef.Parser); | |||
private readonly pbc::RepeatedField<global::Tensorflow.NodeDef> nodeDef_ = new pbc::RepeatedField<global::Tensorflow.NodeDef>(); | |||
/// <summary> | |||
/// By convention, "op" in node_def is resolved by consulting with a | |||
/// user-defined library first. If not resolved, "func" is assumed to | |||
/// be a builtin op. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public pbc::RepeatedField<global::Tensorflow.NodeDef> NodeDef { | |||
get { return nodeDef_; } | |||
} | |||
/// <summary>Field number for the "ret" field.</summary> | |||
public const int RetFieldNumber = 4; | |||
private static readonly pbc::MapField<string, string>.Codec _map_ret_codec | |||
= new pbc::MapField<string, string>.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForString(18), 34); | |||
private readonly pbc::MapField<string, string> ret_ = new pbc::MapField<string, string>(); | |||
/// <summary> | |||
/// A mapping from the output arg names from `signature` to the | |||
/// outputs from `node_def` that should be returned by the function. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public pbc::MapField<string, string> Ret { | |||
get { return ret_; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override bool Equals(object other) { | |||
return Equals(other as FunctionDef); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public bool Equals(FunctionDef other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
} | |||
if (ReferenceEquals(other, this)) { | |||
return true; | |||
} | |||
if (!object.Equals(Signature, other.Signature)) return false; | |||
if (!Attr.Equals(other.Attr)) return false; | |||
if(!nodeDef_.Equals(other.nodeDef_)) return false; | |||
if (!Ret.Equals(other.Ret)) return false; | |||
return Equals(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (signature_ != null) hash ^= Signature.GetHashCode(); | |||
hash ^= Attr.GetHashCode(); | |||
hash ^= nodeDef_.GetHashCode(); | |||
hash ^= Ret.GetHashCode(); | |||
if (_unknownFields != null) { | |||
hash ^= _unknownFields.GetHashCode(); | |||
} | |||
return hash; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
if (signature_ != null) { | |||
output.WriteRawTag(10); | |||
output.WriteMessage(Signature); | |||
} | |||
nodeDef_.WriteTo(output, _repeated_nodeDef_codec); | |||
ret_.WriteTo(output, _map_ret_codec); | |||
attr_.WriteTo(output, _map_attr_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (signature_ != null) { | |||
size += 1 + pb::CodedOutputStream.ComputeMessageSize(Signature); | |||
} | |||
size += attr_.CalculateSize(_map_attr_codec); | |||
size += nodeDef_.CalculateSize(_repeated_nodeDef_codec); | |||
size += ret_.CalculateSize(_map_ret_codec); | |||
if (_unknownFields != null) { | |||
size += _unknownFields.CalculateSize(); | |||
} | |||
return size; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void MergeFrom(FunctionDef other) { | |||
if (other == null) { | |||
return; | |||
} | |||
if (other.signature_ != null) { | |||
if (signature_ == null) { | |||
signature_ = new global::Tensorflow.OpDef(); | |||
} | |||
Signature.MergeFrom(other.Signature); | |||
} | |||
attr_.Add(other.attr_); | |||
nodeDef_.Add(other.nodeDef_); | |||
ret_.Add(other.ret_); | |||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
break; | |||
case 10: { | |||
if (signature_ == null) { | |||
signature_ = new global::Tensorflow.OpDef(); | |||
} | |||
input.ReadMessage(signature_); | |||
break; | |||
} | |||
case 26: { | |||
nodeDef_.AddEntriesFrom(input, _repeated_nodeDef_codec); | |||
break; | |||
} | |||
case 34: { | |||
ret_.AddEntriesFrom(input, _map_ret_codec); | |||
break; | |||
} | |||
case 42: { | |||
attr_.AddEntriesFrom(input, _map_attr_codec); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
/// <summary> | |||
/// GradientDef defines the gradient function of a function defined in | |||
/// a function library. | |||
/// | |||
/// A gradient function g (specified by gradient_func) for a function f | |||
/// (specified by function_name) must follow the following: | |||
/// | |||
/// The function 'f' must be a numerical function which takes N inputs | |||
/// and produces M outputs. Its gradient function 'g', which is a | |||
/// function taking N + M inputs and produces N outputs. | |||
/// | |||
/// I.e. if we have | |||
/// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), | |||
/// then, g is | |||
/// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, | |||
/// dL/dy1, dL/dy2, ..., dL/dy_M), | |||
/// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the | |||
/// loss function). dL/dx_i is the partial derivative of L with respect | |||
/// to x_i. | |||
/// </summary> | |||
public sealed partial class GradientDef : pb::IMessage<GradientDef> { | |||
private static readonly pb::MessageParser<GradientDef> _parser = new pb::MessageParser<GradientDef>(() => new GradientDef()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public static pb::MessageParser<GradientDef> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[2]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public GradientDef() { | |||
OnConstruction(); | |||
} | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public GradientDef(GradientDef other) : this() { | |||
functionName_ = other.functionName_; | |||
gradientFunc_ = other.gradientFunc_; | |||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public GradientDef Clone() { | |||
return new GradientDef(this); | |||
} | |||
/// <summary>Field number for the "function_name" field.</summary> | |||
public const int FunctionNameFieldNumber = 1; | |||
private string functionName_ = ""; | |||
/// <summary> | |||
/// The function name. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public string FunctionName { | |||
get { return functionName_; } | |||
set { | |||
functionName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||
} | |||
} | |||
/// <summary>Field number for the "gradient_func" field.</summary> | |||
public const int GradientFuncFieldNumber = 2; | |||
private string gradientFunc_ = ""; | |||
/// <summary> | |||
/// The gradient function's name. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public string GradientFunc { | |||
get { return gradientFunc_; } | |||
set { | |||
gradientFunc_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||
} | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override bool Equals(object other) { | |||
return Equals(other as GradientDef); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public bool Equals(GradientDef other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
} | |||
if (ReferenceEquals(other, this)) { | |||
return true; | |||
} | |||
if (FunctionName != other.FunctionName) return false; | |||
if (GradientFunc != other.GradientFunc) return false; | |||
return Equals(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (FunctionName.Length != 0) hash ^= FunctionName.GetHashCode(); | |||
if (GradientFunc.Length != 0) hash ^= GradientFunc.GetHashCode(); | |||
if (_unknownFields != null) { | |||
hash ^= _unknownFields.GetHashCode(); | |||
} | |||
return hash; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
if (FunctionName.Length != 0) { | |||
output.WriteRawTag(10); | |||
output.WriteString(FunctionName); | |||
} | |||
if (GradientFunc.Length != 0) { | |||
output.WriteRawTag(18); | |||
output.WriteString(GradientFunc); | |||
} | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (FunctionName.Length != 0) { | |||
size += 1 + pb::CodedOutputStream.ComputeStringSize(FunctionName); | |||
} | |||
if (GradientFunc.Length != 0) { | |||
size += 1 + pb::CodedOutputStream.ComputeStringSize(GradientFunc); | |||
} | |||
if (_unknownFields != null) { | |||
size += _unknownFields.CalculateSize(); | |||
} | |||
return size; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void MergeFrom(GradientDef other) { | |||
if (other == null) { | |||
return; | |||
} | |||
if (other.FunctionName.Length != 0) { | |||
FunctionName = other.FunctionName; | |||
} | |||
if (other.GradientFunc.Length != 0) { | |||
GradientFunc = other.GradientFunc; | |||
} | |||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
break; | |||
case 10: { | |||
FunctionName = input.ReadString(); | |||
break; | |||
} | |||
case 18: { | |||
GradientFunc = input.ReadString(); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
#endregion | |||
} | |||
#endregion Designer generated code |
@@ -0,0 +1,309 @@ | |||
// <auto-generated> | |||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||
// source: graph.proto | |||
// </auto-generated> | |||
#pragma warning disable 1591, 0612, 3021 | |||
#region Designer generated code | |||
using pb = global::Google.Protobuf; | |||
using pbc = global::Google.Protobuf.Collections; | |||
using pbr = global::Google.Protobuf.Reflection; | |||
using scg = global::System.Collections.Generic; | |||
namespace Tensorflow { | |||
/// <summary>Holder for reflection information generated from graph.proto</summary> | |||
public static partial class GraphReflection { | |||
#region Descriptor | |||
/// <summary>File descriptor for graph.proto</summary> | |||
public static pbr::FileDescriptor Descriptor { | |||
get { return descriptor; } | |||
} | |||
private static pbr::FileDescriptor descriptor; | |||
static GraphReflection() { | |||
byte[] descriptorData = global::System.Convert.FromBase64String( | |||
string.Concat( | |||
"CgtncmFwaC5wcm90bxIKdGVuc29yZmxvdxoObm9kZV9kZWYucHJvdG8aDmZ1", | |||
"bmN0aW9uLnByb3RvGg52ZXJzaW9ucy5wcm90byKdAQoIR3JhcGhEZWYSIQoE", | |||
"bm9kZRgBIAMoCzITLnRlbnNvcmZsb3cuTm9kZURlZhIoCgh2ZXJzaW9ucxgE", | |||
"IAEoCzIWLnRlbnNvcmZsb3cuVmVyc2lvbkRlZhITCgd2ZXJzaW9uGAMgASgF", | |||
"QgIYARIvCgdsaWJyYXJ5GAIgASgLMh4udGVuc29yZmxvdy5GdW5jdGlvbkRl", | |||
"ZkxpYnJhcnlCawoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3JrQgtHcmFwaFBy", | |||
"b3Rvc1ABWj1naXRodWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90ZW5z", | |||
"b3JmbG93L2dvL2NvcmUvZnJhbWV3b3Jr+AEBYgZwcm90bzM=")); | |||
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||
new pbr::FileDescriptor[] { global::Tensorflow.NodeDefReflection.Descriptor, global::Tensorflow.FunctionReflection.Descriptor, global::Tensorflow.VersionsReflection.Descriptor, }, | |||
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { | |||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphDef), global::Tensorflow.GraphDef.Parser, new[]{ "Node", "Versions", "Version", "Library" }, null, null, null) | |||
})); | |||
} | |||
#endregion | |||
} | |||
#region Messages | |||
/// <summary> | |||
/// Represents the graph of operations | |||
/// </summary> | |||
public sealed partial class GraphDef : pb::IMessage<GraphDef> { | |||
private static readonly pb::MessageParser<GraphDef> _parser = new pb::MessageParser<GraphDef>(() => new GraphDef()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public static pb::MessageParser<GraphDef> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.GraphReflection.Descriptor.MessageTypes[0]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public GraphDef() { | |||
OnConstruction(); | |||
} | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public GraphDef(GraphDef other) : this() { | |||
node_ = other.node_.Clone(); | |||
versions_ = other.versions_ != null ? other.versions_.Clone() : null; | |||
version_ = other.version_; | |||
library_ = other.library_ != null ? other.library_.Clone() : null; | |||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public GraphDef Clone() { | |||
return new GraphDef(this); | |||
} | |||
/// <summary>Field number for the "node" field.</summary> | |||
public const int NodeFieldNumber = 1; | |||
private static readonly pb::FieldCodec<global::Tensorflow.NodeDef> _repeated_node_codec | |||
= pb::FieldCodec.ForMessage(10, global::Tensorflow.NodeDef.Parser); | |||
private readonly pbc::RepeatedField<global::Tensorflow.NodeDef> node_ = new pbc::RepeatedField<global::Tensorflow.NodeDef>(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public pbc::RepeatedField<global::Tensorflow.NodeDef> Node { | |||
get { return node_; } | |||
} | |||
/// <summary>Field number for the "versions" field.</summary> | |||
public const int VersionsFieldNumber = 4; | |||
private global::Tensorflow.VersionDef versions_; | |||
/// <summary> | |||
/// Compatibility versions of the graph. See core/public/version.h for version | |||
/// history. The GraphDef version is distinct from the TensorFlow version, and | |||
/// each release of TensorFlow will support a range of GraphDef versions. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public global::Tensorflow.VersionDef Versions { | |||
get { return versions_; } | |||
set { | |||
versions_ = value; | |||
} | |||
} | |||
/// <summary>Field number for the "version" field.</summary> | |||
public const int VersionFieldNumber = 3; | |||
private int version_; | |||
/// <summary> | |||
/// Deprecated single version field; use versions above instead. Since all | |||
/// GraphDef changes before "versions" was introduced were forward | |||
/// compatible, this field is entirely ignored. | |||
/// </summary> | |||
[global::System.ObsoleteAttribute] | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public int Version { | |||
get { return version_; } | |||
set { | |||
version_ = value; | |||
} | |||
} | |||
/// <summary>Field number for the "library" field.</summary> | |||
public const int LibraryFieldNumber = 2; | |||
private global::Tensorflow.FunctionDefLibrary library_; | |||
/// <summary> | |||
/// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. | |||
/// | |||
/// "library" provides user-defined functions. | |||
/// | |||
/// Naming: | |||
/// * library.function.name are in a flat namespace. | |||
/// NOTE: We may need to change it to be hierarchical to support | |||
/// different orgs. E.g., | |||
/// { "/google/nn", { ... }}, | |||
/// { "/google/vision", { ... }} | |||
/// { "/org_foo/module_bar", { ... }} | |||
/// map<string, FunctionDefLib> named_lib; | |||
/// * If node[i].op is the name of one function in "library", | |||
/// node[i] is deemed as a function call. Otherwise, node[i].op | |||
/// must be a primitive operation supported by the runtime. | |||
/// | |||
/// Function call semantics: | |||
/// | |||
/// * The callee may start execution as soon as some of its inputs | |||
/// are ready. The caller may want to use Tuple() mechanism to | |||
/// ensure all inputs are ready in the same time. | |||
/// | |||
/// * The consumer of return values may start executing as soon as | |||
/// the return values the consumer depends on are ready. The | |||
/// consumer may want to use Tuple() mechanism to ensure the | |||
/// consumer does not start until all return values of the callee | |||
/// function are ready. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public global::Tensorflow.FunctionDefLibrary Library { | |||
get { return library_; } | |||
set { | |||
library_ = value; | |||
} | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override bool Equals(object other) { | |||
return Equals(other as GraphDef); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public bool Equals(GraphDef other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
} | |||
if (ReferenceEquals(other, this)) { | |||
return true; | |||
} | |||
if(!node_.Equals(other.node_)) return false; | |||
if (!object.Equals(Versions, other.Versions)) return false; | |||
if (Version != other.Version) return false; | |||
if (!object.Equals(Library, other.Library)) return false; | |||
return Equals(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
hash ^= node_.GetHashCode(); | |||
if (versions_ != null) hash ^= Versions.GetHashCode(); | |||
if (Version != 0) hash ^= Version.GetHashCode(); | |||
if (library_ != null) hash ^= Library.GetHashCode(); | |||
if (_unknownFields != null) { | |||
hash ^= _unknownFields.GetHashCode(); | |||
} | |||
return hash; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
node_.WriteTo(output, _repeated_node_codec); | |||
if (library_ != null) { | |||
output.WriteRawTag(18); | |||
output.WriteMessage(Library); | |||
} | |||
if (Version != 0) { | |||
output.WriteRawTag(24); | |||
output.WriteInt32(Version); | |||
} | |||
if (versions_ != null) { | |||
output.WriteRawTag(34); | |||
output.WriteMessage(Versions); | |||
} | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public int CalculateSize() { | |||
int size = 0; | |||
size += node_.CalculateSize(_repeated_node_codec); | |||
if (versions_ != null) { | |||
size += 1 + pb::CodedOutputStream.ComputeMessageSize(Versions); | |||
} | |||
if (Version != 0) { | |||
size += 1 + pb::CodedOutputStream.ComputeInt32Size(Version); | |||
} | |||
if (library_ != null) { | |||
size += 1 + pb::CodedOutputStream.ComputeMessageSize(Library); | |||
} | |||
if (_unknownFields != null) { | |||
size += _unknownFields.CalculateSize(); | |||
} | |||
return size; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void MergeFrom(GraphDef other) { | |||
if (other == null) { | |||
return; | |||
} | |||
node_.Add(other.node_); | |||
if (other.versions_ != null) { | |||
if (versions_ == null) { | |||
versions_ = new global::Tensorflow.VersionDef(); | |||
} | |||
Versions.MergeFrom(other.Versions); | |||
} | |||
if (other.Version != 0) { | |||
Version = other.Version; | |||
} | |||
if (other.library_ != null) { | |||
if (library_ == null) { | |||
library_ = new global::Tensorflow.FunctionDefLibrary(); | |||
} | |||
Library.MergeFrom(other.Library); | |||
} | |||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
break; | |||
case 10: { | |||
node_.AddEntriesFrom(input, _repeated_node_codec); | |||
break; | |||
} | |||
case 18: { | |||
if (library_ == null) { | |||
library_ = new global::Tensorflow.FunctionDefLibrary(); | |||
} | |||
input.ReadMessage(library_); | |||
break; | |||
} | |||
case 24: { | |||
Version = input.ReadInt32(); | |||
break; | |||
} | |||
case 34: { | |||
if (versions_ == null) { | |||
versions_ = new global::Tensorflow.VersionDef(); | |||
} | |||
input.ReadMessage(versions_); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
#endregion | |||
} | |||
#endregion Designer generated code |
@@ -1,12 +1,15 @@ | |||
### Download compiler from https://github.com/protocolbuffers/protobuf/releases | |||
```shell | |||
set SRC_DIR=D:\Projects\tensorflow\tensorflow\core\framework | |||
set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Tensorflow | |||
set SRC_DIR=D:\Projects\tensorflow-1.12.0\tensorflow\core\framework | |||
set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Protobuf | |||
.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% resource_handle.proto | |||
.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor_shape.proto | |||
.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% types.proto | |||
.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor.proto | |||
.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% attr_value.proto | |||
.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% node_def.proto | |||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% resource_handle.proto | |||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor_shape.proto | |||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% types.proto | |||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor.proto | |||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% attr_value.proto | |||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% node_def.proto | |||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% versions.proto | |||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% function.proto | |||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% graph.proto | |||
``` |
@@ -0,0 +1,247 @@ | |||
// <auto-generated> | |||
// Generated by the protocol buffer compiler. DO NOT EDIT! | |||
// source: versions.proto | |||
// </auto-generated> | |||
#pragma warning disable 1591, 0612, 3021 | |||
#region Designer generated code | |||
using pb = global::Google.Protobuf; | |||
using pbc = global::Google.Protobuf.Collections; | |||
using pbr = global::Google.Protobuf.Reflection; | |||
using scg = global::System.Collections.Generic; | |||
namespace Tensorflow { | |||
/// <summary>Holder for reflection information generated from versions.proto</summary> | |||
public static partial class VersionsReflection { | |||
#region Descriptor | |||
/// <summary>File descriptor for versions.proto</summary> | |||
public static pbr::FileDescriptor Descriptor { | |||
get { return descriptor; } | |||
} | |||
private static pbr::FileDescriptor descriptor; | |||
static VersionsReflection() { | |||
byte[] descriptorData = global::System.Convert.FromBase64String( | |||
string.Concat( | |||
"Cg52ZXJzaW9ucy5wcm90bxIKdGVuc29yZmxvdyJLCgpWZXJzaW9uRGVmEhAK", | |||
"CHByb2R1Y2VyGAEgASgFEhQKDG1pbl9jb25zdW1lchgCIAEoBRIVCg1iYWRf", | |||
"Y29uc3VtZXJzGAMgAygFQm4KGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IO", | |||
"VmVyc2lvbnNQcm90b3NQAVo9Z2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNv", | |||
"cmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29ya/gBAWIGcHJvdG8z")); | |||
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||
new pbr::FileDescriptor[] { }, | |||
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { | |||
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.VersionDef), global::Tensorflow.VersionDef.Parser, new[]{ "Producer", "MinConsumer", "BadConsumers" }, null, null, null) | |||
})); | |||
} | |||
#endregion | |||
} | |||
#region Messages | |||
/// <summary> | |||
/// Version information for a piece of serialized data | |||
/// | |||
/// There are different types of versions for each type of data | |||
/// (GraphDef, etc.), but they all have the same common shape | |||
/// described here. | |||
/// | |||
/// Each consumer has "consumer" and "min_producer" versions (specified | |||
/// elsewhere). A consumer is allowed to consume this data if | |||
/// | |||
/// producer >= min_producer | |||
/// consumer >= min_consumer | |||
/// consumer not in bad_consumers | |||
/// </summary> | |||
public sealed partial class VersionDef : pb::IMessage<VersionDef> { | |||
private static readonly pb::MessageParser<VersionDef> _parser = new pb::MessageParser<VersionDef>(() => new VersionDef()); | |||
private pb::UnknownFieldSet _unknownFields; | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public static pb::MessageParser<VersionDef> Parser { get { return _parser; } } | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public static pbr::MessageDescriptor Descriptor { | |||
get { return global::Tensorflow.VersionsReflection.Descriptor.MessageTypes[0]; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
get { return Descriptor; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public VersionDef() { | |||
OnConstruction(); | |||
} | |||
partial void OnConstruction(); | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public VersionDef(VersionDef other) : this() { | |||
producer_ = other.producer_; | |||
minConsumer_ = other.minConsumer_; | |||
badConsumers_ = other.badConsumers_.Clone(); | |||
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public VersionDef Clone() { | |||
return new VersionDef(this); | |||
} | |||
/// <summary>Field number for the "producer" field.</summary> | |||
public const int ProducerFieldNumber = 1; | |||
private int producer_; | |||
/// <summary> | |||
/// The version of the code that produced this data. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public int Producer { | |||
get { return producer_; } | |||
set { | |||
producer_ = value; | |||
} | |||
} | |||
/// <summary>Field number for the "min_consumer" field.</summary> | |||
public const int MinConsumerFieldNumber = 2; | |||
private int minConsumer_; | |||
/// <summary> | |||
/// Any consumer below this version is not allowed to consume this data. | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public int MinConsumer { | |||
get { return minConsumer_; } | |||
set { | |||
minConsumer_ = value; | |||
} | |||
} | |||
/// <summary>Field number for the "bad_consumers" field.</summary> | |||
public const int BadConsumersFieldNumber = 3; | |||
private static readonly pb::FieldCodec<int> _repeated_badConsumers_codec | |||
= pb::FieldCodec.ForInt32(26); | |||
private readonly pbc::RepeatedField<int> badConsumers_ = new pbc::RepeatedField<int>(); | |||
/// <summary> | |||
/// Specific consumer versions which are disallowed (e.g. due to bugs). | |||
/// </summary> | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public pbc::RepeatedField<int> BadConsumers { | |||
get { return badConsumers_; } | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override bool Equals(object other) { | |||
return Equals(other as VersionDef); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public bool Equals(VersionDef other) { | |||
if (ReferenceEquals(other, null)) { | |||
return false; | |||
} | |||
if (ReferenceEquals(other, this)) { | |||
return true; | |||
} | |||
if (Producer != other.Producer) return false; | |||
if (MinConsumer != other.MinConsumer) return false; | |||
if(!badConsumers_.Equals(other.badConsumers_)) return false; | |||
return Equals(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override int GetHashCode() { | |||
int hash = 1; | |||
if (Producer != 0) hash ^= Producer.GetHashCode(); | |||
if (MinConsumer != 0) hash ^= MinConsumer.GetHashCode(); | |||
hash ^= badConsumers_.GetHashCode(); | |||
if (_unknownFields != null) { | |||
hash ^= _unknownFields.GetHashCode(); | |||
} | |||
return hash; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public override string ToString() { | |||
return pb::JsonFormatter.ToDiagnosticString(this); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void WriteTo(pb::CodedOutputStream output) { | |||
if (Producer != 0) { | |||
output.WriteRawTag(8); | |||
output.WriteInt32(Producer); | |||
} | |||
if (MinConsumer != 0) { | |||
output.WriteRawTag(16); | |||
output.WriteInt32(MinConsumer); | |||
} | |||
badConsumers_.WriteTo(output, _repeated_badConsumers_codec); | |||
if (_unknownFields != null) { | |||
_unknownFields.WriteTo(output); | |||
} | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public int CalculateSize() { | |||
int size = 0; | |||
if (Producer != 0) { | |||
size += 1 + pb::CodedOutputStream.ComputeInt32Size(Producer); | |||
} | |||
if (MinConsumer != 0) { | |||
size += 1 + pb::CodedOutputStream.ComputeInt32Size(MinConsumer); | |||
} | |||
size += badConsumers_.CalculateSize(_repeated_badConsumers_codec); | |||
if (_unknownFields != null) { | |||
size += _unknownFields.CalculateSize(); | |||
} | |||
return size; | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void MergeFrom(VersionDef other) { | |||
if (other == null) { | |||
return; | |||
} | |||
if (other.Producer != 0) { | |||
Producer = other.Producer; | |||
} | |||
if (other.MinConsumer != 0) { | |||
MinConsumer = other.MinConsumer; | |||
} | |||
badConsumers_.Add(other.badConsumers_); | |||
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
} | |||
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
public void MergeFrom(pb::CodedInputStream input) { | |||
uint tag; | |||
while ((tag = input.ReadTag()) != 0) { | |||
switch(tag) { | |||
default: | |||
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
break; | |||
case 8: { | |||
Producer = input.ReadInt32(); | |||
break; | |||
} | |||
case 16: { | |||
MinConsumer = input.ReadInt32(); | |||
break; | |||
} | |||
case 26: | |||
case 24: { | |||
badConsumers_.AddEntriesFrom(input, _repeated_badConsumers_codec); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
#endregion | |||
} | |||
#endregion Designer generated code |
@@ -17,7 +17,7 @@ namespace Tensorflow | |||
/// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) | |||
/// struct => struct (TF_Output output) => (TF_Output output) | |||
/// struct* => struct (TF_Output* output) => (TF_Output[] output) | |||
/// struct* => ref IntPtr (TF_Input* consumers) => (ref IntPtr handle), if output is struct[] | |||
/// struct* => struct* for ref | |||
/// const char* => string | |||
/// int32_t => int | |||
/// int64_t* => long[] | |||
@@ -83,6 +83,42 @@ namespace TensorFlowNET.UnitTest | |||
Assert.AreEqual(1, feed_port.Length); | |||
Assert.AreEqual(add, feed_port[0].oper); | |||
Assert.AreEqual(0, feed_port[0].index); | |||
// The scalar const oper also has a consumer. | |||
Assert.AreEqual(1, three.OutputNumConsumers(0)); | |||
TF_Input[] three_port = three.OutputConsumers(0, 1); | |||
Assert.AreEqual(add, three_port[0].oper); | |||
Assert.AreEqual(1, three_port[0].index); | |||
// Serialize to GraphDef. | |||
var graph_def = c_test_util.GetGraphDef(graph); | |||
// Validate GraphDef is what we expect. | |||
bool found_placeholder = false; | |||
bool found_scalar_const = false; | |||
bool found_add = false; | |||
foreach (var n in graph_def.Node) | |||
{ | |||
if (c_test_util.IsPlaceholder(n)) | |||
{ | |||
Assert.IsFalse(found_placeholder); | |||
found_placeholder = true; | |||
} | |||
/*else if (IsScalarConst(n, 3)) | |||
{ | |||
Assert.IsFalse(found_scalar_const); | |||
found_scalar_const = true; | |||
} | |||
else if (IsAddN(n, 2)) | |||
{ | |||
Assert.IsFalse(found_add); | |||
found_add = true; | |||
} | |||
else | |||
{ | |||
ADD_FAILURE() << "Unexpected NodeDef: " << ProtoDebugString(n); | |||
}*/ | |||
} | |||
} | |||
} | |||
} |
@@ -19,7 +19,7 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
var handle = c_api.TF_GetAllOpList(); | |||
var buffer = new Buffer(handle); | |||
Assert.IsTrue(buffer.Length == buffer.Data.Length); | |||
Assert.IsTrue(buffer.Length == buffer.Length); | |||
} | |||
[TestMethod] | |||
@@ -39,11 +39,20 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
var buffer = new Buffer(); | |||
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | |||
attr_value = AttrValue.Parser.ParseFrom(buffer.Data); | |||
attr_value = AttrValue.Parser.ParseFrom(buffer); | |||
buffer.Dispose(); | |||
return s.Code == TF_Code.TF_OK; | |||
} | |||
public static GraphDef GetGraphDef(Graph graph) | |||
{ | |||
var s = new Status(); | |||
var buffer = new Buffer(); | |||
c_api.TF_GraphToGraphDef(graph, buffer, s); | |||
s.Check(); | |||
return GraphDef.Parser.ParseFrom(buffer); | |||
} | |||
public static bool GetNodeDef(Operation oper, ref NodeDef node_def) | |||
{ | |||
var s = new Status(); | |||
@@ -53,6 +62,37 @@ namespace TensorFlowNET.UnitTest | |||
return s.Code == TF_Code.TF_OK; | |||
} | |||
public static bool IsPlaceholder(NodeDef node_def) | |||
{ | |||
if (node_def.Op != "Placeholder" || node_def.Name != "feed") | |||
{ | |||
return false; | |||
} | |||
bool found_dtype = false; | |||
bool found_shape = false; | |||
foreach (var attr in node_def.Attr) | |||
{ | |||
if (attr.Key == "dtype") | |||
{ | |||
if (attr.Value.Type == DataType.DtInt32) | |||
{ | |||
found_dtype = true; | |||
} | |||
else | |||
{ | |||
return false; | |||
} | |||
} | |||
else if (attr.Key == "shape") | |||
{ | |||
found_shape = true; | |||
} | |||
} | |||
return found_dtype && found_shape; | |||
} | |||
public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op) | |||
{ | |||
var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | |||