@@ -9,21 +9,19 @@ namespace Tensorflow | |||
{ | |||
private IntPtr _handle; | |||
public IntPtr Handle => _handle; | |||
//public TF_Buffer buffer => Marshal.PtrToStructure<TF_Buffer>(_handle); | |||
public unsafe Buffer() | |||
{ | |||
_handle = Marshal.AllocHGlobal(sizeof(TF_Buffer)); | |||
} | |||
private TF_Buffer buffer; | |||
public byte[] GetBuffer() | |||
{ | |||
var buffer = Marshal.PtrToStructure<TF_Buffer>(_handle); | |||
public byte[] Data; | |||
var data = Marshal.AllocHGlobal(buffer.length); | |||
//var bytes = c_api.TF_GetBuffer(buffer.data); | |||
public int Length => (int)buffer.length; | |||
return null; | |||
public unsafe Buffer(IntPtr handle) | |||
{ | |||
_handle = handle; | |||
buffer = Marshal.PtrToStructure<TF_Buffer>(_handle); | |||
Data = new byte[buffer.length]; | |||
Marshal.Copy(buffer.data, Data, 0, (int)buffer.length); | |||
} | |||
} | |||
} |
@@ -9,7 +9,7 @@ namespace Tensorflow | |||
public struct TF_Buffer | |||
{ | |||
public IntPtr data; | |||
public int length; | |||
public ulong length; | |||
public IntPtr data_deallocator; | |||
} | |||
} |
@@ -8,6 +8,6 @@ namespace Tensorflow | |||
public static partial class c_api | |||
{ | |||
[DllImport(TensorFlowLibName)] | |||
public static extern string TF_GetBuffer(IntPtr buffer); | |||
public static extern IntPtr TF_GetBuffer(TF_Buffer buffer); | |||
} | |||
} |
@@ -15,8 +15,7 @@ namespace Tensorflow | |||
/// </summary> | |||
public class Graph | |||
{ | |||
private IntPtr _c_graph; | |||
public IntPtr Handle => _c_graph; | |||
private IntPtr _handle; | |||
private Dictionary<int, Operation> _nodes_by_id; | |||
private Dictionary<string, Operation> _nodes_by_name; | |||
private Dictionary<string, int> _names_in_use; | |||
@@ -28,7 +27,7 @@ namespace Tensorflow | |||
public Graph(IntPtr graph) | |||
{ | |||
this._c_graph = graph; | |||
_handle = graph; | |||
_nodes_by_id = new Dictionary<int, Operation>(); | |||
_nodes_by_name = new Dictionary<string, Operation>(); | |||
_names_in_use = new Dictionary<string, int>(); | |||
@@ -171,5 +170,10 @@ namespace Tensorflow | |||
{ | |||
return _nodes_by_name.Values.Select(x => x).ToArray(); | |||
} | |||
public static implicit operator IntPtr(Graph graph) | |||
{ | |||
return graph._handle; | |||
} | |||
} | |||
} |
@@ -10,6 +10,39 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); | |||
/// <summary> | |||
/// Returns the shape of the Tensor referenced by `output` in `graph` | |||
/// into `dims`. `dims` must be an array large enough to hold `num_dims` | |||
/// entries (e.g., the return value of TF_GraphGetTensorNumDims). | |||
/// </summary> | |||
/// <param name="graph"></param> | |||
/// <param name="output"></param> | |||
/// <param name="dims"></param> | |||
/// <param name="num_dims"></param> | |||
/// <param name="status"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status); | |||
/// <summary> | |||
/// Sets the shape of the Tensor referenced by `output` in `graph` to | |||
/// the shape described by `dims` and `num_dims`. | |||
/// </summary> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status); | |||
/// <summary> | |||
/// Returns the number of dimensions of the Tensor referenced by `output` | |||
/// in `graph`. | |||
/// | |||
/// If the number of dimensions in the shape is unknown, returns -1. | |||
/// </summary> | |||
/// <param name="graph"></param> | |||
/// <param name="output"></param> | |||
/// <param name="status"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, IntPtr status); | |||
[DllImport(TensorFlowLibName)] | |||
public static unsafe extern IntPtr TF_NewGraph(); | |||
} | |||
@@ -28,9 +28,6 @@ namespace Tensorflow | |||
{ | |||
var op_def = _ops[op_type_name]; | |||
var status = new Status(); | |||
var buffer = new Buffer(); | |||
var g = ops.get_default_graph(); | |||
if (String.IsNullOrEmpty(name)) | |||
@@ -1,12 +1,13 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using TF_DataType = Tensorflow.DataType; | |||
namespace Tensorflow | |||
{ | |||
public class Operation | |||
{ | |||
public IntPtr Handle { get; } | |||
private Graph _graph; | |||
public Graph graph => _graph; | |||
public IntPtr _c_op; | |||
@@ -17,15 +18,20 @@ namespace Tensorflow | |||
public Tensor[] outputs => _outputs; | |||
public Tensor[] inputs; | |||
public Operation(IntPtr handle) | |||
{ | |||
Handle = handle; | |||
} | |||
public Operation(Graph g, string opType, string oper_name) | |||
{ | |||
_graph = g; | |||
var status = new Status(); | |||
var desc = c_api.TF_NewOperation(g.Handle, opType, oper_name); | |||
var desc = c_api.TF_NewOperation(g, opType, oper_name); | |||
c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32); | |||
c_api.TF_FinishOperation(desc, status.Handle); | |||
c_api.TF_FinishOperation(desc, status); | |||
} | |||
public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||
@@ -7,30 +7,37 @@ namespace Tensorflow | |||
{ | |||
public static partial class c_api | |||
{ | |||
/// <summary> | |||
/// Get the OpList of all OpDefs defined in this address space. | |||
/// </summary> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static unsafe extern IntPtr TF_GetAllOpList(); | |||
/// <summary> | |||
/// For inputs that take a single tensor. | |||
/// </summary> | |||
/// <param name="desc"></param> | |||
/// <param name="input"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static unsafe extern void TF_AddInput(TF_OperationDescription desc, TF_Output input); | |||
public static unsafe extern void TF_AddInput(IntPtr desc, TF_Output input); | |||
[DllImport(TensorFlowLibName)] | |||
public static unsafe extern IntPtr TF_FinishOperation(TF_OperationDescription desc, IntPtr status); | |||
public static unsafe extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status); | |||
[DllImport(TensorFlowLibName)] | |||
public static unsafe extern TF_OperationDescription TF_NewOperation(IntPtr graph, string opType, string oper_name); | |||
public static unsafe extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe int TF_OperationNumOutputs(IntPtr oper); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe void TF_SetAttrValueProto(TF_OperationDescription desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status); | |||
public static extern unsafe void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, IntPtr value, IntPtr status); | |||
public static extern unsafe void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); | |||
public static extern unsafe void TF_SetAttrType(IntPtr desc, string attr_name, TF_DataType value); | |||
} | |||
} |
@@ -24,7 +24,7 @@ namespace Tensorflow | |||
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs) | |||
{ | |||
var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name); | |||
var op_desc = c_api.TF_NewOperation(graph, node_def.Op, node_def.Name); | |||
// Add inputs | |||
if(inputs != null) | |||
@@ -45,12 +45,12 @@ namespace Tensorflow | |||
var bytes = attr.Value.ToByteArray(); | |||
var proto = Marshal.AllocHGlobal(bytes.Length); | |||
Marshal.Copy(bytes, 0, proto, bytes.Length); | |||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle); | |||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status); | |||
if(status.Code != TF_Code.TF_OK) throw new Exception(status.Message); | |||
} | |||
var c_op = c_api.TF_FinishOperation(op_desc, status.Handle); | |||
var c_op = c_api.TF_FinishOperation(op_desc, status); | |||
if (status.Code != TF_Code.TF_OK) throw new Exception(status.Message); | |||
@@ -30,7 +30,7 @@ namespace Tensorflow | |||
_target = UTF8Encoding.UTF8.GetBytes(target); | |||
var opts = c_api.TF_NewSessionOptions(); | |||
var status = new Status(); | |||
_session = c_api.TF_NewSession(_graph.Handle, opts, status.Handle); | |||
_session = c_api.TF_NewSession(_graph, opts, status); | |||
c_api.TF_DeleteSessionOptions(opts); | |||
} | |||
@@ -40,30 +40,30 @@ namespace Tensorflow | |||
} | |||
public virtual object run(Tensor fetches, FeedDict feed_dict = null) | |||
public virtual object run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||
{ | |||
var result = _run(fetches, feed_dict); | |||
return result; | |||
} | |||
private unsafe object _run(Tensor fetches, FeedDict feed_dict = null) | |||
private unsafe object _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||
{ | |||
var feed_dict_tensor = new FeedDict(); | |||
var feed_dict_tensor = new Dictionary<Tensor, object>(); | |||
if (feed_dict != null) | |||
{ | |||
NDArray np_val = null; | |||
foreach (FeedValue feed in feed_dict) | |||
foreach (var feed in feed_dict) | |||
{ | |||
switch (feed.feed_val) | |||
switch (feed.Value) | |||
{ | |||
case float value: | |||
np_val = np.asarray(value); | |||
break; | |||
} | |||
feed_dict_tensor[feed.feed] = np_val; | |||
feed_dict_tensor[feed.Key] = np_val; | |||
} | |||
} | |||
@@ -85,9 +85,9 @@ namespace Tensorflow | |||
return fetch_handler.build_results(null, results); | |||
} | |||
private object[] _do_run(List<Tensor> fetch_list, FeedDict feed_dict) | |||
private object[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, object> feed_dict) | |||
{ | |||
var feeds = feed_dict.items().Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray(); | |||
var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray(); | |||
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||
return _call_tf_sessionrun(feeds, fetches); | |||
@@ -113,7 +113,7 @@ namespace Tensorflow | |||
target_opers: new IntPtr[] { }, | |||
ntargets: 0, | |||
run_metadata: IntPtr.Zero, | |||
status: status.Handle); | |||
status: status); | |||
var result = output_values.Select(x => c_api.TF_TensorData(x)) | |||
.Select(x => (object)*(float*)x) | |||
@@ -1,59 +0,0 @@ | |||
using System; | |||
using System.Collections; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class FeedDict : IEnumerable | |||
{ | |||
private Dictionary<Tensor, object> feed_dict; | |||
public FeedDict() | |||
{ | |||
feed_dict = new Dictionary<Tensor, object>(); | |||
} | |||
public object this[Tensor feed] | |||
{ | |||
get | |||
{ | |||
return feed_dict[feed]; | |||
} | |||
set | |||
{ | |||
feed_dict[feed] = value; | |||
} | |||
} | |||
public FeedDict Add(Tensor feed, object value) | |||
{ | |||
feed_dict.Add(feed, value); | |||
return this; | |||
} | |||
public IEnumerator GetEnumerator() | |||
{ | |||
foreach (KeyValuePair<Tensor, object> feed in feed_dict) | |||
{ | |||
yield return new FeedValue | |||
{ | |||
feed = feed.Key, | |||
feed_val = feed.Value | |||
}; | |||
} | |||
} | |||
public Dictionary<Tensor, object> items() | |||
{ | |||
return feed_dict; | |||
} | |||
} | |||
public struct FeedValue | |||
{ | |||
public Tensor feed { get; set; } | |||
public object feed_val { get; set; } | |||
} | |||
} |
@@ -15,7 +15,7 @@ namespace Tensorflow | |||
private List<Tensor> _final_fetches = new List<Tensor>(); | |||
private List<object> _targets = new List<object>(); | |||
public _FetchHandler(Graph graph, Tensor fetches, FeedDict feeds = null, object feed_handles = null) | |||
public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, object> feeds = null, object feed_handles = null) | |||
{ | |||
_fetch_mapper = new _FetchMapper().for_fetch(fetches); | |||
foreach(var fetch in _fetch_mapper.unique_fetches()) | |||
@@ -4,10 +4,13 @@ using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class Status : IDisposable | |||
/// <summary> | |||
/// TF_Status holds error information. It either has an OK code, or | |||
/// else an error code with an associated error message. | |||
/// </summary> | |||
public class Status | |||
{ | |||
private readonly IntPtr _handle; | |||
public IntPtr Handle => _handle; | |||
/// <summary> | |||
/// Error message | |||
@@ -29,6 +32,23 @@ namespace Tensorflow | |||
c_api.TF_SetStatus(_handle, code, msg); | |||
} | |||
/// <summary> | |||
/// Check status | |||
/// Throw exception with error message if code != TF_OK | |||
/// </summary> | |||
public void Check() | |||
{ | |||
if(Code != TF_Code.TF_OK) | |||
{ | |||
throw new Exception(Message); | |||
} | |||
} | |||
public static implicit operator IntPtr(Status status) | |||
{ | |||
return status._handle; | |||
} | |||
public void Dispose() | |||
{ | |||
c_api.TF_DeleteStatus(_handle); | |||
@@ -13,6 +13,8 @@ namespace Tensorflow | |||
/// </summary> | |||
public class Tensor | |||
{ | |||
public IntPtr Handle { get; } | |||
public Graph graph => op.graph; | |||
public Operation op { get; } | |||
@@ -21,7 +23,6 @@ namespace Tensorflow | |||
public int value_index { get; } | |||
public TF_DataType dtype { get; } | |||
public IntPtr handle { get; } | |||
public ulong bytesize { get; } | |||
public ulong dataTypeSize { get;} | |||
public ulong size => bytesize / dataTypeSize; | |||
@@ -45,7 +46,7 @@ namespace Tensorflow | |||
public Tensor(IntPtr handle) | |||
{ | |||
this.handle = handle; | |||
Handle = handle; | |||
dtype = c_api.TF_TensorType(handle); | |||
rank = c_api.TF_NumDims(handle); | |||
bytesize = c_api.TF_TensorByteSize(handle); | |||
@@ -59,33 +60,52 @@ namespace Tensorflow | |||
public Tensor(NDArray nd) | |||
{ | |||
var data = Marshal.AllocHGlobal(sizeof(float) * nd.size); | |||
Marshal.Copy(nd.Data<float>(), 0, data, nd.size); | |||
var dataType = ToTFDataType(nd.dtype); | |||
Handle = Allocate(nd); | |||
dtype = c_api.TF_TensorType(Handle); | |||
rank = c_api.TF_NumDims(Handle); | |||
bytesize = c_api.TF_TensorByteSize(Handle); | |||
buffer = c_api.TF_TensorData(Handle); | |||
dataTypeSize = c_api.TF_DataTypeSize(dtype); | |||
shape = new long[rank]; | |||
for (int i = 0; i < rank; i++) | |||
shape[i] = c_api.TF_Dim(Handle, i); | |||
} | |||
private IntPtr Allocate(NDArray nd) | |||
{ | |||
var dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size); | |||
switch (nd.dtype.Name) | |||
{ | |||
case "Int32": | |||
Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size); | |||
break; | |||
case "Single": | |||
Marshal.Copy(nd.Data<float>(), 0, dotHandle, nd.size); | |||
break; | |||
case "Double": | |||
Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size); | |||
break; | |||
default: | |||
throw new NotImplementedException("Marshal.Copy failed."); | |||
} | |||
var handle = c_api.TF_NewTensor(dataType, | |||
var dataType = ToTFDataType(nd.dtype); | |||
var tfHandle = c_api.TF_NewTensor(dataType, | |||
nd.shape.Select(x => (long)x).ToArray(), // shape | |||
nd.ndim, | |||
data, | |||
(UIntPtr)(nd.size * sizeof(float)), | |||
dotHandle, | |||
(UIntPtr)(nd.size * nd.dtypesize), | |||
(IntPtr values, IntPtr len, ref bool closure) => | |||
{ | |||
// Free the original buffer and set flag | |||
Marshal.FreeHGlobal(data); | |||
Marshal.FreeHGlobal(dotHandle); | |||
closure = true; | |||
}, | |||
ref deallocator_called); | |||
this.handle = handle; | |||
dtype = c_api.TF_TensorType(handle); | |||
rank = c_api.TF_NumDims(handle); | |||
bytesize = c_api.TF_TensorByteSize(handle); | |||
buffer = c_api.TF_TensorData(handle); | |||
dataTypeSize = c_api.TF_DataTypeSize(dtype); | |||
shape = new long[rank]; | |||
for (int i = 0; i < rank; i++) | |||
shape[i] = c_api.TF_Dim(handle, i); | |||
return tfHandle; | |||
} | |||
public Tensor(Operation op, int value_index, TF_DataType dtype) | |||
@@ -129,11 +149,20 @@ namespace Tensorflow | |||
{ | |||
switch (type.Name) | |||
{ | |||
case "Int32": | |||
return TF_DataType.TF_INT32; | |||
case "Single": | |||
return TF_DataType.TF_FLOAT; | |||
case "Double": | |||
return TF_DataType.TF_DOUBLE; | |||
} | |||
return TF_DataType.DtInvalid; | |||
} | |||
public static implicit operator IntPtr(Tensor tensor) | |||
{ | |||
return tensor.Handle; | |||
} | |||
} | |||
} |
@@ -10,12 +10,22 @@ namespace Tensorflow | |||
/// | |||
/// The API leans towards simplicity and uniformity instead of convenience | |||
/// since most usage will be by language specific wrappers. | |||
/// | |||
/// The params type mapping between .net and c_api | |||
/// TF_XX** => ref IntPtr (TF_Operation** op) => (ref IntPtr op) | |||
/// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) | |||
/// struct => struct (TF_Output output) => (TF_Output output) | |||
/// const char* => string | |||
/// int32_t => int | |||
/// int64_t* => long[] | |||
/// size_t* => unlong[] | |||
/// void* => IntPtr | |||
/// </summary> | |||
public static partial class c_api | |||
{ | |||
public const string TensorFlowLibName = "tensorflow"; | |||
public delegate void Deallocator(IntPtr data, IntPtr size, ref bool deallocatorData); | |||
public delegate void Deallocator(IntPtr data, IntPtr size, ref bool deallocator); | |||
[DllImport(TensorFlowLibName)] | |||
public static unsafe extern IntPtr TF_Version(); | |||
@@ -3,12 +3,21 @@ using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow; | |||
using Buffer = Tensorflow.Buffer; | |||
namespace TensorFlowNET.UnitTest | |||
{ | |||
[TestClass] | |||
public class OperationsTest | |||
{ | |||
[TestMethod] | |||
public void GetAllOpList() | |||
{ | |||
var handle = c_api.TF_GetAllOpList(); | |||
var buffer = new Buffer(handle); | |||
Assert.IsTrue(buffer.Length == buffer.Data.Length); | |||
} | |||
[TestMethod] | |||
public void addInPlaceholder() | |||
{ | |||
@@ -18,9 +27,9 @@ namespace TensorFlowNET.UnitTest | |||
using(var sess = tf.Session()) | |||
{ | |||
var feed_dict = new FeedDict() | |||
.Add(a, 3.0f) | |||
.Add(b, 2.0f); | |||
var feed_dict = new Dictionary<Tensor, object>(); | |||
feed_dict.Add(a, 3.0f); | |||
feed_dict.Add(b, 2.0f); | |||
var o = sess.run(c, feed_dict); | |||
} | |||
@@ -13,7 +13,7 @@ namespace TensorFlowNET.UnitTest | |||
public class TensorTest | |||
{ | |||
[TestMethod] | |||
public unsafe void NewTensor() | |||
public void NewTensor() | |||
{ | |||
var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); | |||
@@ -27,5 +27,69 @@ namespace TensorFlowNET.UnitTest | |||
Assert.AreEqual(tensor.bytesize, (uint)nd.size * sizeof(float)); | |||
Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array)); | |||
} | |||
/// <summary> | |||
/// Port from tensorflow\c\c_api_test.cc | |||
/// </summary> | |||
[TestMethod] | |||
public void SetShape() | |||
{ | |||
var s = new Status(); | |||
var graph = tf.get_default_graph(); | |||
var desc = c_api.TF_NewOperation(graph, "Placeholder", ""); | |||
c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_FLOAT); | |||
//if (!dims.empty()) | |||
{ | |||
//TF_SetAttrShape(desc, "shape", dims.data(), dims.size()); | |||
} | |||
var op = c_api.TF_FinishOperation(desc, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
Assert.IsNotNull(op); | |||
// Fetch the shape, it should be completely unknown. | |||
var feed_out_0 = new TF_Output { oper = op, index = 0 }; | |||
int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
Assert.AreEqual(-1, num_dims); | |||
// Set the shape to be unknown, expect no change. | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, new int[0], -1, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||
Assert.AreEqual(-1, num_dims); | |||
// Set the shape to be 2 x Unknown | |||
var dims = new int[] { 2, -1 }; | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||
Assert.AreEqual(2, num_dims); | |||
// Get the dimension vector appropriately. | |||
var returned_dims = new int[dims.Length]; | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||
// Set to a new valid shape: [2, 3] | |||
dims[1] = 3; | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | |||
//Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
// Fetch and see that the new value is returned. | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
//Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
//Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||
// Test for a scalar. | |||
var three = c_test_util.ScalarConst(3, graph, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
var three_out_0 = new TF_Output { oper = three.Handle }; | |||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); | |||
Assert.AreEqual(0, num_dims); | |||
} | |||
} | |||
} |
@@ -0,0 +1,37 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Runtime.InteropServices; | |||
using System.Text; | |||
using Tensorflow; | |||
namespace TensorFlowNET.UnitTest | |||
{ | |||
public static class c_test_util | |||
{ | |||
public static void ConstHelper(Tensor t, Graph graph, Status s, string name, ref IntPtr op) | |||
{ | |||
var desc = c_api.TF_NewOperation(graph, "Const", name); | |||
c_api.TF_SetAttrTensor(desc, "value", t.Handle, s); | |||
s.Check(); | |||
c_api.TF_SetAttrType(desc, "dtype", t.dtype); | |||
op = c_api.TF_FinishOperation(desc, s); | |||
s.Check(); | |||
if(op == null) | |||
{ | |||
throw new Exception("c_api.TF_FinishOperation failed."); | |||
} | |||
} | |||
public static Operation Const(Tensor t, Graph graph, Status s, string name) | |||
{ | |||
IntPtr op = IntPtr.Zero; | |||
ConstHelper(t, graph, s, name, ref op); | |||
return new Operation(op); | |||
} | |||
public static Operation ScalarConst(int v, Graph graph, Status s, string name = "Const") | |||
{ | |||
return Const(new Tensor(v), graph, s, name); | |||
} | |||
} | |||
} |