- Ignored all unit tests related to CSession as it does not use TF.NET's api directly and unable to be tested with other tests parallely.tags/v0.12
@@ -22,58 +22,54 @@ using System.Linq; | |||
using Tensorflow.Util; | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// Represents a graph node that performs computation on tensors. | |||
/// | |||
/// An `Operation` is a node in a TensorFlow `Graph` that takes zero or | |||
/// more `Tensor` objects as input, and produces zero or more `Tensor` | |||
/// objects as output. Objects of type `Operation` are created by | |||
/// calling an op constructor(such as `tf.matmul`) | |||
/// or `tf.Graph.create_op`. | |||
/// | |||
/// For example `c = tf.matmul(a, b)` creates an `Operation` of type | |||
/// "MatMul" that takes tensors `a` and `b` as input, and produces `c` | |||
/// as output. | |||
/// | |||
/// After the graph has been launched in a session, an `Operation` can | |||
/// be executed by passing it to | |||
/// `tf.Session.run`. | |||
/// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. | |||
{ | |||
/// <summary> | |||
/// Represents a graph node that performs computation on tensors. | |||
/// | |||
/// An `Operation` is a node in a TensorFlow `Graph` that takes zero or | |||
/// more `Tensor` objects as input, and produces zero or more `Tensor` | |||
/// objects as output. Objects of type `Operation` are created by | |||
/// calling an op constructor(such as `tf.matmul`) | |||
/// or `tf.Graph.create_op`. | |||
/// | |||
/// For example `c = tf.matmul(a, b)` creates an `Operation` of type | |||
/// "MatMul" that takes tensors `a` and `b` as input, and produces `c` | |||
/// as output. | |||
/// | |||
/// After the graph has been launched in a session, an `Operation` can | |||
/// be executed by passing it to | |||
/// `tf.Session.run`. | |||
/// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. | |||
/// </summary> | |||
public partial class Operation : ITensorOrOperation | |||
{ | |||
private readonly IntPtr _handle; // _c_op in python | |||
private readonly IntPtr _operDesc; | |||
private readonly IntPtr _operDesc; | |||
private readonly Graph _graph; | |||
private NodeDef _node_def; | |||
private Graph _graph; | |||
public string type => OpType; | |||
public Graph graph => _graph; | |||
public int _id => _id_value; | |||
public int _id_value; | |||
public Operation op => this; | |||
public TF_DataType dtype => TF_DataType.DtInvalid; | |||
public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); | |||
public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | |||
public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||
private NodeDef _node_def; | |||
public NodeDef node_def | |||
{ | |||
get | |||
{ | |||
if(_node_def == null) | |||
if (_node_def == null) | |||
_node_def = GetNodeDef(); | |||
return _node_def; | |||
} | |||
} | |||
public Operation(IntPtr handle, Graph g=null) | |||
public Operation(IntPtr handle, Graph g = null) | |||
{ | |||
if (handle == IntPtr.Zero) | |||
return; | |||
@@ -97,14 +93,15 @@ namespace Tensorflow | |||
_operDesc = c_api.TF_NewOperation(g, opType, oper_name); | |||
c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); | |||
using (var status = new Status()) | |||
{ | |||
_handle = c_api.TF_FinishOperation(_operDesc, status); | |||
status.Check(true); | |||
} | |||
// Dict mapping op name to file and line information for op colocation | |||
// context managers. | |||
lock (Locks.ProcessWide) | |||
using (var status = new Status()) | |||
{ | |||
_handle = c_api.TF_FinishOperation(_operDesc, status); | |||
status.Check(true); | |||
} | |||
// Dict mapping op name to file and line information for op colocation | |||
// context managers. | |||
_control_flow_context = graph._get_control_flow_context(); | |||
} | |||
@@ -133,9 +130,9 @@ namespace Tensorflow | |||
// Build the list of control inputs. | |||
var control_input_ops = new List<Operation>(); | |||
if(control_inputs != null) | |||
if (control_inputs != null) | |||
{ | |||
foreach(var c in control_inputs) | |||
foreach (var c in control_inputs) | |||
{ | |||
switch (c) | |||
{ | |||
@@ -196,15 +193,13 @@ namespace Tensorflow | |||
{ | |||
if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | |||
{ | |||
input_len = (int)attrs[input_arg.NumberAttr].I; | |||
input_len = (int) attrs[input_arg.NumberAttr].I; | |||
is_sequence = true; | |||
} | |||
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | |||
} else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | |||
{ | |||
input_len = attrs[input_arg.TypeListAttr].List.Type.Count; | |||
is_sequence = true; | |||
} | |||
else | |||
} else | |||
{ | |||
input_len = 1; | |||
is_sequence = false; | |||
@@ -225,22 +220,21 @@ namespace Tensorflow | |||
{ | |||
AttrValue x = null; | |||
using (var status = new Status()) | |||
using (var buf = new Buffer()) | |||
{ | |||
unsafe | |||
lock (Locks.ProcessWide) | |||
using (var status = new Status()) | |||
using (var buf = new Buffer()) | |||
{ | |||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||
status.Check(true); | |||
x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); | |||
} | |||
} | |||
string oneof_value = x.ValueCase.ToString(); | |||
if (string.IsNullOrEmpty(oneof_value)) | |||
return null; | |||
if(oneof_value == "list") | |||
if (oneof_value == "list") | |||
throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); | |||
if (oneof_value == "type") | |||
@@ -259,60 +253,63 @@ namespace Tensorflow | |||
private NodeDef GetNodeDef() | |||
{ | |||
using (var s = new Status()) | |||
using (var buffer = new Buffer()) | |||
{ | |||
c_api.TF_OperationToNodeDef(_handle, buffer, s); | |||
s.Check(); | |||
return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
} | |||
} | |||
/// <summary> | |||
/// Update the input to this operation at the given index. | |||
/// | |||
/// NOTE: This is for TF internal use only.Please don't use it. | |||
/// </summary> | |||
/// <param name="index">the index of the input to update.</param> | |||
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param> | |||
public void _update_input(int index, Tensor tensor) | |||
{ | |||
_assert_same_graph(tensor); | |||
var input = _tf_input(index); | |||
var output = tensor._as_tf_output(); | |||
// Reset cached inputs. | |||
_inputs = null; | |||
// after the c_api call next time _inputs is accessed | |||
// the updated inputs are reloaded from the c_api | |||
using (var status = new Status()) | |||
{ | |||
c_api.UpdateEdge(_graph, output, input, status); | |||
//var updated_inputs = inputs; | |||
status.Check(); | |||
} | |||
} | |||
private void _assert_same_graph(Tensor tensor) | |||
{ | |||
//TODO: implement | |||
} | |||
/// <summary> | |||
/// Create and return a new TF_Output for output_idx'th output of this op. | |||
/// </summary> | |||
public TF_Output _tf_output(int output_idx) | |||
{ | |||
return new TF_Output(op, output_idx); | |||
} | |||
/// <summary> | |||
/// Create and return a new TF_Input for input_idx'th input of this op. | |||
/// </summary> | |||
public TF_Input _tf_input(int input_idx) | |||
{ | |||
return new TF_Input(op, input_idx); | |||
} | |||
} | |||
} | |||
lock (Locks.ProcessWide) | |||
using (var s = new Status()) | |||
using (var buffer = new Buffer()) | |||
{ | |||
c_api.TF_OperationToNodeDef(_handle, buffer, s); | |||
s.Check(); | |||
return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
} | |||
} | |||
/// <summary> | |||
/// Update the input to this operation at the given index. | |||
/// | |||
/// NOTE: This is for TF internal use only.Please don't use it. | |||
/// </summary> | |||
/// <param name="index">the index of the input to update.</param> | |||
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param> | |||
public void _update_input(int index, Tensor tensor) | |||
{ | |||
_assert_same_graph(tensor); | |||
var input = _tf_input(index); | |||
var output = tensor._as_tf_output(); | |||
// Reset cached inputs. | |||
_inputs = null; | |||
// after the c_api call next time _inputs is accessed | |||
// the updated inputs are reloaded from the c_api | |||
lock (Locks.ProcessWide) | |||
using (var status = new Status()) | |||
{ | |||
c_api.UpdateEdge(_graph, output, input, status); | |||
//var updated_inputs = inputs; | |||
status.Check(); | |||
} | |||
} | |||
private void _assert_same_graph(Tensor tensor) | |||
{ | |||
//TODO: implement | |||
} | |||
/// <summary> | |||
/// Create and return a new TF_Output for output_idx'th output of this op. | |||
/// </summary> | |||
public TF_Output _tf_output(int output_idx) | |||
{ | |||
return new TF_Output(op, output_idx); | |||
} | |||
/// <summary> | |||
/// Create and return a new TF_Input for input_idx'th input of this op. | |||
/// </summary> | |||
public TF_Input _tf_input(int input_idx) | |||
{ | |||
return new TF_Input(op, input_idx); | |||
} | |||
} | |||
} |
@@ -36,23 +36,20 @@ namespace Tensorflow | |||
protected byte[] _target; | |||
public Graph graph => _graph; | |||
public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | |||
public BaseSession(string target = "", Graph g = null, SessionOptions opts = null, Status status = null) | |||
{ | |||
_graph = g ?? ops.get_default_graph(); | |||
_graph.as_default(); | |||
_target = Encoding.UTF8.GetBytes(target); | |||
SessionOptions newOpts = opts ?? new SessionOptions(); | |||
SessionOptions lopts = opts ?? new SessionOptions(); | |||
var status = new Status(); | |||
_handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | |||
// dispose opts only if not provided externally. | |||
if (opts == null) | |||
newOpts.Dispose(); | |||
status.Check(true); | |||
lock (Locks.ProcessWide) | |||
{ | |||
status = status ?? new Status(); | |||
_handle = c_api.TF_NewSession(_graph, opts ?? lopts, status); | |||
status.Check(true); | |||
} | |||
} | |||
public virtual void run(Operation op, params FeedItem[] feed_dict) | |||
@@ -72,19 +69,19 @@ namespace Tensorflow | |||
public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
{ | |||
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||
var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4}, feed_dict); | |||
return (results[0], results[1], results[2], results[3]); | |||
} | |||
public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
{ | |||
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||
var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3}, feed_dict); | |||
return (results[0], results[1], results[2]); | |||
} | |||
public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
{ | |||
var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||
var results = _run(new object[] {fetches.Item1, fetches.Item2}, feed_dict); | |||
return (results[0], results[1]); | |||
} | |||
@@ -95,8 +92,7 @@ namespace Tensorflow | |||
public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | |||
{ | |||
var feed_items = feed_dict == null ? new FeedItem[0] : | |||
feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||
var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||
return _run(fetches, feed_items); | |||
} | |||
@@ -130,7 +126,7 @@ namespace Tensorflow | |||
// We only want to really perform the run if fetches or targets are provided, | |||
// or if the call is a partial run that specifies feeds. | |||
var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); | |||
var results = _do_run(final_targets.Select(x => (Operation) x).ToList(), final_fetches, feed_dict_tensor); | |||
return fetch_handler.build_results(this, results); | |||
} | |||
@@ -150,7 +146,6 @@ namespace Tensorflow | |||
/// </returns> | |||
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||
{ | |||
var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count]; | |||
int i = 0; | |||
foreach (var x in feed_dict) | |||
@@ -159,16 +154,25 @@ namespace Tensorflow | |||
{ | |||
switch (x.Value) | |||
{ | |||
case Tensor v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); break; | |||
case NDArray v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); break; | |||
case IntPtr v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
case Tensor v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); | |||
break; | |||
case NDArray v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); | |||
break; | |||
case IntPtr v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
break; | |||
#if _REGEN | |||
// @formatter:off — disable formatter after this line | |||
%types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | |||
%foreach types% | |||
case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
% | |||
// @formatter:on — enable formatter after this line | |||
#else | |||
// @formatter:off — disable formatter after this line | |||
case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
@@ -191,9 +195,14 @@ namespace Tensorflow | |||
case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
// @formatter:on — enable formatter after this line | |||
#endif | |||
case bool v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); break; | |||
case string v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||
case bool v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); | |||
break; | |||
case string v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||
break; | |||
default: | |||
throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}"); | |||
} | |||
@@ -217,12 +226,12 @@ namespace Tensorflow | |||
c_api.TF_SessionRun(_handle, | |||
run_options: null, | |||
inputs: feed_dict.Select(f => f.Key).ToArray(), | |||
input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), | |||
input_values: feed_dict.Select(f => (IntPtr) f.Value).ToArray(), | |||
ninputs: feed_dict.Length, | |||
outputs: fetch_list, | |||
output_values: output_values, | |||
noutputs: fetch_list.Length, | |||
target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||
target_opers: target_list.Select(f => (IntPtr) f).ToArray(), | |||
ntargets: target_list.Count, | |||
run_metadata: IntPtr.Zero, | |||
status: status); | |||
@@ -253,7 +262,7 @@ namespace Tensorflow | |||
ret = NDArray.Scalar(*(bool*) srcAddress); | |||
break; | |||
case TF_DataType.TF_STRING: | |||
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize))) | |||
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long) tensor.bytesize))) | |||
ret = NDArray.FromString(reader.ReadString()); | |||
break; | |||
case TF_DataType.TF_UINT8: | |||
@@ -318,81 +327,95 @@ namespace Tensorflow | |||
#endregion | |||
#else | |||
#region Compute | |||
switch (tensor.dtype) | |||
{ | |||
case TF_DataType.TF_BOOL: | |||
{ | |||
#region Compute | |||
switch (tensor.dtype) | |||
{ | |||
case TF_DataType.TF_BOOL: | |||
{ | |||
ret = new NDArray(NPTypeCode.Boolean, ndims, false); | |||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
break; | |||
} | |||
case TF_DataType.TF_UINT8: | |||
{ | |||
break; | |||
} | |||
case TF_DataType.TF_UINT8: | |||
{ | |||
ret = new NDArray(NPTypeCode.Byte, ndims, false); | |||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
break; | |||
} | |||
case TF_DataType.TF_INT16: | |||
{ | |||
break; | |||
} | |||
case TF_DataType.TF_INT16: | |||
{ | |||
ret = new NDArray(NPTypeCode.Int16, ndims, false); | |||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
break; | |||
} | |||
case TF_DataType.TF_UINT16: | |||
{ | |||
break; | |||
} | |||
case TF_DataType.TF_UINT16: | |||
{ | |||
ret = new NDArray(NPTypeCode.UInt16, ndims, false); | |||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
break; | |||
} | |||
case TF_DataType.TF_INT32: | |||
{ | |||
break; | |||
} | |||
case TF_DataType.TF_INT32: | |||
{ | |||
ret = new NDArray(NPTypeCode.Int32, ndims, false); | |||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
break; | |||
} | |||
case TF_DataType.TF_UINT32: | |||
{ | |||
break; | |||
} | |||
case TF_DataType.TF_UINT32: | |||
{ | |||
ret = new NDArray(NPTypeCode.UInt32, ndims, false); | |||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
break; | |||
} | |||
case TF_DataType.TF_INT64: | |||
{ | |||
break; | |||
} | |||
case TF_DataType.TF_INT64: | |||
{ | |||
ret = new NDArray(NPTypeCode.Int64, ndims, false); | |||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
break; | |||
} | |||
case TF_DataType.TF_UINT64: | |||
{ | |||
break; | |||
} | |||
case TF_DataType.TF_UINT64: | |||
{ | |||
ret = new NDArray(NPTypeCode.UInt64, ndims, false); | |||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
break; | |||
} | |||
case TF_DataType.TF_DOUBLE: | |||
{ | |||
break; | |||
} | |||
case TF_DataType.TF_DOUBLE: | |||
{ | |||
ret = new NDArray(NPTypeCode.Double, ndims, false); | |||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
break; | |||
} | |||
case TF_DataType.TF_FLOAT: | |||
{ | |||
break; | |||
} | |||
case TF_DataType.TF_FLOAT: | |||
{ | |||
ret = new NDArray(NPTypeCode.Single, ndims, false); | |||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | |||
break; | |||
} | |||
break; | |||
} | |||
case TF_DataType.TF_STRING: | |||
{ | |||
throw new NotImplementedException(); | |||
//TODO:! This is not the way to handle string[], it should be done with TF_DecodeString | |||
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize))) | |||
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long) tensor.bytesize))) | |||
ret = NDArray.FromString(reader.ReadString()); | |||
break; | |||
} | |||
default: | |||
throw new NotSupportedException(); | |||
} | |||
#endregion | |||
default: | |||
throw new NotSupportedException(); | |||
} | |||
#endregion | |||
#endif | |||
} | |||
} | |||
@@ -411,9 +434,7 @@ namespace Tensorflow | |||
} | |||
private void _extend_graph() | |||
{ | |||
} | |||
{ } | |||
public void close() | |||
{ | |||
@@ -422,11 +443,12 @@ namespace Tensorflow | |||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||
{ | |||
using (var status = new Status()) | |||
{ | |||
c_api.TF_DeleteSession(handle, status); | |||
status.Check(true); | |||
} | |||
lock (Locks.ProcessWide) | |||
using (var status = new Status()) | |||
{ | |||
c_api.TF_DeleteSession(handle, status); | |||
status.Check(true); | |||
} | |||
} | |||
} | |||
} |
@@ -21,24 +21,16 @@ namespace Tensorflow | |||
{ | |||
public class Session : BaseSession, IObjectLife | |||
{ | |||
public Session(string target = "", Graph g = null) | |||
: base(target, g, null) | |||
{ | |||
} | |||
public Session(string target = "", Graph g = null) : base(target, g, null) | |||
{ } | |||
public Session(IntPtr handle, Graph g = null) | |||
: base("", g, null) | |||
public Session(IntPtr handle, Graph g = null) : base("", g, null) | |||
{ | |||
_handle = handle; | |||
} | |||
public Session(Graph g, SessionOptions opts = null, Status s = null) | |||
: base("", g, opts) | |||
{ | |||
if (s == null) | |||
s = new Status(); | |||
} | |||
public Session(Graph g, SessionOptions opts = null, Status s = null) : base("", g, opts, s) | |||
{ } | |||
public Session as_default() | |||
{ | |||
@@ -21,6 +21,7 @@ using Google.Protobuf; | |||
using System.Linq; | |||
using System.Threading; | |||
using NumSharp; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
@@ -207,47 +208,49 @@ namespace Tensorflow | |||
/// <returns>A wrapped TF_Operation*.</returns> | |||
public static (IntPtr, IntPtr) _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) | |||
{ | |||
var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | |||
//TODO: Implement TF_SetDevice | |||
//if node_def.device: | |||
// c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device)) | |||
// Add inputs | |||
foreach (var op_input in inputs) | |||
lock (Locks.ProcessWide) | |||
{ | |||
if (op_input is Tensor[] op_inputs) | |||
c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length); | |||
else if (op_input is Tensor op_input1) | |||
var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | |||
//TODO: Implement TF_SetDevice | |||
//if node_def.device: | |||
// c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device)) | |||
// Add inputs | |||
foreach (var op_input in inputs) | |||
{ | |||
c_api.TF_AddInput(op_desc, op_input1._as_tf_output()); | |||
if (op_input is Tensor[] op_inputs) | |||
c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length); | |||
else if (op_input is Tensor op_input1) | |||
{ | |||
c_api.TF_AddInput(op_desc, op_input1._as_tf_output()); | |||
} else | |||
throw new NotImplementedException("_create_c_op"); | |||
} | |||
else | |||
throw new NotImplementedException("_create_c_op"); | |||
} | |||
var status = new Status(); | |||
var status = new Status(); | |||
// Add control inputs | |||
foreach (var control_input in control_inputs) | |||
c_api.TF_AddControlInput(op_desc, control_input); | |||
// Add control inputs | |||
foreach (var control_input in control_inputs) | |||
c_api.TF_AddControlInput(op_desc, control_input); | |||
// Add attrs | |||
foreach (var attr in node_def.Attr) | |||
{ | |||
var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||
var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak | |||
Marshal.Copy(bytes, 0, proto, bytes.Length); | |||
uint len = (uint)bytes.Length; | |||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | |||
// Add attrs | |||
foreach (var attr in node_def.Attr) | |||
{ | |||
var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||
var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak | |||
Marshal.Copy(bytes, 0, proto, bytes.Length); | |||
uint len = (uint) bytes.Length; | |||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | |||
status.Check(true); | |||
} | |||
status.Check(true); | |||
} | |||
var c_op = c_api.TF_FinishOperation(op_desc, status); | |||
var c_op = c_api.TF_FinishOperation(op_desc, status); | |||
status.Check(true); | |||
status.Check(true); | |||
return (c_op, op_desc); | |||
return (c_op, op_desc); | |||
} | |||
} | |||
public static OpDef _get_op_def(Graph graph, string type) | |||
@@ -11,7 +11,7 @@ namespace TensorFlowNET.UnitTest | |||
/// tensorflow\c\c_api_test.cc | |||
/// `class CApiGradientsTest` | |||
/// </summary> | |||
[TestClass] | |||
[TestClass, Ignore] | |||
public class CApiGradientsTest : CApiTest, IDisposable | |||
{ | |||
private Graph graph_ = new Graph(); | |||
@@ -2,6 +2,7 @@ | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow; | |||
using Tensorflow.Util; | |||
namespace TensorFlowNET.UnitTest | |||
{ | |||
@@ -22,9 +23,12 @@ namespace TensorFlowNET.UnitTest | |||
public CSession(Graph graph, Status s, bool user_XLA = false) | |||
{ | |||
var opts = new SessionOptions(); | |||
opts.SetConfig(new ConfigProto { InterOpParallelismThreads = 4 }); | |||
session_ = new Session(graph, opts, s); | |||
lock (Locks.ProcessWide) | |||
{ | |||
var opts = new SessionOptions(); | |||
opts.SetConfig(new ConfigProto {InterOpParallelismThreads = 4}); | |||
session_ = new Session(graph, opts, s); | |||
} | |||
} | |||
public void SetInputs(Dictionary<Operation, Tensor> inputs) | |||
@@ -64,13 +68,13 @@ namespace TensorFlowNET.UnitTest | |||
public unsafe void Run(Status s) | |||
{ | |||
var inputs_ptr = inputs_.ToArray(); | |||
var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray(); | |||
var input_values_ptr = input_values_.Select(x => (IntPtr) x).ToArray(); | |||
var outputs_ptr = outputs_.ToArray(); | |||
var output_values_ptr = output_values_.Select(x => IntPtr.Zero).ToArray(); | |||
IntPtr[] targets_ptr = new IntPtr[0]; | |||
c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, | |||
outputs_ptr, output_values_ptr, outputs_.Count, | |||
outputs_ptr, output_values_ptr, outputs_.Count, | |||
targets_ptr, targets_.Count, | |||
IntPtr.Zero, s); | |||
@@ -90,4 +94,4 @@ namespace TensorFlowNET.UnitTest | |||
ResetOutputValues(); | |||
} | |||
} | |||
} | |||
} |
@@ -207,7 +207,7 @@ namespace TensorFlowNET.UnitTest | |||
public void ImportGraphDef() | |||
{ | |||
var s = new Status(); | |||
var graph = new Graph(); | |||
var graph = new Graph().as_default(); | |||
// Create a simple graph. | |||
c_test_util.Placeholder(graph, s); | |||
@@ -221,7 +221,7 @@ namespace TensorFlowNET.UnitTest | |||
// Import it, with a prefix, in a fresh graph. | |||
graph.Dispose(); | |||
graph = new Graph(); | |||
graph = new Graph().as_default(); | |||
var opts = c_api.TF_NewImportGraphDefOptions(); | |||
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); | |||
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||
@@ -359,7 +359,7 @@ namespace TensorFlowNET.UnitTest | |||
public void ImportGraphDef_WithReturnOutputs() | |||
{ | |||
var s = new Status(); | |||
var graph = new Graph(); | |||
var graph = new Graph().as_default(); | |||
// Create a graph with two nodes: x and 3 | |||
c_test_util.Placeholder(graph, s); | |||
@@ -375,7 +375,7 @@ namespace TensorFlowNET.UnitTest | |||
// Import it in a fresh graph with return outputs. | |||
graph.Dispose(); | |||
graph = new Graph(); | |||
graph = new Graph().as_default(); | |||
var opts = new ImportGraphDefOptions(); | |||
opts.AddReturnOutput("feed", 0); | |||
opts.AddReturnOutput("scalar", 0); | |||
@@ -4,6 +4,7 @@ using System.Runtime.InteropServices; | |||
using FluentAssertions; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Tensorflow; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest | |||
@@ -14,7 +15,7 @@ namespace TensorFlowNET.UnitTest | |||
[TestMethod] | |||
public void SessionCreation() | |||
{ | |||
tf.Session(); //create one to increase next id to 1. | |||
ops.uid(); //increment id by one | |||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||
@@ -23,6 +24,28 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
tf.peak_default_graph().Should().BeNull(); | |||
using (var sess = tf.Session()) | |||
{ | |||
var default_graph = tf.peak_default_graph(); | |||
var sess_graph = sess.GetPrivate<Graph>("_graph"); | |||
sess_graph.Should().NotBeNull(); | |||
default_graph.Should().NotBeNull() | |||
.And.BeEquivalentTo(sess_graph); | |||
} | |||
} | |||
} | |||
[TestMethod] | |||
public void SessionCreation_x2() | |||
{ | |||
ops.uid(); //increment id by one | |||
MultiThreadedUnitTestExecuter.Run(16, Core); | |||
//the core method | |||
void Core(int tid) | |||
{ | |||
tf.peak_default_graph().Should().BeNull(); | |||
//tf.Session created an other graph | |||
using (var sess = tf.Session()) | |||
{ | |||
@@ -38,7 +61,7 @@ namespace TensorFlowNET.UnitTest | |||
[TestMethod] | |||
public void GraphCreation() | |||
{ | |||
tf.Graph(); //create one to increase next id to 1. | |||
ops.uid(); //increment id by one | |||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||
@@ -47,7 +70,7 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
tf.peak_default_graph().Should().BeNull(); | |||
var beforehand = tf.get_default_graph(); //this should create default automatically. | |||
beforehand.graph_key.Should().NotContain("0", "Already created a graph in an other thread."); | |||
beforehand.graph_key.Should().NotContain("-0/", "Already created a graph in an other thread."); | |||
tf.peak_default_graph().Should().NotBeNull(); | |||
using (var sess = tf.Session()) | |||
@@ -67,5 +90,174 @@ namespace TensorFlowNET.UnitTest | |||
} | |||
} | |||
} | |||
[TestMethod] | |||
public void Marshal_AllocHGlobal() | |||
{ | |||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||
//the core method | |||
void Core(int tid) | |||
{ | |||
for (int i = 0; i < 100; i++) | |||
{ | |||
Marshal.FreeHGlobal(Marshal.AllocHGlobal(sizeof(int))); | |||
} | |||
} | |||
} | |||
[TestMethod] | |||
public void TensorCreation() | |||
{ | |||
//lock (Locks.ProcessWide) | |||
// tf.Session(); //create one to increase next id to 1. | |||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||
//the core method | |||
void Core(int tid) | |||
{ | |||
using (var sess = tf.Session()) | |||
{ | |||
Tensor t = null; | |||
for (int i = 0; i < 100; i++) | |||
{ | |||
t = new Tensor(1); | |||
} | |||
} | |||
} | |||
} | |||
[TestMethod] | |||
public void TensorCreation_Array() | |||
{ | |||
//lock (Locks.ProcessWide) | |||
// tf.Session(); //create one to increase next id to 1. | |||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||
//the core method | |||
void Core(int tid) | |||
{ | |||
//tf.Session created an other graph | |||
using (var sess = tf.Session()) | |||
{ | |||
Tensor t = null; | |||
for (int i = 0; i < 100; i++) | |||
{ | |||
t = new Tensor(new int[] {1, 2, 3}); | |||
} | |||
} | |||
} | |||
} | |||
[TestMethod] | |||
public void TensorCreation_Undressed() | |||
{ | |||
//lock (Locks.ProcessWide) | |||
// tf.Session(); //create one to increase next id to 1. | |||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||
//the core method | |||
unsafe void Core(int tid) | |||
{ | |||
using (var sess = tf.Session()) | |||
{ | |||
Tensor t = null; | |||
for (int i = 0; i < 100; i++) | |||
{ | |||
var v = (int*) Marshal.AllocHGlobal(sizeof(int)); | |||
c_api.DeallocatorArgs _deallocatorArgs = new c_api.DeallocatorArgs(); | |||
var handle = c_api.TF_NewTensor(typeof(int).as_dtype(), dims: new long[0], num_dims: 0, | |||
data: (IntPtr) v, len: (UIntPtr) sizeof(int), | |||
deallocator: (IntPtr data, IntPtr size, ref c_api.DeallocatorArgs args) => Marshal.FreeHGlobal(data), | |||
ref _deallocatorArgs); | |||
c_api.TF_DeleteTensor(handle); | |||
} | |||
} | |||
} | |||
} | |||
[TestMethod] | |||
public void SessionRun() | |||
{ | |||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||
//the core method | |||
void Core(int tid) | |||
{ | |||
tf.peak_default_graph().Should().BeNull(); | |||
//graph is created automatically to perform create these operations | |||
var a1 = tf.constant(new[] {2f}, shape: new[] {1}); | |||
var a2 = tf.constant(new[] {3f}, shape: new[] {1}); | |||
var math = a1 + a2; | |||
for (int i = 0; i < 100; i++) | |||
{ | |||
using (var sess = tf.Session()) | |||
{ | |||
sess.run(math).GetAtIndex<float>(0).Should().Be(5); | |||
} | |||
} | |||
} | |||
} | |||
[TestMethod] | |||
public void SessionRun_InsideSession() | |||
{ | |||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||
//the core method | |||
void Core(int tid) | |||
{ | |||
using (var sess = tf.Session()) | |||
{ | |||
tf.peak_default_graph().Should().NotBeNull(); | |||
//graph is created automatically to perform create these operations | |||
var a1 = tf.constant(new[] {2f}, shape: new[] {1}); | |||
var a2 = tf.constant(new[] {3f}, shape: new[] {1}); | |||
var math = a1 + a2; | |||
var result = sess.run(math); | |||
result[0].GetAtIndex<float>(0).Should().Be(5); | |||
} | |||
} | |||
} | |||
[TestMethod] | |||
public void SessionRun_Initialization() | |||
{ | |||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||
//the core method | |||
void Core(int tid) | |||
{ | |||
using (var sess = tf.Session()) | |||
{ | |||
tf.peak_default_graph().Should().NotBeNull(); | |||
//graph is created automatically to perform create these operations | |||
var a1 = tf.constant(new[] {2f}, shape: new[] {1}); | |||
var a2 = tf.constant(new[] {3f}, shape: new[] {1}); | |||
var math = a1 + a2; | |||
} | |||
} | |||
} | |||
[TestMethod] | |||
public void SessionRun_Initialization_OutsideSession() | |||
{ | |||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||
//the core method | |||
void Core(int tid) | |||
{ | |||
tf.peak_default_graph().Should().BeNull(); | |||
//graph is created automatically to perform create these operations | |||
var a1 = tf.constant(new[] {2f}, shape: new[] {1}); | |||
var a2 = tf.constant(new[] {3f}, shape: new[] {1}); | |||
var math = a1 + a2; | |||
} | |||
} | |||
} | |||
} |
@@ -8,6 +8,7 @@ using System.Text; | |||
using FluentAssertions; | |||
using Google.Protobuf; | |||
using Tensorflow; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest | |||
@@ -19,13 +20,13 @@ namespace TensorFlowNET.UnitTest | |||
/// tensorflow\c\c_api_test.cc | |||
/// `TEST(CAPI, Session)` | |||
/// </summary> | |||
[TestMethod] | |||
[TestMethod, Ignore] | |||
public void Session() | |||
{ | |||
lock (this) | |||
lock (Locks.ProcessWide) | |||
{ | |||
var s = new Status(); | |||
var graph = new Graph(); | |||
var graph = new Graph().as_default(); | |||
// Make a placeholder operation. | |||
var feed = c_test_util.Placeholder(graph, s); | |||
@@ -117,7 +117,7 @@ namespace TensorFlowNET.UnitTest | |||
public void SetShape() | |||
{ | |||
var s = new Status(); | |||
var graph = new Graph(); | |||
var graph = new Graph().as_default(); | |||
var feed = c_test_util.Placeholder(graph, s); | |||
var feed_out_0 = new TF_Output(feed, 0); | |||
@@ -12,42 +12,51 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add") | |||
{ | |||
var desc = c_api.TF_NewOperation(graph, "AddN", name); | |||
var inputs = new TF_Output[] | |||
lock (Locks.ProcessWide) | |||
{ | |||
new TF_Output(l, 0), | |||
new TF_Output(r, 0), | |||
}; | |||
var desc = c_api.TF_NewOperation(graph, "AddN", name); | |||
c_api.TF_AddInputList(desc, inputs, inputs.Length); | |||
var inputs = new TF_Output[] | |||
{ | |||
new TF_Output(l, 0), | |||
new TF_Output(r, 0), | |||
}; | |||
var op = c_api.TF_FinishOperation(desc, s); | |||
s.Check(); | |||
c_api.TF_AddInputList(desc, inputs, inputs.Length); | |||
return op; | |||
var op = c_api.TF_FinishOperation(desc, s); | |||
s.Check(); | |||
return op; | |||
} | |||
} | |||
[SuppressMessage("ReSharper", "RedundantAssignment")] | |||
public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | |||
{ | |||
using (var buffer = new Buffer()) | |||
lock (Locks.ProcessWide) | |||
{ | |||
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | |||
attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
} | |||
using (var buffer = new Buffer()) | |||
{ | |||
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | |||
attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
} | |||
return s.Code == TF_Code.TF_OK; | |||
return s.Code == TF_Code.TF_OK; | |||
} | |||
} | |||
public static GraphDef GetGraphDef(Graph graph) | |||
{ | |||
using (var s = new Status()) | |||
using (var buffer = new Buffer()) | |||
lock (Locks.ProcessWide) | |||
{ | |||
c_api.TF_GraphToGraphDef(graph, buffer, s); | |||
s.Check(); | |||
return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
using (var s = new Status()) | |||
using (var buffer = new Buffer()) | |||
{ | |||
c_api.TF_GraphToGraphDef(graph, buffer, s); | |||
s.Check(); | |||
return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
} | |||
} | |||
} | |||
@@ -58,6 +67,7 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
return false; | |||
} | |||
bool found_t = false; | |||
bool found_n = false; | |||
foreach (var attr in node_def.Attr) | |||
@@ -67,19 +77,16 @@ namespace TensorFlowNET.UnitTest | |||
if (attr.Value.Type == DataType.DtInt32) | |||
{ | |||
found_t = true; | |||
} | |||
else | |||
} else | |||
{ | |||
return false; | |||
} | |||
} | |||
else if (attr.Key == "N") | |||
} else if (attr.Key == "N") | |||
{ | |||
if (attr.Value.I == n) | |||
{ | |||
found_n = true; | |||
} | |||
else | |||
} else | |||
{ | |||
return false; | |||
} | |||
@@ -92,7 +99,7 @@ namespace TensorFlowNET.UnitTest | |||
public static bool IsNeg(NodeDef node_def, string input) | |||
{ | |||
return node_def.Op == "Neg" && node_def.Name == "neg" && | |||
node_def.Input.Count == 1 && node_def.Input[0] == input; | |||
node_def.Input.Count == 1 && node_def.Input[0] == input; | |||
} | |||
public static bool IsPlaceholder(NodeDef node_def) | |||
@@ -111,13 +118,11 @@ namespace TensorFlowNET.UnitTest | |||
if (attr.Value.Type == DataType.DtInt32) | |||
{ | |||
found_dtype = true; | |||
} | |||
else | |||
} else | |||
{ | |||
return false; | |||
} | |||
} | |||
else if (attr.Key == "shape") | |||
} else if (attr.Key == "shape") | |||
{ | |||
found_shape = true; | |||
} | |||
@@ -132,72 +137,82 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
return false; | |||
} | |||
bool found_dtype = false; | |||
bool found_value = false; | |||
foreach (var attr in node_def.Attr) { | |||
foreach (var attr in node_def.Attr) | |||
{ | |||
if (attr.Key == "dtype") | |||
{ | |||
if (attr.Value.Type == DataType.DtInt32) | |||
{ | |||
found_dtype = true; | |||
} | |||
else | |||
} else | |||
{ | |||
return false; | |||
} | |||
} | |||
else if (attr.Key == "value") | |||
} else if (attr.Key == "value") | |||
{ | |||
if (attr.Value.Tensor != null && | |||
attr.Value.Tensor.IntVal.Count == 1 && | |||
attr.Value.Tensor.IntVal[0] == v) | |||
{ | |||
found_value = true; | |||
} | |||
else | |||
} else | |||
{ | |||
return false; | |||
} | |||
} | |||
} | |||
return found_dtype && found_value; | |||
} | |||
public static Operation Neg(Operation n, Graph graph, Status s, string name = "neg") | |||
{ | |||
OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); | |||
var neg_input = new TF_Output(n, 0); | |||
c_api.TF_AddInput(desc, neg_input); | |||
var op = c_api.TF_FinishOperation(desc, s); | |||
s.Check(); | |||
lock (Locks.ProcessWide) | |||
{ | |||
OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); | |||
var neg_input = new TF_Output(n, 0); | |||
c_api.TF_AddInput(desc, neg_input); | |||
var op = c_api.TF_FinishOperation(desc, s); | |||
s.Check(); | |||
return op; | |||
return op; | |||
} | |||
} | |||
public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null) | |||
{ | |||
var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | |||
c_api.TF_SetAttrType(desc, "dtype", dtype); | |||
if (dims != null) | |||
lock (Locks.ProcessWide) | |||
{ | |||
c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); | |||
} | |||
var op = c_api.TF_FinishOperation(desc, s); | |||
s.Check(); | |||
var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | |||
c_api.TF_SetAttrType(desc, "dtype", dtype); | |||
if (dims != null) | |||
{ | |||
c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); | |||
} | |||
var op = c_api.TF_FinishOperation(desc, s); | |||
s.Check(); | |||
return op; | |||
return op; | |||
} | |||
} | |||
public static Operation Const(Tensor t, Graph graph, Status s, string name) | |||
{ | |||
var desc = c_api.TF_NewOperation(graph, "Const", name); | |||
c_api.TF_SetAttrTensor(desc, "value", t, s); | |||
s.Check(); | |||
c_api.TF_SetAttrType(desc, "dtype", t.dtype); | |||
var op = c_api.TF_FinishOperation(desc, s); | |||
s.Check(); | |||
return op; | |||
lock (Locks.ProcessWide) | |||
{ | |||
var desc = c_api.TF_NewOperation(graph, "Const", name); | |||
c_api.TF_SetAttrTensor(desc, "value", t, s); | |||
s.Check(); | |||
c_api.TF_SetAttrType(desc, "dtype", t.dtype); | |||
var op = c_api.TF_FinishOperation(desc, s); | |||
s.Check(); | |||
return op; | |||
} | |||
} | |||
public static Operation ScalarConst(int v, Graph graph, Status s, string name = "scalar") | |||
@@ -205,4 +220,4 @@ namespace TensorFlowNET.UnitTest | |||
return Const(new Tensor(v), graph, s, name); | |||
} | |||
} | |||
} | |||
} |