Browse Source

test string const

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
d7e04d5cb4
11 changed files with 73 additions and 15 deletions
  1. +15
    -2
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +13
    -1
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  3. +3
    -0
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  4. +10
    -0
      src/TensorFlowNET.Core/Sessions/Session.cs
  5. +5
    -2
      src/TensorFlowNET.Core/Status/Status.cs
  6. +9
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  7. +5
    -4
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  8. +2
    -2
      src/TensorFlowNET.Core/Tensors/tf.constant.cs
  9. +1
    -1
      test/TensorFlowNET.Examples/HelloWorld.cs
  10. +2
    -0
      test/TensorFlowNET.Examples/Program.cs
  11. +8
    -2
      test/TensorFlowNET.UnitTest/CApiColocationTest.cs

+ 15
- 2
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -24,23 +24,36 @@ namespace Tensorflow
private List<String> _unfetchable_ops = new List<string>();

private string _name_stack;
public Status Status { get; }

public Graph()
{
_handle = c_api.TF_NewGraph();
Status = new Status();
}

public Graph(IntPtr graph)
{
_handle = graph;
Status = new Status();
_nodes_by_id = new Dictionary<int, Operation>();
_nodes_by_name = new Dictionary<string, Operation>();
_names_in_use = new Dictionary<string, int>();
}

public OperationDescription NewOperation(string opType, string opName)
public Operation NewOperation(string opType, string opName, Tensor t)
{
return c_api.TF_NewOperation(_handle, opType, opName);
var desc = c_api.TF_NewOperation(_handle, opType, opName);
c_api.TF_SetAttrTensor(desc, "value", t, Status);
Status.Check();

c_api.TF_SetAttrType(desc, "dtype", t.dtype);

var op = c_api.TF_FinishOperation(desc, Status);
Status.Check();

return op;
}

public T as_graph_element<T>(T obj, bool allow_tensor = true, bool allow_operation = true)


+ 13
- 1
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -7,6 +7,18 @@ namespace Tensorflow
{
public static partial class c_api
{
/// <summary>
/// Request that `desc` be co-located on the device where `op`
/// is placed.
///
/// Use of this is discouraged since the implementation of device placement is
/// subject to change. Primarily intended for internal libraries
/// </summary>
/// <param name="desc"></param>
/// <param name="op"></param>
[DllImport(TensorFlowLibName)]
public static extern void TF_ColocateWith(IntPtr desc, IntPtr op);

/// <summary>
/// Get the OpList of all OpDefs defined in this address space.
/// </summary>
@@ -209,7 +221,7 @@ namespace Tensorflow
/// <param name="value">const void*</param>
/// <param name="length">size_t</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_SetAttrString(IntPtr desc, string attr_name, string value, uint length);
public static extern void TF_SetAttrString(IntPtr desc, string attr_name, IntPtr value, uint length);

/// <summary>
///


+ 3
- 0
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -119,6 +119,9 @@ namespace Tensorflow
.Select(x => (object)*(float*)x)
.ToArray();

var op = new Operation(fetch_list[0].oper);
//var metadata = c_api.TF_OperationGetAttrMetadata(fetch_list[0].oper, "dtype", status);

return result;
}



+ 10
- 0
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -7,9 +7,19 @@ namespace Tensorflow
public class Session : BaseSession
{
private IntPtr _handle;
public Status Status { get; }
public SessionOptions Options { get; }

public Session(string target = "", Graph graph = null)
{
Status = new Status();
if(graph == null)
{
graph = tf.get_default_graph();
}
Options = new SessionOptions();
_handle = c_api.TF_NewSession(graph, Options, Status);
Status.Check();
}

public Session(IntPtr handle)


+ 5
- 2
src/TensorFlowNET.Core/Status/Status.cs View File

@@ -36,12 +36,15 @@ namespace Tensorflow
/// Check status
/// Throw exception with error message if code != TF_OK
/// </summary>
public void Check()
public void Check(bool throwException = false)
{
if(Code != TF_Code.TF_OK)
{
Console.WriteLine(Message);
// throw new Exception(Message);
if (throwException)
{
throw new Exception(Message);
}
}
}



+ 9
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -69,6 +69,7 @@ namespace Tensorflow
private IntPtr Allocate(NDArray nd)
{
var dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size);
ulong size = (ulong)(nd.size * nd.dtypesize);

switch (nd.dtype.Name)
{
@@ -81,16 +82,21 @@ namespace Tensorflow
case "Double":
Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size);
break;
case "String":
dotHandle = Marshal.StringToHGlobalAuto(nd.Data<string>()[0]);
size = (ulong)nd.Data<string>()[0].Length;
break;
default:
throw new NotImplementedException("Marshal.Copy failed.");
}

var dataType = ToTFDataType(nd.dtype);
var tfHandle = c_api.TF_NewTensor(dataType,
nd.shape.Select(x => (long)x).ToArray(), // shape
nd.ndim,
dotHandle,
(ulong)(nd.size * nd.dtypesize),
size,
(IntPtr values, IntPtr len, ref bool closure) =>
{
// Free the original buffer and set flag
@@ -154,6 +160,8 @@ namespace Tensorflow
return TF_DataType.TF_FLOAT;
case "Double":
return TF_DataType.TF_DOUBLE;
case "String":
return TF_DataType.TF_STRING;
}

return TF_DataType.DtInvalid;


+ 5
- 4
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -34,12 +34,13 @@ namespace Tensorflow
attrs["dtype"] = dtype_value;
attrs["value"] = tensor_value;

var const_tensor = g.create_op("Const",
null,
new TF_DataType[] { (TF_DataType)dtype_value.Type },
var op = g.create_op("Const",
null,
new TF_DataType[] { (TF_DataType)dtype_value.Type },
attrs: attrs,
name: name).outputs[0];
name: name);

var const_tensor = op.outputs[0];
const_tensor.value = nd.Data();

return const_tensor;


+ 2
- 2
src/TensorFlowNET.Core/Tensors/tf.constant.cs View File

@@ -7,9 +7,9 @@ namespace Tensorflow
{
public static partial class tf
{
public static Tensor constant(NDArray value, string name = "Const", bool verify_shape = false)
public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false)
{
return constant_op.Create(value, name, verify_shape);
return constant_op.Create(nd, name, verify_shape);
}
}
}

+ 1
- 1
test/TensorFlowNET.Examples/HelloWorld.cs View File

@@ -24,7 +24,7 @@ namespace TensorFlowNET.Examples
var sess = tf.Session();

// Run the op
sess.run(hello);
Console.WriteLine(sess.run(hello));
}
}
}

+ 2
- 0
test/TensorFlowNET.Examples/Program.cs View File

@@ -23,6 +23,8 @@ namespace TensorFlowNET.Examples
Console.ReadLine();
}
}

Console.ReadLine();
}
}
}

+ 8
- 2
test/TensorFlowNET.UnitTest/CApiColocationTest.cs View File

@@ -25,12 +25,17 @@ namespace TensorFlowNET.UnitTest
public void SetUp()
{
feed1_ = c_test_util.Placeholder(graph_, s_, "feed1");
s_.Check();
feed2_ = c_test_util.Placeholder(graph_, s_, "feed2");
s_.Check();
constant_ = c_test_util.ScalarConst(10, graph_, s_);
desc_ = graph_.NewOperation("AddN", "add");
s_.Check();
desc_ = c_api.TF_NewOperation(graph_, "AddN", "add");
s_.Check();

TF_Output[] inputs = { new TF_Output(feed1_, 0), new TF_Output(constant_, 0) };
desc_.AddInputList(inputs);
s_.Check();
}

private void SetViaStringList(OperationDescription desc, string[] list)
@@ -85,7 +90,8 @@ namespace TensorFlowNET.UnitTest
[TestMethod]
public void ColocateWith()
{

c_api.TF_ColocateWith(desc_, feed1_);
FinishAndVerify(desc_, new string[] { "loc:@feed1" });
}

[TestMethod]


Loading…
Cancel
Save