Browse Source

add GraphDef and FunctionDef.

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
d5b81e9c1f
12 changed files with 1270 additions and 16 deletions
  1. +5
    -0
      src/TensorFlowNET.Core/Buffers/Buffer.cs
  2. +10
    -0
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  3. +4
    -4
      src/TensorFlowNET.Core/Operations/Operation.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  5. +604
    -0
      src/TensorFlowNET.Core/Protobuf/Function.cs
  6. +309
    -0
      src/TensorFlowNET.Core/Protobuf/Graph.cs
  7. +11
    -8
      src/TensorFlowNET.Core/Protobuf/README.md
  8. +247
    -0
      src/TensorFlowNET.Core/Protobuf/Versions.cs
  9. +1
    -1
      src/TensorFlowNET.Core/c_api.cs
  10. +36
    -0
      test/TensorFlowNET.UnitTest/GraphTest.cs
  11. +1
    -1
      test/TensorFlowNET.UnitTest/OperationsTest.cs
  12. +41
    -1
      test/TensorFlowNET.UnitTest/c_test_util.cs

+ 5
- 0
src/TensorFlowNET.Core/Buffers/Buffer.cs View File

@@ -39,6 +39,11 @@ namespace Tensorflow
return buffer._handle;
}

public static implicit operator byte[](Buffer buffer)
{
return buffer.Data;
}

public void Dispose()
{
c_api.TF_DeleteBuffer(_handle);


+ 10
- 0
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -38,6 +38,16 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status);

/// <summary>
/// Write out a serialized representation of `graph` (as a GraphDef protocol
/// message) to `output_graph_def` (allocated by TF_NewBuffer()).
/// </summary>
/// <param name="graph"></param>
/// <param name="output_graph_def"></param>
/// <param name="status"></param>
[DllImport(TensorFlowLibName)]
public static extern void TF_GraphToGraphDef(IntPtr graph, IntPtr output_graph_def, IntPtr status);
/// <summary>
/// Returns the number of dimensions of the Tensor referenced by `output`
/// in `graph`.


+ 4
- 4
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -26,15 +26,15 @@ namespace Tensorflow
public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status);
public int NumInputs => c_api.TF_OperationNumInputs(_handle);
public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index));
public TF_Input[] OutputConsumers(int index, int max_consumers)
public unsafe TF_Input[] OutputConsumers(int index, int max_consumers)
{
IntPtr handle = IntPtr.Zero;
int size = Marshal.SizeOf<TF_Input>();
int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), ref handle, max_consumers);
var handle = (TF_Input*)Marshal.AllocHGlobal(size);
int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers);
var consumers = new TF_Input[num];
for(int i = 0; i < num; i++)
{
consumers[0] = Marshal.PtrToStructure<TF_Input>(handle + i * size);
consumers[i] = new TF_Input((*handle).oper + i * size, (*handle).index);
}

return consumers;


+ 1
- 1
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -112,7 +112,7 @@ namespace Tensorflow
/// <param name="max_consumers"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TF_OperationOutputConsumers(TF_Output oper_out, ref IntPtr consumers, int max_consumers);
public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input * consumers, int max_consumers);

[DllImport(TensorFlowLibName)]
public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out);


+ 604
- 0
src/TensorFlowNET.Core/Protobuf/Function.cs View File

@@ -0,0 +1,604 @@
// <auto-generated>
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: function.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021
#region Designer generated code

using pb = global::Google.Protobuf;
using pbc = global::Google.Protobuf.Collections;
using pbr = global::Google.Protobuf.Reflection;
using scg = global::System.Collections.Generic;
namespace Tensorflow {

/// <summary>Holder for reflection information generated from function.proto</summary>
public static partial class FunctionReflection {

#region Descriptor
/// <summary>File descriptor for function.proto</summary>
public static pbr::FileDescriptor Descriptor {
get { return descriptor; }
}
private static pbr::FileDescriptor descriptor;

static FunctionReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"Cg5mdW5jdGlvbi5wcm90bxIKdGVuc29yZmxvdxoQYXR0cl92YWx1ZS5wcm90",
"bxoObm9kZV9kZWYucHJvdG8aDG9wX2RlZi5wcm90byJqChJGdW5jdGlvbkRl",
"ZkxpYnJhcnkSKQoIZnVuY3Rpb24YASADKAsyFy50ZW5zb3JmbG93LkZ1bmN0",
"aW9uRGVmEikKCGdyYWRpZW50GAIgAygLMhcudGVuc29yZmxvdy5HcmFkaWVu",
"dERlZiKwAgoLRnVuY3Rpb25EZWYSJAoJc2lnbmF0dXJlGAEgASgLMhEudGVu",
"c29yZmxvdy5PcERlZhIvCgRhdHRyGAUgAygLMiEudGVuc29yZmxvdy5GdW5j",
"dGlvbkRlZi5BdHRyRW50cnkSJQoIbm9kZV9kZWYYAyADKAsyEy50ZW5zb3Jm",
"bG93Lk5vZGVEZWYSLQoDcmV0GAQgAygLMiAudGVuc29yZmxvdy5GdW5jdGlv",
"bkRlZi5SZXRFbnRyeRpCCglBdHRyRW50cnkSCwoDa2V5GAEgASgJEiQKBXZh",
"bHVlGAIgASgLMhUudGVuc29yZmxvdy5BdHRyVmFsdWU6AjgBGioKCFJldEVu",
"dHJ5EgsKA2tleRgBIAEoCRINCgV2YWx1ZRgCIAEoCToCOAFKBAgCEAMiOwoL",
"R3JhZGllbnREZWYSFQoNZnVuY3Rpb25fbmFtZRgBIAEoCRIVCg1ncmFkaWVu",
"dF9mdW5jGAIgASgJQm4KGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IORnVu",
"Y3Rpb25Qcm90b3NQAVo9Z2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNvcmZs",
"b3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29ya/gBAWIGcHJvdG8z"));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { global::Tensorflow.AttrValueReflection.Descriptor, global::Tensorflow.NodeDefReflection.Descriptor, global::Tensorflow.OpDefReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionDefLibrary), global::Tensorflow.FunctionDefLibrary.Parser, new[]{ "Function", "Gradient" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionDef), global::Tensorflow.FunctionDef.Parser, new[]{ "Signature", "Attr", "NodeDef", "Ret" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, null, }),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GradientDef), global::Tensorflow.GradientDef.Parser, new[]{ "FunctionName", "GradientFunc" }, null, null, null)
}));
}
#endregion

}
#region Messages
/// <summary>
/// A library is a set of named functions.
/// </summary>
public sealed partial class FunctionDefLibrary : pb::IMessage<FunctionDefLibrary> {
private static readonly pb::MessageParser<FunctionDefLibrary> _parser = new pb::MessageParser<FunctionDefLibrary>(() => new FunctionDefLibrary());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<FunctionDefLibrary> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public FunctionDefLibrary() {
OnConstruction();
}

partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public FunctionDefLibrary(FunctionDefLibrary other) : this() {
function_ = other.function_.Clone();
gradient_ = other.gradient_.Clone();
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public FunctionDefLibrary Clone() {
return new FunctionDefLibrary(this);
}

/// <summary>Field number for the "function" field.</summary>
public const int FunctionFieldNumber = 1;
private static readonly pb::FieldCodec<global::Tensorflow.FunctionDef> _repeated_function_codec
= pb::FieldCodec.ForMessage(10, global::Tensorflow.FunctionDef.Parser);
private readonly pbc::RepeatedField<global::Tensorflow.FunctionDef> function_ = new pbc::RepeatedField<global::Tensorflow.FunctionDef>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<global::Tensorflow.FunctionDef> Function {
get { return function_; }
}

/// <summary>Field number for the "gradient" field.</summary>
public const int GradientFieldNumber = 2;
private static readonly pb::FieldCodec<global::Tensorflow.GradientDef> _repeated_gradient_codec
= pb::FieldCodec.ForMessage(18, global::Tensorflow.GradientDef.Parser);
private readonly pbc::RepeatedField<global::Tensorflow.GradientDef> gradient_ = new pbc::RepeatedField<global::Tensorflow.GradientDef>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<global::Tensorflow.GradientDef> Gradient {
get { return gradient_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as FunctionDefLibrary);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(FunctionDefLibrary other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if(!function_.Equals(other.function_)) return false;
if(!gradient_.Equals(other.gradient_)) return false;
return Equals(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
hash ^= function_.GetHashCode();
hash ^= gradient_.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
function_.WriteTo(output, _repeated_function_codec);
gradient_.WriteTo(output, _repeated_gradient_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
size += function_.CalculateSize(_repeated_function_codec);
size += gradient_.CalculateSize(_repeated_gradient_codec);
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(FunctionDefLibrary other) {
if (other == null) {
return;
}
function_.Add(other.function_);
gradient_.Add(other.gradient_);
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 10: {
function_.AddEntriesFrom(input, _repeated_function_codec);
break;
}
case 18: {
gradient_.AddEntriesFrom(input, _repeated_gradient_codec);
break;
}
}
}
}

}

/// <summary>
/// A function can be instantiated when the runtime can bind every attr
/// with a value. When a GraphDef has a call to a function, it must
/// have binding for every attr defined in the signature.
///
/// TODO(zhifengc):
/// * device spec, etc.
/// </summary>
public sealed partial class FunctionDef : pb::IMessage<FunctionDef> {
private static readonly pb::MessageParser<FunctionDef> _parser = new pb::MessageParser<FunctionDef>(() => new FunctionDef());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<FunctionDef> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[1]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public FunctionDef() {
OnConstruction();
}

partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public FunctionDef(FunctionDef other) : this() {
signature_ = other.signature_ != null ? other.signature_.Clone() : null;
attr_ = other.attr_.Clone();
nodeDef_ = other.nodeDef_.Clone();
ret_ = other.ret_.Clone();
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public FunctionDef Clone() {
return new FunctionDef(this);
}

/// <summary>Field number for the "signature" field.</summary>
public const int SignatureFieldNumber = 1;
private global::Tensorflow.OpDef signature_;
/// <summary>
/// The definition of the function's name, arguments, return values,
/// attrs etc.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::Tensorflow.OpDef Signature {
get { return signature_; }
set {
signature_ = value;
}
}

/// <summary>Field number for the "attr" field.</summary>
public const int AttrFieldNumber = 5;
private static readonly pbc::MapField<string, global::Tensorflow.AttrValue>.Codec _map_attr_codec
= new pbc::MapField<string, global::Tensorflow.AttrValue>.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 42);
private readonly pbc::MapField<string, global::Tensorflow.AttrValue> attr_ = new pbc::MapField<string, global::Tensorflow.AttrValue>();
/// <summary>
/// Attributes specific to this function definition.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::MapField<string, global::Tensorflow.AttrValue> Attr {
get { return attr_; }
}

/// <summary>Field number for the "node_def" field.</summary>
public const int NodeDefFieldNumber = 3;
private static readonly pb::FieldCodec<global::Tensorflow.NodeDef> _repeated_nodeDef_codec
= pb::FieldCodec.ForMessage(26, global::Tensorflow.NodeDef.Parser);
private readonly pbc::RepeatedField<global::Tensorflow.NodeDef> nodeDef_ = new pbc::RepeatedField<global::Tensorflow.NodeDef>();
/// <summary>
/// By convention, "op" in node_def is resolved by consulting with a
/// user-defined library first. If not resolved, "func" is assumed to
/// be a builtin op.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<global::Tensorflow.NodeDef> NodeDef {
get { return nodeDef_; }
}

/// <summary>Field number for the "ret" field.</summary>
public const int RetFieldNumber = 4;
private static readonly pbc::MapField<string, string>.Codec _map_ret_codec
= new pbc::MapField<string, string>.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForString(18), 34);
private readonly pbc::MapField<string, string> ret_ = new pbc::MapField<string, string>();
/// <summary>
/// A mapping from the output arg names from `signature` to the
/// outputs from `node_def` that should be returned by the function.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::MapField<string, string> Ret {
get { return ret_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as FunctionDef);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(FunctionDef other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (!object.Equals(Signature, other.Signature)) return false;
if (!Attr.Equals(other.Attr)) return false;
if(!nodeDef_.Equals(other.nodeDef_)) return false;
if (!Ret.Equals(other.Ret)) return false;
return Equals(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (signature_ != null) hash ^= Signature.GetHashCode();
hash ^= Attr.GetHashCode();
hash ^= nodeDef_.GetHashCode();
hash ^= Ret.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (signature_ != null) {
output.WriteRawTag(10);
output.WriteMessage(Signature);
}
nodeDef_.WriteTo(output, _repeated_nodeDef_codec);
ret_.WriteTo(output, _map_ret_codec);
attr_.WriteTo(output, _map_attr_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (signature_ != null) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(Signature);
}
size += attr_.CalculateSize(_map_attr_codec);
size += nodeDef_.CalculateSize(_repeated_nodeDef_codec);
size += ret_.CalculateSize(_map_ret_codec);
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(FunctionDef other) {
if (other == null) {
return;
}
if (other.signature_ != null) {
if (signature_ == null) {
signature_ = new global::Tensorflow.OpDef();
}
Signature.MergeFrom(other.Signature);
}
attr_.Add(other.attr_);
nodeDef_.Add(other.nodeDef_);
ret_.Add(other.ret_);
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 10: {
if (signature_ == null) {
signature_ = new global::Tensorflow.OpDef();
}
input.ReadMessage(signature_);
break;
}
case 26: {
nodeDef_.AddEntriesFrom(input, _repeated_nodeDef_codec);
break;
}
case 34: {
ret_.AddEntriesFrom(input, _map_ret_codec);
break;
}
case 42: {
attr_.AddEntriesFrom(input, _map_attr_codec);
break;
}
}
}
}

}

/// <summary>
/// GradientDef defines the gradient function of a function defined in
/// a function library.
///
/// A gradient function g (specified by gradient_func) for a function f
/// (specified by function_name) must follow the following:
///
/// The function 'f' must be a numerical function which takes N inputs
/// and produces M outputs. Its gradient function 'g', which is a
/// function taking N + M inputs and produces N outputs.
///
/// I.e. if we have
/// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
/// then, g is
/// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
/// dL/dy1, dL/dy2, ..., dL/dy_M),
/// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
/// loss function). dL/dx_i is the partial derivative of L with respect
/// to x_i.
/// </summary>
public sealed partial class GradientDef : pb::IMessage<GradientDef> {
private static readonly pb::MessageParser<GradientDef> _parser = new pb::MessageParser<GradientDef>(() => new GradientDef());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<GradientDef> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[2]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public GradientDef() {
OnConstruction();
}

partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public GradientDef(GradientDef other) : this() {
functionName_ = other.functionName_;
gradientFunc_ = other.gradientFunc_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public GradientDef Clone() {
return new GradientDef(this);
}

/// <summary>Field number for the "function_name" field.</summary>
public const int FunctionNameFieldNumber = 1;
private string functionName_ = "";
/// <summary>
/// The function name.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string FunctionName {
get { return functionName_; }
set {
functionName_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

/// <summary>Field number for the "gradient_func" field.</summary>
public const int GradientFuncFieldNumber = 2;
private string gradientFunc_ = "";
/// <summary>
/// The gradient function's name.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string GradientFunc {
get { return gradientFunc_; }
set {
gradientFunc_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as GradientDef);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(GradientDef other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (FunctionName != other.FunctionName) return false;
if (GradientFunc != other.GradientFunc) return false;
return Equals(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (FunctionName.Length != 0) hash ^= FunctionName.GetHashCode();
if (GradientFunc.Length != 0) hash ^= GradientFunc.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (FunctionName.Length != 0) {
output.WriteRawTag(10);
output.WriteString(FunctionName);
}
if (GradientFunc.Length != 0) {
output.WriteRawTag(18);
output.WriteString(GradientFunc);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (FunctionName.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(FunctionName);
}
if (GradientFunc.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(GradientFunc);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(GradientDef other) {
if (other == null) {
return;
}
if (other.FunctionName.Length != 0) {
FunctionName = other.FunctionName;
}
if (other.GradientFunc.Length != 0) {
GradientFunc = other.GradientFunc;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 10: {
FunctionName = input.ReadString();
break;
}
case 18: {
GradientFunc = input.ReadString();
break;
}
}
}
}

}

#endregion

}

#endregion Designer generated code

+ 309
- 0
src/TensorFlowNET.Core/Protobuf/Graph.cs View File

@@ -0,0 +1,309 @@
// <auto-generated>
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: graph.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021
#region Designer generated code

using pb = global::Google.Protobuf;
using pbc = global::Google.Protobuf.Collections;
using pbr = global::Google.Protobuf.Reflection;
using scg = global::System.Collections.Generic;
namespace Tensorflow {

/// <summary>Holder for reflection information generated from graph.proto</summary>
public static partial class GraphReflection {

#region Descriptor
/// <summary>File descriptor for graph.proto</summary>
public static pbr::FileDescriptor Descriptor {
get { return descriptor; }
}
private static pbr::FileDescriptor descriptor;

static GraphReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CgtncmFwaC5wcm90bxIKdGVuc29yZmxvdxoObm9kZV9kZWYucHJvdG8aDmZ1",
"bmN0aW9uLnByb3RvGg52ZXJzaW9ucy5wcm90byKdAQoIR3JhcGhEZWYSIQoE",
"bm9kZRgBIAMoCzITLnRlbnNvcmZsb3cuTm9kZURlZhIoCgh2ZXJzaW9ucxgE",
"IAEoCzIWLnRlbnNvcmZsb3cuVmVyc2lvbkRlZhITCgd2ZXJzaW9uGAMgASgF",
"QgIYARIvCgdsaWJyYXJ5GAIgASgLMh4udGVuc29yZmxvdy5GdW5jdGlvbkRl",
"ZkxpYnJhcnlCawoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3JrQgtHcmFwaFBy",
"b3Rvc1ABWj1naXRodWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90ZW5z",
"b3JmbG93L2dvL2NvcmUvZnJhbWV3b3Jr+AEBYgZwcm90bzM="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { global::Tensorflow.NodeDefReflection.Descriptor, global::Tensorflow.FunctionReflection.Descriptor, global::Tensorflow.VersionsReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphDef), global::Tensorflow.GraphDef.Parser, new[]{ "Node", "Versions", "Version", "Library" }, null, null, null)
}));
}
#endregion

}
#region Messages
/// <summary>
/// Represents the graph of operations
/// </summary>
public sealed partial class GraphDef : pb::IMessage<GraphDef> {
private static readonly pb::MessageParser<GraphDef> _parser = new pb::MessageParser<GraphDef>(() => new GraphDef());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<GraphDef> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.GraphReflection.Descriptor.MessageTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public GraphDef() {
OnConstruction();
}

partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public GraphDef(GraphDef other) : this() {
node_ = other.node_.Clone();
versions_ = other.versions_ != null ? other.versions_.Clone() : null;
version_ = other.version_;
library_ = other.library_ != null ? other.library_.Clone() : null;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public GraphDef Clone() {
return new GraphDef(this);
}

/// <summary>Field number for the "node" field.</summary>
public const int NodeFieldNumber = 1;
private static readonly pb::FieldCodec<global::Tensorflow.NodeDef> _repeated_node_codec
= pb::FieldCodec.ForMessage(10, global::Tensorflow.NodeDef.Parser);
private readonly pbc::RepeatedField<global::Tensorflow.NodeDef> node_ = new pbc::RepeatedField<global::Tensorflow.NodeDef>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<global::Tensorflow.NodeDef> Node {
get { return node_; }
}

/// <summary>Field number for the "versions" field.</summary>
public const int VersionsFieldNumber = 4;
private global::Tensorflow.VersionDef versions_;
/// <summary>
/// Compatibility versions of the graph. See core/public/version.h for version
/// history. The GraphDef version is distinct from the TensorFlow version, and
/// each release of TensorFlow will support a range of GraphDef versions.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::Tensorflow.VersionDef Versions {
get { return versions_; }
set {
versions_ = value;
}
}

/// <summary>Field number for the "version" field.</summary>
public const int VersionFieldNumber = 3;
private int version_;
/// <summary>
/// Deprecated single version field; use versions above instead. Since all
/// GraphDef changes before "versions" was introduced were forward
/// compatible, this field is entirely ignored.
/// </summary>
[global::System.ObsoleteAttribute]
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int Version {
get { return version_; }
set {
version_ = value;
}
}

/// <summary>Field number for the "library" field.</summary>
public const int LibraryFieldNumber = 2;
private global::Tensorflow.FunctionDefLibrary library_;
/// <summary>
/// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
///
/// "library" provides user-defined functions.
///
/// Naming:
/// * library.function.name are in a flat namespace.
/// NOTE: We may need to change it to be hierarchical to support
/// different orgs. E.g.,
/// { "/google/nn", { ... }},
/// { "/google/vision", { ... }}
/// { "/org_foo/module_bar", { ... }}
/// map&lt;string, FunctionDefLib> named_lib;
/// * If node[i].op is the name of one function in "library",
/// node[i] is deemed as a function call. Otherwise, node[i].op
/// must be a primitive operation supported by the runtime.
///
/// Function call semantics:
///
/// * The callee may start execution as soon as some of its inputs
/// are ready. The caller may want to use Tuple() mechanism to
/// ensure all inputs are ready in the same time.
///
/// * The consumer of return values may start executing as soon as
/// the return values the consumer depends on are ready. The
/// consumer may want to use Tuple() mechanism to ensure the
/// consumer does not start until all return values of the callee
/// function are ready.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::Tensorflow.FunctionDefLibrary Library {
get { return library_; }
set {
library_ = value;
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as GraphDef);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(GraphDef other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if(!node_.Equals(other.node_)) return false;
if (!object.Equals(Versions, other.Versions)) return false;
if (Version != other.Version) return false;
if (!object.Equals(Library, other.Library)) return false;
return Equals(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
hash ^= node_.GetHashCode();
if (versions_ != null) hash ^= Versions.GetHashCode();
if (Version != 0) hash ^= Version.GetHashCode();
if (library_ != null) hash ^= Library.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
node_.WriteTo(output, _repeated_node_codec);
if (library_ != null) {
output.WriteRawTag(18);
output.WriteMessage(Library);
}
if (Version != 0) {
output.WriteRawTag(24);
output.WriteInt32(Version);
}
if (versions_ != null) {
output.WriteRawTag(34);
output.WriteMessage(Versions);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
size += node_.CalculateSize(_repeated_node_codec);
if (versions_ != null) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(Versions);
}
if (Version != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(Version);
}
if (library_ != null) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(Library);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(GraphDef other) {
if (other == null) {
return;
}
node_.Add(other.node_);
if (other.versions_ != null) {
if (versions_ == null) {
versions_ = new global::Tensorflow.VersionDef();
}
Versions.MergeFrom(other.Versions);
}
if (other.Version != 0) {
Version = other.Version;
}
if (other.library_ != null) {
if (library_ == null) {
library_ = new global::Tensorflow.FunctionDefLibrary();
}
Library.MergeFrom(other.Library);
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 10: {
node_.AddEntriesFrom(input, _repeated_node_codec);
break;
}
case 18: {
if (library_ == null) {
library_ = new global::Tensorflow.FunctionDefLibrary();
}
input.ReadMessage(library_);
break;
}
case 24: {
Version = input.ReadInt32();
break;
}
case 34: {
if (versions_ == null) {
versions_ = new global::Tensorflow.VersionDef();
}
input.ReadMessage(versions_);
break;
}
}
}
}

}

#endregion

}

#endregion Designer generated code

+ 11
- 8
src/TensorFlowNET.Core/Protobuf/README.md View File

@@ -1,12 +1,15 @@
### Download compiler from https://github.com/protocolbuffers/protobuf/releases
```shell
set SRC_DIR=D:\Projects\tensorflow\tensorflow\core\framework
set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Tensorflow
set SRC_DIR=D:\Projects\tensorflow-1.12.0\tensorflow\core\framework
set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Protobuf

.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% resource_handle.proto
.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor_shape.proto
.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% types.proto
.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor.proto
.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% attr_value.proto
.\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% node_def.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% resource_handle.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor_shape.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% types.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% attr_value.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% node_def.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% versions.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% function.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% graph.proto
```

+ 247
- 0
src/TensorFlowNET.Core/Protobuf/Versions.cs View File

@@ -0,0 +1,247 @@
// <auto-generated>
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: versions.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021
#region Designer generated code

using pb = global::Google.Protobuf;
using pbc = global::Google.Protobuf.Collections;
using pbr = global::Google.Protobuf.Reflection;
using scg = global::System.Collections.Generic;
namespace Tensorflow {

/// <summary>Holder for reflection information generated from versions.proto</summary>
public static partial class VersionsReflection {

#region Descriptor
/// <summary>File descriptor for versions.proto</summary>
public static pbr::FileDescriptor Descriptor {
get { return descriptor; }
}
private static pbr::FileDescriptor descriptor;

static VersionsReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"Cg52ZXJzaW9ucy5wcm90bxIKdGVuc29yZmxvdyJLCgpWZXJzaW9uRGVmEhAK",
"CHByb2R1Y2VyGAEgASgFEhQKDG1pbl9jb25zdW1lchgCIAEoBRIVCg1iYWRf",
"Y29uc3VtZXJzGAMgAygFQm4KGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IO",
"VmVyc2lvbnNQcm90b3NQAVo9Z2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNv",
"cmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29ya/gBAWIGcHJvdG8z"));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.VersionDef), global::Tensorflow.VersionDef.Parser, new[]{ "Producer", "MinConsumer", "BadConsumers" }, null, null, null)
}));
}
#endregion

}
#region Messages
/// <summary>
/// Version information for a piece of serialized data
///
/// There are different types of versions for each type of data
/// (GraphDef, etc.), but they all have the same common shape
/// described here.
///
/// Each consumer has "consumer" and "min_producer" versions (specified
/// elsewhere). A consumer is allowed to consume this data if
///
/// producer >= min_producer
/// consumer >= min_consumer
/// consumer not in bad_consumers
/// </summary>
public sealed partial class VersionDef : pb::IMessage<VersionDef> {
private static readonly pb::MessageParser<VersionDef> _parser = new pb::MessageParser<VersionDef>(() => new VersionDef());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<VersionDef> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::Tensorflow.VersionsReflection.Descriptor.MessageTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public VersionDef() {
OnConstruction();
}

partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public VersionDef(VersionDef other) : this() {
producer_ = other.producer_;
minConsumer_ = other.minConsumer_;
badConsumers_ = other.badConsumers_.Clone();
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public VersionDef Clone() {
return new VersionDef(this);
}

/// <summary>Field number for the "producer" field.</summary>
public const int ProducerFieldNumber = 1;
private int producer_;
/// <summary>
/// The version of the code that produced this data.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int Producer {
get { return producer_; }
set {
producer_ = value;
}
}

/// <summary>Field number for the "min_consumer" field.</summary>
public const int MinConsumerFieldNumber = 2;
private int minConsumer_;
/// <summary>
/// Any consumer below this version is not allowed to consume this data.
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int MinConsumer {
get { return minConsumer_; }
set {
minConsumer_ = value;
}
}

/// <summary>Field number for the "bad_consumers" field.</summary>
public const int BadConsumersFieldNumber = 3;
private static readonly pb::FieldCodec<int> _repeated_badConsumers_codec
= pb::FieldCodec.ForInt32(26);
private readonly pbc::RepeatedField<int> badConsumers_ = new pbc::RepeatedField<int>();
/// <summary>
/// Specific consumer versions which are disallowed (e.g. due to bugs).
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<int> BadConsumers {
get { return badConsumers_; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as VersionDef);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(VersionDef other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (Producer != other.Producer) return false;
if (MinConsumer != other.MinConsumer) return false;
if(!badConsumers_.Equals(other.badConsumers_)) return false;
return Equals(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (Producer != 0) hash ^= Producer.GetHashCode();
if (MinConsumer != 0) hash ^= MinConsumer.GetHashCode();
hash ^= badConsumers_.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (Producer != 0) {
output.WriteRawTag(8);
output.WriteInt32(Producer);
}
if (MinConsumer != 0) {
output.WriteRawTag(16);
output.WriteInt32(MinConsumer);
}
badConsumers_.WriteTo(output, _repeated_badConsumers_codec);
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (Producer != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(Producer);
}
if (MinConsumer != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(MinConsumer);
}
size += badConsumers_.CalculateSize(_repeated_badConsumers_codec);
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(VersionDef other) {
if (other == null) {
return;
}
if (other.Producer != 0) {
Producer = other.Producer;
}
if (other.MinConsumer != 0) {
MinConsumer = other.MinConsumer;
}
badConsumers_.Add(other.badConsumers_);
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 8: {
Producer = input.ReadInt32();
break;
}
case 16: {
MinConsumer = input.ReadInt32();
break;
}
case 26:
case 24: {
badConsumers_.AddEntriesFrom(input, _repeated_badConsumers_codec);
break;
}
}
}
}

}

#endregion

}

#endregion Designer generated code

+ 1
- 1
src/TensorFlowNET.Core/c_api.cs View File

@@ -17,7 +17,7 @@ namespace Tensorflow
/// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph)
/// struct => struct (TF_Output output) => (TF_Output output)
/// struct* => struct (TF_Output* output) => (TF_Output[] output)
/// struct* => ref IntPtr (TF_Input* consumers) => (ref IntPtr handle), if output is struct[]
/// struct* => struct* for ref
/// const char* => string
/// int32_t => int
/// int64_t* => long[]


+ 36
- 0
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -83,6 +83,42 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual(1, feed_port.Length);
Assert.AreEqual(add, feed_port[0].oper);
Assert.AreEqual(0, feed_port[0].index);

// The scalar const oper also has a consumer.
Assert.AreEqual(1, three.OutputNumConsumers(0));
TF_Input[] three_port = three.OutputConsumers(0, 1);
Assert.AreEqual(add, three_port[0].oper);
Assert.AreEqual(1, three_port[0].index);

// Serialize to GraphDef.
var graph_def = c_test_util.GetGraphDef(graph);

// Validate GraphDef is what we expect.
bool found_placeholder = false;
bool found_scalar_const = false;
bool found_add = false;
foreach (var n in graph_def.Node)
{
if (c_test_util.IsPlaceholder(n))
{
Assert.IsFalse(found_placeholder);
found_placeholder = true;
}
/*else if (IsScalarConst(n, 3))
{
Assert.IsFalse(found_scalar_const);
found_scalar_const = true;
}
else if (IsAddN(n, 2))
{
Assert.IsFalse(found_add);
found_add = true;
}
else
{
ADD_FAILURE() << "Unexpected NodeDef: " << ProtoDebugString(n);
}*/
}
}
}
}

+ 1
- 1
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -19,7 +19,7 @@ namespace TensorFlowNET.UnitTest
{
var handle = c_api.TF_GetAllOpList();
var buffer = new Buffer(handle);
Assert.IsTrue(buffer.Length == buffer.Data.Length);
Assert.IsTrue(buffer.Length == buffer.Length);
}

[TestMethod]


+ 41
- 1
test/TensorFlowNET.UnitTest/c_test_util.cs View File

@@ -39,11 +39,20 @@ namespace TensorFlowNET.UnitTest
{
var buffer = new Buffer();
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s);
attr_value = AttrValue.Parser.ParseFrom(buffer.Data);
attr_value = AttrValue.Parser.ParseFrom(buffer);
buffer.Dispose();
return s.Code == TF_Code.TF_OK;
}

public static GraphDef GetGraphDef(Graph graph)
{
var s = new Status();
var buffer = new Buffer();
c_api.TF_GraphToGraphDef(graph, buffer, s);
s.Check();
return GraphDef.Parser.ParseFrom(buffer);
}

public static bool GetNodeDef(Operation oper, ref NodeDef node_def)
{
var s = new Status();
@@ -53,6 +62,37 @@ namespace TensorFlowNET.UnitTest
return s.Code == TF_Code.TF_OK;
}

public static bool IsPlaceholder(NodeDef node_def)
{
if (node_def.Op != "Placeholder" || node_def.Name != "feed")
{
return false;
}

bool found_dtype = false;
bool found_shape = false;
foreach (var attr in node_def.Attr)
{
if (attr.Key == "dtype")
{
if (attr.Value.Type == DataType.DtInt32)
{
found_dtype = true;
}
else
{
return false;
}
}
else if (attr.Key == "shape")
{
found_shape = true;
}
}

return found_dtype && found_shape;
}

public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op)
{
var desc = c_api.TF_NewOperation(graph, "Placeholder", name);


Loading…
Cancel
Save