Browse Source

test string const

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
d7e04d5cb4
11 changed files with 73 additions and 15 deletions
  1. +15
    -2
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +13
    -1
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  3. +3
    -0
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  4. +10
    -0
      src/TensorFlowNET.Core/Sessions/Session.cs
  5. +5
    -2
      src/TensorFlowNET.Core/Status/Status.cs
  6. +9
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  7. +5
    -4
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  8. +2
    -2
      src/TensorFlowNET.Core/Tensors/tf.constant.cs
  9. +1
    -1
      test/TensorFlowNET.Examples/HelloWorld.cs
  10. +2
    -0
      test/TensorFlowNET.Examples/Program.cs
  11. +8
    -2
      test/TensorFlowNET.UnitTest/CApiColocationTest.cs

+ 15
- 2
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -24,23 +24,36 @@ namespace Tensorflow
private List<String> _unfetchable_ops = new List<string>(); private List<String> _unfetchable_ops = new List<string>();


private string _name_stack; private string _name_stack;
public Status Status { get; }


public Graph() public Graph()
{ {
_handle = c_api.TF_NewGraph(); _handle = c_api.TF_NewGraph();
Status = new Status();
} }


public Graph(IntPtr graph) public Graph(IntPtr graph)
{ {
_handle = graph; _handle = graph;
Status = new Status();
_nodes_by_id = new Dictionary<int, Operation>(); _nodes_by_id = new Dictionary<int, Operation>();
_nodes_by_name = new Dictionary<string, Operation>(); _nodes_by_name = new Dictionary<string, Operation>();
_names_in_use = new Dictionary<string, int>(); _names_in_use = new Dictionary<string, int>();
} }


public OperationDescription NewOperation(string opType, string opName)
public Operation NewOperation(string opType, string opName, Tensor t)
{ {
return c_api.TF_NewOperation(_handle, opType, opName);
var desc = c_api.TF_NewOperation(_handle, opType, opName);
c_api.TF_SetAttrTensor(desc, "value", t, Status);
Status.Check();

c_api.TF_SetAttrType(desc, "dtype", t.dtype);

var op = c_api.TF_FinishOperation(desc, Status);
Status.Check();

return op;
} }


public T as_graph_element<T>(T obj, bool allow_tensor = true, bool allow_operation = true) public T as_graph_element<T>(T obj, bool allow_tensor = true, bool allow_operation = true)


+ 13
- 1
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -7,6 +7,18 @@ namespace Tensorflow
{ {
public static partial class c_api public static partial class c_api
{ {
/// <summary>
/// Request that `desc` be co-located on the device where `op`
/// is placed.
///
/// Use of this is discouraged since the implementation of device placement is
/// subject to change. Primarily intended for internal libraries
/// </summary>
/// <param name="desc"></param>
/// <param name="op"></param>
[DllImport(TensorFlowLibName)]
public static extern void TF_ColocateWith(IntPtr desc, IntPtr op);

/// <summary> /// <summary>
/// Get the OpList of all OpDefs defined in this address space. /// Get the OpList of all OpDefs defined in this address space.
/// </summary> /// </summary>
@@ -209,7 +221,7 @@ namespace Tensorflow
/// <param name="value">const void*</param> /// <param name="value">const void*</param>
/// <param name="length">size_t</param> /// <param name="length">size_t</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_SetAttrString(IntPtr desc, string attr_name, string value, uint length);
public static extern void TF_SetAttrString(IntPtr desc, string attr_name, IntPtr value, uint length);


/// <summary> /// <summary>
/// ///


+ 3
- 0
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -119,6 +119,9 @@ namespace Tensorflow
.Select(x => (object)*(float*)x) .Select(x => (object)*(float*)x)
.ToArray(); .ToArray();


var op = new Operation(fetch_list[0].oper);
//var metadata = c_api.TF_OperationGetAttrMetadata(fetch_list[0].oper, "dtype", status);

return result; return result;
} }




+ 10
- 0
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -7,9 +7,19 @@ namespace Tensorflow
public class Session : BaseSession public class Session : BaseSession
{ {
private IntPtr _handle; private IntPtr _handle;
public Status Status { get; }
public SessionOptions Options { get; }


public Session(string target = "", Graph graph = null) public Session(string target = "", Graph graph = null)
{ {
Status = new Status();
if(graph == null)
{
graph = tf.get_default_graph();
}
Options = new SessionOptions();
_handle = c_api.TF_NewSession(graph, Options, Status);
Status.Check();
} }


public Session(IntPtr handle) public Session(IntPtr handle)


+ 5
- 2
src/TensorFlowNET.Core/Status/Status.cs View File

@@ -36,12 +36,15 @@ namespace Tensorflow
/// Check status /// Check status
/// Throw exception with error message if code != TF_OK /// Throw exception with error message if code != TF_OK
/// </summary> /// </summary>
public void Check()
public void Check(bool throwException = false)
{ {
if(Code != TF_Code.TF_OK) if(Code != TF_Code.TF_OK)
{ {
Console.WriteLine(Message); Console.WriteLine(Message);
// throw new Exception(Message);
if (throwException)
{
throw new Exception(Message);
}
} }
} }




+ 9
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -69,6 +69,7 @@ namespace Tensorflow
private IntPtr Allocate(NDArray nd) private IntPtr Allocate(NDArray nd)
{ {
var dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size); var dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size);
ulong size = (ulong)(nd.size * nd.dtypesize);


switch (nd.dtype.Name) switch (nd.dtype.Name)
{ {
@@ -81,16 +82,21 @@ namespace Tensorflow
case "Double": case "Double":
Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size); Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size);
break; break;
case "String":
dotHandle = Marshal.StringToHGlobalAuto(nd.Data<string>()[0]);
size = (ulong)nd.Data<string>()[0].Length;
break;
default: default:
throw new NotImplementedException("Marshal.Copy failed."); throw new NotImplementedException("Marshal.Copy failed.");
} }


var dataType = ToTFDataType(nd.dtype); var dataType = ToTFDataType(nd.dtype);
var tfHandle = c_api.TF_NewTensor(dataType, var tfHandle = c_api.TF_NewTensor(dataType,
nd.shape.Select(x => (long)x).ToArray(), // shape nd.shape.Select(x => (long)x).ToArray(), // shape
nd.ndim, nd.ndim,
dotHandle, dotHandle,
(ulong)(nd.size * nd.dtypesize),
size,
(IntPtr values, IntPtr len, ref bool closure) => (IntPtr values, IntPtr len, ref bool closure) =>
{ {
// Free the original buffer and set flag // Free the original buffer and set flag
@@ -154,6 +160,8 @@ namespace Tensorflow
return TF_DataType.TF_FLOAT; return TF_DataType.TF_FLOAT;
case "Double": case "Double":
return TF_DataType.TF_DOUBLE; return TF_DataType.TF_DOUBLE;
case "String":
return TF_DataType.TF_STRING;
} }


return TF_DataType.DtInvalid; return TF_DataType.DtInvalid;


+ 5
- 4
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -34,12 +34,13 @@ namespace Tensorflow
attrs["dtype"] = dtype_value; attrs["dtype"] = dtype_value;
attrs["value"] = tensor_value; attrs["value"] = tensor_value;


var const_tensor = g.create_op("Const",
null,
new TF_DataType[] { (TF_DataType)dtype_value.Type },
var op = g.create_op("Const",
null,
new TF_DataType[] { (TF_DataType)dtype_value.Type },
attrs: attrs, attrs: attrs,
name: name).outputs[0];
name: name);


var const_tensor = op.outputs[0];
const_tensor.value = nd.Data(); const_tensor.value = nd.Data();


return const_tensor; return const_tensor;


+ 2
- 2
src/TensorFlowNET.Core/Tensors/tf.constant.cs View File

@@ -7,9 +7,9 @@ namespace Tensorflow
{ {
public static partial class tf public static partial class tf
{ {
public static Tensor constant(NDArray value, string name = "Const", bool verify_shape = false)
public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false)
{ {
return constant_op.Create(value, name, verify_shape);
return constant_op.Create(nd, name, verify_shape);
} }
} }
} }

+ 1
- 1
test/TensorFlowNET.Examples/HelloWorld.cs View File

@@ -24,7 +24,7 @@ namespace TensorFlowNET.Examples
var sess = tf.Session(); var sess = tf.Session();


// Run the op // Run the op
sess.run(hello);
Console.WriteLine(sess.run(hello));
} }
} }
} }

+ 2
- 0
test/TensorFlowNET.Examples/Program.cs View File

@@ -23,6 +23,8 @@ namespace TensorFlowNET.Examples
Console.ReadLine(); Console.ReadLine();
} }
} }

Console.ReadLine();
} }
} }
} }

+ 8
- 2
test/TensorFlowNET.UnitTest/CApiColocationTest.cs View File

@@ -25,12 +25,17 @@ namespace TensorFlowNET.UnitTest
public void SetUp() public void SetUp()
{ {
feed1_ = c_test_util.Placeholder(graph_, s_, "feed1"); feed1_ = c_test_util.Placeholder(graph_, s_, "feed1");
s_.Check();
feed2_ = c_test_util.Placeholder(graph_, s_, "feed2"); feed2_ = c_test_util.Placeholder(graph_, s_, "feed2");
s_.Check();
constant_ = c_test_util.ScalarConst(10, graph_, s_); constant_ = c_test_util.ScalarConst(10, graph_, s_);
desc_ = graph_.NewOperation("AddN", "add");
s_.Check();
desc_ = c_api.TF_NewOperation(graph_, "AddN", "add");
s_.Check();


TF_Output[] inputs = { new TF_Output(feed1_, 0), new TF_Output(constant_, 0) }; TF_Output[] inputs = { new TF_Output(feed1_, 0), new TF_Output(constant_, 0) };
desc_.AddInputList(inputs); desc_.AddInputList(inputs);
s_.Check();
} }


private void SetViaStringList(OperationDescription desc, string[] list) private void SetViaStringList(OperationDescription desc, string[] list)
@@ -85,7 +90,8 @@ namespace TensorFlowNET.UnitTest
[TestMethod] [TestMethod]
public void ColocateWith() public void ColocateWith()
{ {

c_api.TF_ColocateWith(desc_, feed1_);
FinishAndVerify(desc_, new string[] { "loc:@feed1" });
} }


[TestMethod] [TestMethod]


Loading…
Cancel
Save