@@ -24,23 +24,36 @@ namespace Tensorflow | |||||
private List<String> _unfetchable_ops = new List<string>(); | private List<String> _unfetchable_ops = new List<string>(); | ||||
private string _name_stack; | private string _name_stack; | ||||
public Status Status { get; } | |||||
public Graph() | public Graph() | ||||
{ | { | ||||
_handle = c_api.TF_NewGraph(); | _handle = c_api.TF_NewGraph(); | ||||
Status = new Status(); | |||||
} | } | ||||
public Graph(IntPtr graph) | public Graph(IntPtr graph) | ||||
{ | { | ||||
_handle = graph; | _handle = graph; | ||||
Status = new Status(); | |||||
_nodes_by_id = new Dictionary<int, Operation>(); | _nodes_by_id = new Dictionary<int, Operation>(); | ||||
_nodes_by_name = new Dictionary<string, Operation>(); | _nodes_by_name = new Dictionary<string, Operation>(); | ||||
_names_in_use = new Dictionary<string, int>(); | _names_in_use = new Dictionary<string, int>(); | ||||
} | } | ||||
public OperationDescription NewOperation(string opType, string opName) | |||||
public Operation NewOperation(string opType, string opName, Tensor t) | |||||
{ | { | ||||
return c_api.TF_NewOperation(_handle, opType, opName); | |||||
var desc = c_api.TF_NewOperation(_handle, opType, opName); | |||||
c_api.TF_SetAttrTensor(desc, "value", t, Status); | |||||
Status.Check(); | |||||
c_api.TF_SetAttrType(desc, "dtype", t.dtype); | |||||
var op = c_api.TF_FinishOperation(desc, Status); | |||||
Status.Check(); | |||||
return op; | |||||
} | } | ||||
public T as_graph_element<T>(T obj, bool allow_tensor = true, bool allow_operation = true) | public T as_graph_element<T>(T obj, bool allow_tensor = true, bool allow_operation = true) | ||||
@@ -7,6 +7,18 @@ namespace Tensorflow | |||||
{ | { | ||||
public static partial class c_api | public static partial class c_api | ||||
{ | { | ||||
/// <summary> | |||||
/// Request that `desc` be co-located on the device where `op` | |||||
/// is placed. | |||||
/// | |||||
/// Use of this is discouraged since the implementation of device placement is | |||||
/// subject to change. Primarily intended for internal libraries | |||||
/// </summary> | |||||
/// <param name="desc"></param> | |||||
/// <param name="op"></param> | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern void TF_ColocateWith(IntPtr desc, IntPtr op); | |||||
/// <summary> | /// <summary> | ||||
/// Get the OpList of all OpDefs defined in this address space. | /// Get the OpList of all OpDefs defined in this address space. | ||||
/// </summary> | /// </summary> | ||||
@@ -209,7 +221,7 @@ namespace Tensorflow | |||||
/// <param name="value">const void*</param> | /// <param name="value">const void*</param> | ||||
/// <param name="length">size_t</param> | /// <param name="length">size_t</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_SetAttrString(IntPtr desc, string attr_name, string value, uint length); | |||||
public static extern void TF_SetAttrString(IntPtr desc, string attr_name, IntPtr value, uint length); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -119,6 +119,9 @@ namespace Tensorflow | |||||
.Select(x => (object)*(float*)x) | .Select(x => (object)*(float*)x) | ||||
.ToArray(); | .ToArray(); | ||||
var op = new Operation(fetch_list[0].oper); | |||||
//var metadata = c_api.TF_OperationGetAttrMetadata(fetch_list[0].oper, "dtype", status); | |||||
return result; | return result; | ||||
} | } | ||||
@@ -7,9 +7,19 @@ namespace Tensorflow | |||||
public class Session : BaseSession | public class Session : BaseSession | ||||
{ | { | ||||
private IntPtr _handle; | private IntPtr _handle; | ||||
public Status Status { get; } | |||||
public SessionOptions Options { get; } | |||||
public Session(string target = "", Graph graph = null) | public Session(string target = "", Graph graph = null) | ||||
{ | { | ||||
Status = new Status(); | |||||
if(graph == null) | |||||
{ | |||||
graph = tf.get_default_graph(); | |||||
} | |||||
Options = new SessionOptions(); | |||||
_handle = c_api.TF_NewSession(graph, Options, Status); | |||||
Status.Check(); | |||||
} | } | ||||
public Session(IntPtr handle) | public Session(IntPtr handle) | ||||
@@ -36,12 +36,15 @@ namespace Tensorflow | |||||
/// Check status | /// Check status | ||||
/// Throw exception with error message if code != TF_OK | /// Throw exception with error message if code != TF_OK | ||||
/// </summary> | /// </summary> | ||||
public void Check() | |||||
public void Check(bool throwException = false) | |||||
{ | { | ||||
if(Code != TF_Code.TF_OK) | if(Code != TF_Code.TF_OK) | ||||
{ | { | ||||
Console.WriteLine(Message); | Console.WriteLine(Message); | ||||
// throw new Exception(Message); | |||||
if (throwException) | |||||
{ | |||||
throw new Exception(Message); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -69,6 +69,7 @@ namespace Tensorflow | |||||
private IntPtr Allocate(NDArray nd) | private IntPtr Allocate(NDArray nd) | ||||
{ | { | ||||
var dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size); | var dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size); | ||||
ulong size = (ulong)(nd.size * nd.dtypesize); | |||||
switch (nd.dtype.Name) | switch (nd.dtype.Name) | ||||
{ | { | ||||
@@ -81,16 +82,21 @@ namespace Tensorflow | |||||
case "Double": | case "Double": | ||||
Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size); | Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size); | ||||
break; | break; | ||||
case "String": | |||||
dotHandle = Marshal.StringToHGlobalAuto(nd.Data<string>()[0]); | |||||
size = (ulong)nd.Data<string>()[0].Length; | |||||
break; | |||||
default: | default: | ||||
throw new NotImplementedException("Marshal.Copy failed."); | throw new NotImplementedException("Marshal.Copy failed."); | ||||
} | } | ||||
var dataType = ToTFDataType(nd.dtype); | var dataType = ToTFDataType(nd.dtype); | ||||
var tfHandle = c_api.TF_NewTensor(dataType, | var tfHandle = c_api.TF_NewTensor(dataType, | ||||
nd.shape.Select(x => (long)x).ToArray(), // shape | nd.shape.Select(x => (long)x).ToArray(), // shape | ||||
nd.ndim, | nd.ndim, | ||||
dotHandle, | dotHandle, | ||||
(ulong)(nd.size * nd.dtypesize), | |||||
size, | |||||
(IntPtr values, IntPtr len, ref bool closure) => | (IntPtr values, IntPtr len, ref bool closure) => | ||||
{ | { | ||||
// Free the original buffer and set flag | // Free the original buffer and set flag | ||||
@@ -154,6 +160,8 @@ namespace Tensorflow | |||||
return TF_DataType.TF_FLOAT; | return TF_DataType.TF_FLOAT; | ||||
case "Double": | case "Double": | ||||
return TF_DataType.TF_DOUBLE; | return TF_DataType.TF_DOUBLE; | ||||
case "String": | |||||
return TF_DataType.TF_STRING; | |||||
} | } | ||||
return TF_DataType.DtInvalid; | return TF_DataType.DtInvalid; | ||||
@@ -34,12 +34,13 @@ namespace Tensorflow | |||||
attrs["dtype"] = dtype_value; | attrs["dtype"] = dtype_value; | ||||
attrs["value"] = tensor_value; | attrs["value"] = tensor_value; | ||||
var const_tensor = g.create_op("Const", | |||||
null, | |||||
new TF_DataType[] { (TF_DataType)dtype_value.Type }, | |||||
var op = g.create_op("Const", | |||||
null, | |||||
new TF_DataType[] { (TF_DataType)dtype_value.Type }, | |||||
attrs: attrs, | attrs: attrs, | ||||
name: name).outputs[0]; | |||||
name: name); | |||||
var const_tensor = op.outputs[0]; | |||||
const_tensor.value = nd.Data(); | const_tensor.value = nd.Data(); | ||||
return const_tensor; | return const_tensor; | ||||
@@ -7,9 +7,9 @@ namespace Tensorflow | |||||
{ | { | ||||
public static partial class tf | public static partial class tf | ||||
{ | { | ||||
public static Tensor constant(NDArray value, string name = "Const", bool verify_shape = false) | |||||
public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) | |||||
{ | { | ||||
return constant_op.Create(value, name, verify_shape); | |||||
return constant_op.Create(nd, name, verify_shape); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -24,7 +24,7 @@ namespace TensorFlowNET.Examples | |||||
var sess = tf.Session(); | var sess = tf.Session(); | ||||
// Run the op | // Run the op | ||||
sess.run(hello); | |||||
Console.WriteLine(sess.run(hello)); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -23,6 +23,8 @@ namespace TensorFlowNET.Examples | |||||
Console.ReadLine(); | Console.ReadLine(); | ||||
} | } | ||||
} | } | ||||
Console.ReadLine(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -25,12 +25,17 @@ namespace TensorFlowNET.UnitTest | |||||
public void SetUp() | public void SetUp() | ||||
{ | { | ||||
feed1_ = c_test_util.Placeholder(graph_, s_, "feed1"); | feed1_ = c_test_util.Placeholder(graph_, s_, "feed1"); | ||||
s_.Check(); | |||||
feed2_ = c_test_util.Placeholder(graph_, s_, "feed2"); | feed2_ = c_test_util.Placeholder(graph_, s_, "feed2"); | ||||
s_.Check(); | |||||
constant_ = c_test_util.ScalarConst(10, graph_, s_); | constant_ = c_test_util.ScalarConst(10, graph_, s_); | ||||
desc_ = graph_.NewOperation("AddN", "add"); | |||||
s_.Check(); | |||||
desc_ = c_api.TF_NewOperation(graph_, "AddN", "add"); | |||||
s_.Check(); | |||||
TF_Output[] inputs = { new TF_Output(feed1_, 0), new TF_Output(constant_, 0) }; | TF_Output[] inputs = { new TF_Output(feed1_, 0), new TF_Output(constant_, 0) }; | ||||
desc_.AddInputList(inputs); | desc_.AddInputList(inputs); | ||||
s_.Check(); | |||||
} | } | ||||
private void SetViaStringList(OperationDescription desc, string[] list) | private void SetViaStringList(OperationDescription desc, string[] list) | ||||
@@ -85,7 +90,8 @@ namespace TensorFlowNET.UnitTest | |||||
[TestMethod] | [TestMethod] | ||||
public void ColocateWith() | public void ColocateWith() | ||||
{ | { | ||||
c_api.TF_ColocateWith(desc_, feed1_); | |||||
FinishAndVerify(desc_, new string[] { "loc:@feed1" }); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||