Browse Source

add implicit for Graph, Operation, Tensor, Status.

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
9e42e3c67f
20 changed files with 282 additions and 127 deletions
  1. +9
    -11
      src/TensorFlowNET.Core/Buffers/Buffer.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Buffers/TF_Buffer.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Buffers/c_api.buffer.cs
  4. +7
    -3
      src/TensorFlowNET.Core/Graphs/Graph.cs
  5. +33
    -0
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  6. +0
    -3
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  7. +9
    -3
      src/TensorFlowNET.Core/Operations/Operation.cs
  8. +0
    -0
      src/TensorFlowNET.Core/Operations/TF_Input.cs
  9. +0
    -0
      src/TensorFlowNET.Core/Operations/TF_Output.cs
  10. +13
    -6
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  11. +3
    -3
      src/TensorFlowNET.Core/Operations/ops.cs
  12. +10
    -10
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  13. +0
    -59
      src/TensorFlowNET.Core/Sessions/FeedDict.cs
  14. +1
    -1
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  15. +22
    -2
      src/TensorFlowNET.Core/Status/Status.cs
  16. +48
    -19
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  17. +11
    -1
      src/TensorFlowNET.Core/c_api.cs
  18. +12
    -3
      test/TensorFlowNET.UnitTest/OperationsTest.cs
  19. +65
    -1
      test/TensorFlowNET.UnitTest/TensorTest.cs
  20. +37
    -0
      test/TensorFlowNET.UnitTest/c_test_util.cs

+ 9
- 11
src/TensorFlowNET.Core/Buffers/Buffer.cs View File

@@ -9,21 +9,19 @@ namespace Tensorflow
{
private IntPtr _handle;
public IntPtr Handle => _handle;
//public TF_Buffer buffer => Marshal.PtrToStructure<TF_Buffer>(_handle);

public unsafe Buffer()
{
_handle = Marshal.AllocHGlobal(sizeof(TF_Buffer));
}
private TF_Buffer buffer;

public byte[] GetBuffer()
{
var buffer = Marshal.PtrToStructure<TF_Buffer>(_handle);
public byte[] Data;

var data = Marshal.AllocHGlobal(buffer.length);
//var bytes = c_api.TF_GetBuffer(buffer.data);
public int Length => (int)buffer.length;

return null;
public unsafe Buffer(IntPtr handle)
{
_handle = handle;
buffer = Marshal.PtrToStructure<TF_Buffer>(_handle);
Data = new byte[buffer.length];
Marshal.Copy(buffer.data, Data, 0, (int)buffer.length);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Buffers/TF_Buffer.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow
public struct TF_Buffer
{
public IntPtr data;
public int length;
public ulong length;
public IntPtr data_deallocator;
}
}

+ 1
- 1
src/TensorFlowNET.Core/Buffers/c_api.buffer.cs View File

@@ -8,6 +8,6 @@ namespace Tensorflow
public static partial class c_api
{
[DllImport(TensorFlowLibName)]
public static extern string TF_GetBuffer(IntPtr buffer);
public static extern IntPtr TF_GetBuffer(TF_Buffer buffer);
}
}

+ 7
- 3
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -15,8 +15,7 @@ namespace Tensorflow
/// </summary>
public class Graph
{
private IntPtr _c_graph;
public IntPtr Handle => _c_graph;
private IntPtr _handle;
private Dictionary<int, Operation> _nodes_by_id;
private Dictionary<string, Operation> _nodes_by_name;
private Dictionary<string, int> _names_in_use;
@@ -28,7 +27,7 @@ namespace Tensorflow

public Graph(IntPtr graph)
{
this._c_graph = graph;
_handle = graph;
_nodes_by_id = new Dictionary<int, Operation>();
_nodes_by_name = new Dictionary<string, Operation>();
_names_in_use = new Dictionary<string, int>();
@@ -171,5 +170,10 @@ namespace Tensorflow
{
return _nodes_by_name.Values.Select(x => x).ToArray();
}

public static implicit operator IntPtr(Graph graph)
{
return graph._handle;
}
}
}

+ 33
- 0
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -10,6 +10,39 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status);

/// <summary>
/// Returns the shape of the Tensor referenced by `output` in `graph`
/// into `dims`. `dims` must be an array large enough to hold `num_dims`
/// entries (e.g., the return value of TF_GraphGetTensorNumDims).
/// </summary>
/// <param name="graph"></param>
/// <param name="output"></param>
/// <param name="dims"></param>
/// <param name="num_dims"></param>
/// <param name="status"></param>
[DllImport(TensorFlowLibName)]
public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status);

/// <summary>
/// Sets the shape of the Tensor referenced by `output` in `graph` to
/// the shape described by `dims` and `num_dims`.
/// </summary>
[DllImport(TensorFlowLibName)]
public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status);

/// <summary>
/// Returns the number of dimensions of the Tensor referenced by `output`
/// in `graph`.
///
/// If the number of dimensions in the shape is unknown, returns -1.
/// </summary>
/// <param name="graph"></param>
/// <param name="output"></param>
/// <param name="status"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, IntPtr status);

[DllImport(TensorFlowLibName)]
public static unsafe extern IntPtr TF_NewGraph();
}


+ 0
- 3
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -28,9 +28,6 @@ namespace Tensorflow
{
var op_def = _ops[op_type_name];

var status = new Status();
var buffer = new Buffer();

var g = ops.get_default_graph();

if (String.IsNullOrEmpty(name))


+ 9
- 3
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -1,12 +1,13 @@
using System;
using System.Collections.Generic;
using System.Text;
using TF_DataType = Tensorflow.DataType;

namespace Tensorflow
{
public class Operation
{
public IntPtr Handle { get; }

private Graph _graph;
public Graph graph => _graph;
public IntPtr _c_op;
@@ -17,15 +18,20 @@ namespace Tensorflow
public Tensor[] outputs => _outputs;
public Tensor[] inputs;

public Operation(IntPtr handle)
{
Handle = handle;
}

public Operation(Graph g, string opType, string oper_name)
{
_graph = g;

var status = new Status();

var desc = c_api.TF_NewOperation(g.Handle, opType, oper_name);
var desc = c_api.TF_NewOperation(g, opType, oper_name);
c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32);
c_api.TF_FinishOperation(desc, status.Handle);
c_api.TF_FinishOperation(desc, status);
}

public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)


src/TensorFlowNET.Core/Graphs/TF_Input.cs → src/TensorFlowNET.Core/Operations/TF_Input.cs View File


src/TensorFlowNET.Core/Graphs/TF_Output.cs → src/TensorFlowNET.Core/Operations/TF_Output.cs View File


+ 13
- 6
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -7,30 +7,37 @@ namespace Tensorflow
{
public static partial class c_api
{
/// <summary>
/// Get the OpList of all OpDefs defined in this address space.
/// </summary>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static unsafe extern IntPtr TF_GetAllOpList();

/// <summary>
/// For inputs that take a single tensor.
/// </summary>
/// <param name="desc"></param>
/// <param name="input"></param>
[DllImport(TensorFlowLibName)]
public static unsafe extern void TF_AddInput(TF_OperationDescription desc, TF_Output input);
public static unsafe extern void TF_AddInput(IntPtr desc, TF_Output input);

[DllImport(TensorFlowLibName)]
public static unsafe extern IntPtr TF_FinishOperation(TF_OperationDescription desc, IntPtr status);
public static unsafe extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status);

[DllImport(TensorFlowLibName)]
public static unsafe extern TF_OperationDescription TF_NewOperation(IntPtr graph, string opType, string oper_name);
public static unsafe extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name);

[DllImport(TensorFlowLibName)]
public static extern unsafe int TF_OperationNumOutputs(IntPtr oper);

[DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SetAttrValueProto(TF_OperationDescription desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status);
public static extern unsafe void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status);

[DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, IntPtr value, IntPtr status);
public static extern unsafe void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status);

[DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value);
public static extern unsafe void TF_SetAttrType(IntPtr desc, string attr_name, TF_DataType value);
}
}

+ 3
- 3
src/TensorFlowNET.Core/Operations/ops.cs View File

@@ -24,7 +24,7 @@ namespace Tensorflow

public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs)
{
var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name);
var op_desc = c_api.TF_NewOperation(graph, node_def.Op, node_def.Name);

// Add inputs
if(inputs != null)
@@ -45,12 +45,12 @@ namespace Tensorflow
var bytes = attr.Value.ToByteArray();
var proto = Marshal.AllocHGlobal(bytes.Length);
Marshal.Copy(bytes, 0, proto, bytes.Length);
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle);
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status);

if(status.Code != TF_Code.TF_OK) throw new Exception(status.Message);
}

var c_op = c_api.TF_FinishOperation(op_desc, status.Handle);
var c_op = c_api.TF_FinishOperation(op_desc, status);

if (status.Code != TF_Code.TF_OK) throw new Exception(status.Message);



+ 10
- 10
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -30,7 +30,7 @@ namespace Tensorflow
_target = UTF8Encoding.UTF8.GetBytes(target);
var opts = c_api.TF_NewSessionOptions();
var status = new Status();
_session = c_api.TF_NewSession(_graph.Handle, opts, status.Handle);
_session = c_api.TF_NewSession(_graph, opts, status);

c_api.TF_DeleteSessionOptions(opts);
}
@@ -40,30 +40,30 @@ namespace Tensorflow
}

public virtual object run(Tensor fetches, FeedDict feed_dict = null)
public virtual object run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
{
var result = _run(fetches, feed_dict);

return result;
}

private unsafe object _run(Tensor fetches, FeedDict feed_dict = null)
private unsafe object _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
{
var feed_dict_tensor = new FeedDict();
var feed_dict_tensor = new Dictionary<Tensor, object>();

if (feed_dict != null)
{
NDArray np_val = null;
foreach (FeedValue feed in feed_dict)
foreach (var feed in feed_dict)
{
switch (feed.feed_val)
switch (feed.Value)
{
case float value:
np_val = np.asarray(value);
break;
}

feed_dict_tensor[feed.feed] = np_val;
feed_dict_tensor[feed.Key] = np_val;
}
}

@@ -85,9 +85,9 @@ namespace Tensorflow
return fetch_handler.build_results(null, results);
}

private object[] _do_run(List<Tensor> fetch_list, FeedDict feed_dict)
private object[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, object> feed_dict)
{
var feeds = feed_dict.items().Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray();
var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray();
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();

return _call_tf_sessionrun(feeds, fetches);
@@ -113,7 +113,7 @@ namespace Tensorflow
target_opers: new IntPtr[] { },
ntargets: 0,
run_metadata: IntPtr.Zero,
status: status.Handle);
status: status);

var result = output_values.Select(x => c_api.TF_TensorData(x))
.Select(x => (object)*(float*)x)


+ 0
- 59
src/TensorFlowNET.Core/Sessions/FeedDict.cs View File

@@ -1,59 +0,0 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class FeedDict : IEnumerable
{
private Dictionary<Tensor, object> feed_dict;

public FeedDict()
{
feed_dict = new Dictionary<Tensor, object>();
}

public object this[Tensor feed]
{
get
{
return feed_dict[feed];
}

set
{
feed_dict[feed] = value;
}
}

public FeedDict Add(Tensor feed, object value)
{
feed_dict.Add(feed, value);
return this;
}

public IEnumerator GetEnumerator()
{
foreach (KeyValuePair<Tensor, object> feed in feed_dict)
{
yield return new FeedValue
{
feed = feed.Key,
feed_val = feed.Value
};
}
}

public Dictionary<Tensor, object> items()
{
return feed_dict;
}
}

public struct FeedValue
{
public Tensor feed { get; set; }
public object feed_val { get; set; }
}
}

+ 1
- 1
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -15,7 +15,7 @@ namespace Tensorflow
private List<Tensor> _final_fetches = new List<Tensor>();
private List<object> _targets = new List<object>();

public _FetchHandler(Graph graph, Tensor fetches, FeedDict feeds = null, object feed_handles = null)
public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, object> feeds = null, object feed_handles = null)
{
_fetch_mapper = new _FetchMapper().for_fetch(fetches);
foreach(var fetch in _fetch_mapper.unique_fetches())


+ 22
- 2
src/TensorFlowNET.Core/Status/Status.cs View File

@@ -4,10 +4,13 @@ using System.Text;

namespace Tensorflow
{
public class Status : IDisposable
/// <summary>
/// TF_Status holds error information. It either has an OK code, or
/// else an error code with an associated error message.
/// </summary>
public class Status
{
private readonly IntPtr _handle;
public IntPtr Handle => _handle;

/// <summary>
/// Error message
@@ -29,6 +32,23 @@ namespace Tensorflow
c_api.TF_SetStatus(_handle, code, msg);
}

/// <summary>
/// Check status
/// Throw exception with error message if code != TF_OK
/// </summary>
public void Check()
{
if(Code != TF_Code.TF_OK)
{
throw new Exception(Message);
}
}

public static implicit operator IntPtr(Status status)
{
return status._handle;
}

public void Dispose()
{
c_api.TF_DeleteStatus(_handle);


+ 48
- 19
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -13,6 +13,8 @@ namespace Tensorflow
/// </summary>
public class Tensor
{
public IntPtr Handle { get; }

public Graph graph => op.graph;
public Operation op { get; }

@@ -21,7 +23,6 @@ namespace Tensorflow
public int value_index { get; }

public TF_DataType dtype { get; }
public IntPtr handle { get; }
public ulong bytesize { get; }
public ulong dataTypeSize { get;}
public ulong size => bytesize / dataTypeSize;
@@ -45,7 +46,7 @@ namespace Tensorflow

public Tensor(IntPtr handle)
{
this.handle = handle;
Handle = handle;
dtype = c_api.TF_TensorType(handle);
rank = c_api.TF_NumDims(handle);
bytesize = c_api.TF_TensorByteSize(handle);
@@ -59,33 +60,52 @@ namespace Tensorflow

public Tensor(NDArray nd)
{
var data = Marshal.AllocHGlobal(sizeof(float) * nd.size);
Marshal.Copy(nd.Data<float>(), 0, data, nd.size);
var dataType = ToTFDataType(nd.dtype);
Handle = Allocate(nd);
dtype = c_api.TF_TensorType(Handle);
rank = c_api.TF_NumDims(Handle);
bytesize = c_api.TF_TensorByteSize(Handle);
buffer = c_api.TF_TensorData(Handle);
dataTypeSize = c_api.TF_DataTypeSize(dtype);

shape = new long[rank];
for (int i = 0; i < rank; i++)
shape[i] = c_api.TF_Dim(Handle, i);
}

private IntPtr Allocate(NDArray nd)
{
var dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size);

switch (nd.dtype.Name)
{
case "Int32":
Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size);
break;
case "Single":
Marshal.Copy(nd.Data<float>(), 0, dotHandle, nd.size);
break;
case "Double":
Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size);
break;
default:
throw new NotImplementedException("Marshal.Copy failed.");
}

var handle = c_api.TF_NewTensor(dataType,
var dataType = ToTFDataType(nd.dtype);
var tfHandle = c_api.TF_NewTensor(dataType,
nd.shape.Select(x => (long)x).ToArray(), // shape
nd.ndim,
data,
(UIntPtr)(nd.size * sizeof(float)),
dotHandle,
(UIntPtr)(nd.size * nd.dtypesize),
(IntPtr values, IntPtr len, ref bool closure) =>
{
// Free the original buffer and set flag
Marshal.FreeHGlobal(data);
Marshal.FreeHGlobal(dotHandle);
closure = true;
},
ref deallocator_called);

this.handle = handle;
dtype = c_api.TF_TensorType(handle);
rank = c_api.TF_NumDims(handle);
bytesize = c_api.TF_TensorByteSize(handle);
buffer = c_api.TF_TensorData(handle);
dataTypeSize = c_api.TF_DataTypeSize(dtype);

shape = new long[rank];
for (int i = 0; i < rank; i++)
shape[i] = c_api.TF_Dim(handle, i);
return tfHandle;
}

public Tensor(Operation op, int value_index, TF_DataType dtype)
@@ -129,11 +149,20 @@ namespace Tensorflow
{
switch (type.Name)
{
case "Int32":
return TF_DataType.TF_INT32;
case "Single":
return TF_DataType.TF_FLOAT;
case "Double":
return TF_DataType.TF_DOUBLE;
}

return TF_DataType.DtInvalid;
}

public static implicit operator IntPtr(Tensor tensor)
{
return tensor.Handle;
}
}
}

+ 11
- 1
src/TensorFlowNET.Core/c_api.cs View File

@@ -10,12 +10,22 @@ namespace Tensorflow
///
/// The API leans towards simplicity and uniformity instead of convenience
/// since most usage will be by language specific wrappers.
///
/// The params type mapping between .net and c_api
/// TF_XX** => ref IntPtr (TF_Operation** op) => (ref IntPtr op)
/// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph)
/// struct => struct (TF_Output output) => (TF_Output output)
/// const char* => string
/// int32_t => int
/// int64_t* => long[]
/// size_t* => unlong[]
/// void* => IntPtr
/// </summary>
public static partial class c_api
{
public const string TensorFlowLibName = "tensorflow";

public delegate void Deallocator(IntPtr data, IntPtr size, ref bool deallocatorData);
public delegate void Deallocator(IntPtr data, IntPtr size, ref bool deallocator);

[DllImport(TensorFlowLibName)]
public static unsafe extern IntPtr TF_Version();


+ 12
- 3
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -3,12 +3,21 @@ using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;
using Buffer = Tensorflow.Buffer;

namespace TensorFlowNET.UnitTest
{
[TestClass]
public class OperationsTest
{
[TestMethod]
public void GetAllOpList()
{
var handle = c_api.TF_GetAllOpList();
var buffer = new Buffer(handle);
Assert.IsTrue(buffer.Length == buffer.Data.Length);
}

[TestMethod]
public void addInPlaceholder()
{
@@ -18,9 +27,9 @@ namespace TensorFlowNET.UnitTest

using(var sess = tf.Session())
{
var feed_dict = new FeedDict()
.Add(a, 3.0f)
.Add(b, 2.0f);
var feed_dict = new Dictionary<Tensor, object>();
feed_dict.Add(a, 3.0f);
feed_dict.Add(b, 2.0f);

var o = sess.run(c, feed_dict);
}


+ 65
- 1
test/TensorFlowNET.UnitTest/TensorTest.cs View File

@@ -13,7 +13,7 @@ namespace TensorFlowNET.UnitTest
public class TensorTest
{
[TestMethod]
public unsafe void NewTensor()
public void NewTensor()
{
var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3);

@@ -27,5 +27,69 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual(tensor.bytesize, (uint)nd.size * sizeof(float));
Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array));
}

/// <summary>
/// Port from tensorflow\c\c_api_test.cc
/// </summary>
[TestMethod]
public void SetShape()
{
var s = new Status();
var graph = tf.get_default_graph();

var desc = c_api.TF_NewOperation(graph, "Placeholder", "");
c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_FLOAT);
//if (!dims.empty())
{
//TF_SetAttrShape(desc, "shape", dims.data(), dims.size());
}
var op = c_api.TF_FinishOperation(desc, s);

Assert.IsTrue(s.Code == TF_Code.TF_OK);
Assert.IsNotNull(op);

// Fetch the shape, it should be completely unknown.
var feed_out_0 = new TF_Output { oper = op, index = 0 };
int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);

Assert.IsTrue(s.Code == TF_Code.TF_OK);
Assert.AreEqual(-1, num_dims);

// Set the shape to be unknown, expect no change.
c_api.TF_GraphSetTensorShape(graph, feed_out_0, new int[0], -1, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
Assert.AreEqual(-1, num_dims);

// Set the shape to be 2 x Unknown
var dims = new int[] { 2, -1 };
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
Assert.AreEqual(2, num_dims);

// Get the dimension vector appropriately.
var returned_dims = new int[dims.Length];
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));

// Set to a new valid shape: [2, 3]
dims[1] = 3;
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
//Assert.IsTrue(s.Code == TF_Code.TF_OK);

// Fetch and see that the new value is returned.
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
//Assert.IsTrue(s.Code == TF_Code.TF_OK);
//Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));

// Test for a scalar.
var three = c_test_util.ScalarConst(3, graph, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
var three_out_0 = new TF_Output { oper = three.Handle };
num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s);
Assert.AreEqual(0, num_dims);
}
}
}

+ 37
- 0
test/TensorFlowNET.UnitTest/c_test_util.cs View File

@@ -0,0 +1,37 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using Tensorflow;

namespace TensorFlowNET.UnitTest
{
public static class c_test_util
{
public static void ConstHelper(Tensor t, Graph graph, Status s, string name, ref IntPtr op)
{
var desc = c_api.TF_NewOperation(graph, "Const", name);
c_api.TF_SetAttrTensor(desc, "value", t.Handle, s);
s.Check();
c_api.TF_SetAttrType(desc, "dtype", t.dtype);
op = c_api.TF_FinishOperation(desc, s);
s.Check();
if(op == null)
{
throw new Exception("c_api.TF_FinishOperation failed.");
}
}

public static Operation Const(Tensor t, Graph graph, Status s, string name)
{
IntPtr op = IntPtr.Zero;
ConstHelper(t, graph, s, name, ref op);
return new Operation(op);
}

public static Operation ScalarConst(int v, Graph graph, Status s, string name = "Const")
{
return Const(new Tensor(v), graph, s, name);
}
}
}

Loading…
Cancel
Save