@@ -53,8 +53,7 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
object value = null; | object value = null; | ||||
byte isList = 0; | byte isList = 0; | ||||
using var status = new Status(); | |||||
var attrType = c_api.TFE_OpNameGetAttrType(tf.context, Name, attr_name, ref isList, status); | |||||
var attrType = c_api.TFE_OpNameGetAttrType(tf.context, Name, attr_name, ref isList, tf.status); | |||||
switch (attrType) | switch (attrType) | ||||
{ | { | ||||
case TF_AttrType.TF_ATTR_BOOL: | case TF_AttrType.TF_ATTR_BOOL: | ||||
@@ -22,13 +22,13 @@ namespace Tensorflow.Eager | |||||
public EagerTensor(string value, string device_name) : base(value) | public EagerTensor(string value, string device_name) : base(value) | ||||
{ | { | ||||
EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, status); | |||||
EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.status); | |||||
Resolve(); | Resolve(); | ||||
} | } | ||||
public EagerTensor(NDArray value, string device_name) : base(value) | public EagerTensor(NDArray value, string device_name) : base(value) | ||||
{ | { | ||||
EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, status); | |||||
EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.status); | |||||
Resolve(); | Resolve(); | ||||
} | } | ||||
@@ -37,7 +37,7 @@ namespace Tensorflow.Eager | |||||
_id = get_uid(); | _id = get_uid(); | ||||
if (_handle == IntPtr.Zero) | if (_handle == IntPtr.Zero) | ||||
_handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, status); | |||||
_handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, tf.status); | |||||
//print($"new Tensor {Id} {_handle.ToString("x16")}"); | //print($"new Tensor {Id} {_handle.ToString("x16")}"); | ||||
//print($"new TensorHandle {Id} {EagerTensorHandle.ToString("x16")}"); | //print($"new TensorHandle {Id} {EagerTensorHandle.ToString("x16")}"); | ||||
@@ -8,26 +8,23 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
public partial class EagerTensor : Tensor | public partial class EagerTensor : Tensor | ||||
{ | { | ||||
Status status = new Status(); | |||||
public IntPtr EagerTensorHandle; | public IntPtr EagerTensorHandle; | ||||
public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, status)); | |||||
public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, tf.status)); | |||||
public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, status); | |||||
public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.status); | |||||
public static int GetRank(IntPtr handle) | public static int GetRank(IntPtr handle) | ||||
{ | { | ||||
var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | ||||
using var status = new Status(); | |||||
return c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, status); | |||||
return c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.status); | |||||
} | } | ||||
public static int[] GetDims(IntPtr handle) | public static int[] GetDims(IntPtr handle) | ||||
{ | { | ||||
var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | ||||
using var status = new Status(); | |||||
var dims = new int[c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, status)]; | |||||
var dims = new int[c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.status)]; | |||||
for (int i = 0; i < dims.Length; i++) | for (int i = 0; i < dims.Length; i++) | ||||
dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, status); | |||||
dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.status); | |||||
return dims; | return dims; | ||||
} | } | ||||
@@ -512,7 +512,7 @@ namespace Tensorflow | |||||
public TensorShape GetTensorShape(TF_Output output) | public TensorShape GetTensorShape(TF_Output output) | ||||
{ | { | ||||
var status = new Status(); | |||||
var status = tf.status; | |||||
var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status); | var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status); | ||||
status.Check(); | status.Check(); | ||||
@@ -17,6 +17,7 @@ | |||||
using System; | using System; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -30,11 +31,8 @@ namespace Tensorflow | |||||
public int InputListLength(string name) | public int InputListLength(string name) | ||||
{ | { | ||||
int num = 0; | int num = 0; | ||||
using(var status = new Status()) | |||||
{ | |||||
num = c_api.TF_OperationInputListLength(_handle, name, status); | |||||
status.Check(true); | |||||
} | |||||
num = c_api.TF_OperationInputListLength(_handle, name, tf.status); | |||||
tf.status.Check(true); | |||||
return num; | return num; | ||||
} | } | ||||
public int NumInputs => c_api.TF_OperationNumInputs(_handle); | public int NumInputs => c_api.TF_OperationNumInputs(_handle); | ||||
@@ -28,12 +28,8 @@ namespace Tensorflow | |||||
public int OutputListLength(string name) | public int OutputListLength(string name) | ||||
{ | { | ||||
int num = 0; | |||||
using (var status = new Status()) | |||||
{ | |||||
num = c_api.TF_OperationOutputListLength(_handle, name, status); | |||||
status.Check(true); | |||||
} | |||||
int num = c_api.TF_OperationOutputListLength(_handle, name, tf.status); | |||||
tf.status.Check(true); | |||||
return num; | return num; | ||||
} | } | ||||
@@ -20,6 +20,7 @@ using System.Collections.Generic; | |||||
using System.IO; | using System.IO; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -233,14 +234,13 @@ namespace Tensorflow | |||||
AttrValue x = null; | AttrValue x = null; | ||||
lock (Locks.ProcessWide) | lock (Locks.ProcessWide) | ||||
using (var status = new Status()) | |||||
using (var buf = new Buffer()) | |||||
{ | |||||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||||
status.Check(true); | |||||
{ | |||||
using var buf = new Buffer(); | |||||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.status); | |||||
tf.status.Check(true); | |||||
x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); | |||||
} | |||||
x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); | |||||
} | |||||
string oneof_value = x.ValueCase.ToString(); | string oneof_value = x.ValueCase.ToString(); | ||||
if (string.IsNullOrEmpty(oneof_value)) | if (string.IsNullOrEmpty(oneof_value)) | ||||
@@ -295,11 +295,10 @@ namespace Tensorflow | |||||
// after the c_api call next time _inputs is accessed | // after the c_api call next time _inputs is accessed | ||||
// the updated inputs are reloaded from the c_api | // the updated inputs are reloaded from the c_api | ||||
lock (Locks.ProcessWide) | lock (Locks.ProcessWide) | ||||
using (var status = new Status()) | |||||
{ | { | ||||
c_api.UpdateEdge(_graph, output, input, status); | |||||
c_api.UpdateEdge(_graph, output, input, tf.status); | |||||
//var updated_inputs = inputs; | //var updated_inputs = inputs; | ||||
status.Check(); | |||||
tf.status.Check(); | |||||
} | } | ||||
} | } | ||||
@@ -43,30 +43,26 @@ namespace Tensorflow | |||||
allow_broadcast: false); | allow_broadcast: false); | ||||
public static Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | public static Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | ||||
{ | |||||
dtype = dtype.as_base_dtype(); | |||||
return tf_with(ops.name_scope(name, "zeros", shape), scope => | |||||
=> tf_with(ops.name_scope(name, "zeros", shape), scope => | |||||
{ | { | ||||
dtype = dtype.as_base_dtype(); | |||||
name = scope; | name = scope; | ||||
var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); | |||||
Tensor zeros = null; | |||||
switch (dtype) | switch (dtype) | ||||
{ | { | ||||
case TF_DataType.TF_BOOL: | |||||
return _constant_if_small(false, shape, dtype, name); | |||||
case TF_DataType.TF_DOUBLE: | case TF_DataType.TF_DOUBLE: | ||||
return _constant_if_small(0.0D, shape, dtype, name); | |||||
zeros = constant(0d); | |||||
break; | |||||
case TF_DataType.TF_FLOAT: | case TF_DataType.TF_FLOAT: | ||||
return _constant_if_small(0.0F, shape, dtype, name); | |||||
case TF_DataType.TF_INT64: | |||||
return _constant_if_small(0L, shape, dtype, name); | |||||
case TF_DataType.TF_INT32: | |||||
return _constant_if_small(0, shape, dtype, name); | |||||
case TF_DataType.TF_INT8: | |||||
return _constant_if_small<byte>(0, shape, dtype, name); | |||||
zeros = constant(0f); | |||||
break; | |||||
default: | default: | ||||
throw new TypeError("can't find type for zeros"); | |||||
zeros = constant(0); | |||||
break; | |||||
} | } | ||||
return fill(shape_tensor, zeros, name: name); | |||||
}); | }); | ||||
} | |||||
public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) | public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) | ||||
{ | { | ||||
@@ -22,7 +22,7 @@ using System.Linq; | |||||
using System.Numerics; | using System.Numerics; | ||||
using System.Text; | using System.Text; | ||||
using Google.Protobuf; | using Google.Protobuf; | ||||
using NumSharp.Backends; | |||||
using static Tensorflow.Binding; | |||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -236,7 +236,7 @@ namespace Tensorflow | |||||
// Ensure any changes to the graph are reflected in the runtime. | // Ensure any changes to the graph are reflected in the runtime. | ||||
_extend_graph(); | _extend_graph(); | ||||
var status = new Status(); | |||||
var status = tf.status; | |||||
var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | ||||
@@ -46,7 +46,7 @@ namespace Tensorflow | |||||
lock (Locks.ProcessWide) | lock (Locks.ProcessWide) | ||||
{ | { | ||||
var graph = c_api.TF_NewGraph(); | var graph = c_api.TF_NewGraph(); | ||||
var status = new Status(); | |||||
using var status = new Status(); | |||||
var opt = new SessionOptions(); | var opt = new SessionOptions(); | ||||
var tags = new string[] {"serve"}; | var tags = new string[] {"serve"}; | ||||
@@ -66,7 +66,6 @@ namespace Tensorflow | |||||
status.Check(true); | status.Check(true); | ||||
} catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel")) | } catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel")) | ||||
{ | { | ||||
status = new Status(); | |||||
sess = c_api.TF_LoadSessionFromSavedModel(opt, | sess = c_api.TF_LoadSessionFromSavedModel(opt, | ||||
IntPtr.Zero, | IntPtr.Zero, | ||||
Path.GetFullPath(path), | Path.GetFullPath(path), | ||||
@@ -13,14 +13,12 @@ namespace Tensorflow | |||||
public class EagerTensorV2 : DisposableObject, ITensor | public class EagerTensorV2 : DisposableObject, ITensor | ||||
{ | { | ||||
IntPtr EagerTensorHandle; | IntPtr EagerTensorHandle; | ||||
public string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, status)); | |||||
static Status status = new Status(); | |||||
public string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, tf.status)); | |||||
public EagerTensorV2(IntPtr handle) | public EagerTensorV2(IntPtr handle) | ||||
{ | { | ||||
EagerTensorHandle = c_api.TFE_EagerTensorHandle(handle); | EagerTensorHandle = c_api.TFE_EagerTensorHandle(handle); | ||||
_handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, status); | |||||
_handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, tf.status); | |||||
} | } | ||||
public unsafe EagerTensorV2(NDArray nd, string device_name = "") | public unsafe EagerTensorV2(NDArray nd, string device_name = "") | ||||
@@ -40,7 +38,7 @@ namespace Tensorflow | |||||
}, IntPtr.Zero); | }, IntPtr.Zero); | ||||
EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, status); | |||||
EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.status); | |||||
} | } | ||||
/*public unsafe EagerTensorV2(float[,] value) | /*public unsafe EagerTensorV2(float[,] value) | ||||
@@ -21,6 +21,7 @@ using System.Globalization; | |||||
using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
using System.Text; | using System.Text; | ||||
using NumSharp.Utilities; | using NumSharp.Utilities; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -69,11 +70,8 @@ namespace Tensorflow | |||||
IntPtr stringStartAddress = IntPtr.Zero; | IntPtr stringStartAddress = IntPtr.Zero; | ||||
UIntPtr dstLen = UIntPtr.Zero; | UIntPtr dstLen = UIntPtr.Zero; | ||||
using (var status = new Status()) | |||||
{ | |||||
c_api.TF_StringDecode((byte*) this.buffer + 8, (UIntPtr) (this.bytesize), (byte**) &stringStartAddress, &dstLen, status); | |||||
status.Check(true); | |||||
} | |||||
c_api.TF_StringDecode((byte*) this.buffer + 8, (UIntPtr) (this.bytesize), (byte**) &stringStartAddress, &dstLen, tf.status); | |||||
tf.status.Check(true); | |||||
var dstLenInt = checked((int) dstLen); | var dstLenInt = checked((int) dstLen); | ||||
var value = Encoding.UTF8.GetString((byte*) stringStartAddress, dstLenInt); | var value = Encoding.UTF8.GetString((byte*) stringStartAddress, dstLenInt); | ||||
@@ -451,7 +451,6 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public unsafe Tensor(string str) | public unsafe Tensor(string str) | ||||
{ | { | ||||
var status = new Status(); | |||||
var buffer = Encoding.UTF8.GetBytes(str); | var buffer = Encoding.UTF8.GetBytes(str); | ||||
var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | ||||
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | ||||
@@ -460,9 +459,9 @@ namespace Tensorflow | |||||
IntPtr tensor = c_api.TF_TensorData(handle); | IntPtr tensor = c_api.TF_TensorData(handle); | ||||
Marshal.WriteInt64(tensor, 0); | Marshal.WriteInt64(tensor, 0); | ||||
fixed (byte* src = buffer) | fixed (byte* src = buffer) | ||||
c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(long)), size, status); | |||||
c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(long)), size, tf.status); | |||||
_handle = handle; | _handle = handle; | ||||
status.Check(true); | |||||
tf.status.Check(true); | |||||
} | } | ||||
public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) | public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) | ||||
@@ -483,10 +482,8 @@ namespace Tensorflow | |||||
IntPtr tensor = c_api.TF_TensorData(handle); | IntPtr tensor = c_api.TF_TensorData(handle); | ||||
Marshal.WriteInt64(tensor, 0); | Marshal.WriteInt64(tensor, 0); | ||||
var status = new Status(); | |||||
c_api.TF_StringEncode((byte*) nd.Unsafe.Address, bytesLength, (sbyte*) (tensor + sizeof(Int64)), size, status); | |||||
status.Check(true); | |||||
c_api.TF_StringEncode((byte*) nd.Unsafe.Address, bytesLength, (sbyte*) (tensor + sizeof(Int64)), size, tf.status); | |||||
tf.status.Check(true); | |||||
_handle = handle; | _handle = handle; | ||||
} else | } else | ||||
{ | { | ||||
@@ -498,11 +495,10 @@ namespace Tensorflow | |||||
IntPtr tensor = c_api.TF_TensorData(handle); | IntPtr tensor = c_api.TF_TensorData(handle); | ||||
Marshal.WriteInt64(tensor, 0); | Marshal.WriteInt64(tensor, 0); | ||||
var status = new Status(); | |||||
fixed (byte* src = buffer) | fixed (byte* src = buffer) | ||||
c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status); | |||||
c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, tf.status); | |||||
status.Check(true); | |||||
tf.status.Check(true); | |||||
_handle = handle; | _handle = handle; | ||||
} | } | ||||
@@ -607,11 +603,10 @@ namespace Tensorflow | |||||
IntPtr tensor = c_api.TF_TensorData(handle); | IntPtr tensor = c_api.TF_TensorData(handle); | ||||
Marshal.WriteInt64(tensor, 0); | Marshal.WriteInt64(tensor, 0); | ||||
var status = new Status(); | |||||
fixed (byte* src = buffer) | fixed (byte* src = buffer) | ||||
c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status); | |||||
c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(long)), size, tf.status); | |||||
status.Check(true); | |||||
tf.status.Check(true); | |||||
return handle; | return handle; | ||||
} | } | ||||
@@ -3,7 +3,7 @@ using NumSharp.Backends; | |||||
using NumSharp.Backends.Unmanaged; | using NumSharp.Backends.Unmanaged; | ||||
using NumSharp.Utilities; | using NumSharp.Utilities; | ||||
using System; | using System; | ||||
using System.Collections.Generic; | |||||
using static Tensorflow.Binding; | |||||
using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
using System.Text; | using System.Text; | ||||
@@ -237,18 +237,15 @@ namespace Tensorflow | |||||
var src = c_api.TF_TensorData(_handle); | var src = c_api.TF_TensorData(_handle); | ||||
var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); | var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); | ||||
src += (int)(size * 8); | src += (int)(size * 8); | ||||
using (var status = new Status()) | |||||
for (int i = 0; i < buffer.Length; i++) | |||||
{ | { | ||||
for (int i = 0; i < buffer.Length; i++) | |||||
{ | |||||
IntPtr dst = IntPtr.Zero; | |||||
UIntPtr dstLen = UIntPtr.Zero; | |||||
var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status); | |||||
status.Check(true); | |||||
buffer[i] = new byte[(int)dstLen]; | |||||
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | |||||
src += (int)read; | |||||
} | |||||
IntPtr dst = IntPtr.Zero; | |||||
UIntPtr dstLen = UIntPtr.Zero; | |||||
var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, tf.status); | |||||
tf.status.Check(true); | |||||
buffer[i] = new byte[(int)dstLen]; | |||||
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | |||||
src += (int)read; | |||||
} | } | ||||
var _str = new string[buffer.Length]; | var _str = new string[buffer.Length]; | ||||
@@ -22,7 +22,7 @@ using System.Globalization; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
using System.Text; | using System.Text; | ||||
using System.Threading.Tasks; | |||||
using static Tensorflow.Binding; | |||||
using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -109,11 +109,7 @@ namespace Tensorflow | |||||
if (_handle == IntPtr.Zero) | if (_handle == IntPtr.Zero) | ||||
{ | { | ||||
using (var status = new Status()) | |||||
{ | |||||
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); | |||||
status.Check(); | |||||
} | |||||
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.status); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -126,15 +122,12 @@ namespace Tensorflow | |||||
set | set | ||||
{ | { | ||||
using (var status = new Status()) | |||||
{ | |||||
if (value == null) | |||||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, status); | |||||
else | |||||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); | |||||
if (value == null) | |||||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.status); | |||||
else | |||||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, tf.status); | |||||
status.Check(true); | |||||
} | |||||
tf.status.Check(true); | |||||
} | } | ||||
} | } | ||||
@@ -178,13 +171,9 @@ namespace Tensorflow | |||||
{ | { | ||||
if (_handle == IntPtr.Zero) | if (_handle == IntPtr.Zero) | ||||
{ | { | ||||
using (var status = new Status()) | |||||
{ | |||||
var output = _as_tf_output(); | |||||
int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status); | |||||
status.Check(); | |||||
return ndim; | |||||
} | |||||
var output = _as_tf_output(); | |||||
int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, tf.status); | |||||
return ndim; | |||||
} | } | ||||
return c_api.TF_NumDims(_handle); | return c_api.TF_NumDims(_handle); | ||||
@@ -176,30 +176,29 @@ namespace Tensorflow | |||||
throw new NotImplementedException("_create_c_op"); | throw new NotImplementedException("_create_c_op"); | ||||
} | } | ||||
using (var status = new Status()) | |||||
var status = tf.status; | |||||
// Add control inputs | |||||
foreach (var control_input in control_inputs) | |||||
c_api.TF_AddControlInput(op_desc, control_input); | |||||
// Add attrs | |||||
foreach (var attr in node_def.Attr) | |||||
{ | { | ||||
// Add control inputs | |||||
foreach (var control_input in control_inputs) | |||||
c_api.TF_AddControlInput(op_desc, control_input); | |||||
// Add attrs | |||||
foreach (var attr in node_def.Attr) | |||||
{ | |||||
var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||||
var protoHandle = Marshal.AllocHGlobal(bytes.Length); | |||||
Marshal.Copy(bytes, 0, protoHandle, bytes.Length); | |||||
uint len = (uint)bytes.Length; | |||||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, protoHandle, proto_len: len, status: status); | |||||
status.Check(true); | |||||
Marshal.FreeHGlobal(protoHandle); | |||||
} | |||||
var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||||
var protoHandle = Marshal.AllocHGlobal(bytes.Length); | |||||
Marshal.Copy(bytes, 0, protoHandle, bytes.Length); | |||||
uint len = (uint)bytes.Length; | |||||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, protoHandle, proto_len: len, status: status); | |||||
status.Check(true); | |||||
Marshal.FreeHGlobal(protoHandle); | |||||
} | |||||
var c_op = c_api.TF_FinishOperation(op_desc, status); | |||||
var c_op = c_api.TF_FinishOperation(op_desc, status); | |||||
status.Check(true); | |||||
status.Check(true); | |||||
return c_op; | |||||
} | |||||
return c_op; | |||||
} | } | ||||
} | } | ||||
@@ -22,7 +22,6 @@ using System.Runtime.InteropServices; | |||||
using System.Threading; | using System.Threading; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Tensorflow.Gradients; | using Tensorflow.Gradients; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -42,11 +41,12 @@ namespace Tensorflow | |||||
public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients); | public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients); | ||||
public Status status = new Status(); | |||||
public OpDefLibrary _op_def_lib = new OpDefLibrary(); | public OpDefLibrary _op_def_lib = new OpDefLibrary(); | ||||
public Context context = new Context(new ContextOptions(), new Status()); | |||||
public Execute _execute = new Execute(); | public Execute _execute = new Execute(); | ||||
public IEagerRunner Runner = new EagerRunner(); | public IEagerRunner Runner = new EagerRunner(); | ||||
public Context context = new Context(new ContextOptions(), new Status()); | |||||
public tensorflow() | public tensorflow() | ||||
{ | { | ||||
enable_eager_execution(); | enable_eager_execution(); | ||||
@@ -96,14 +96,14 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
public void ZerosConst() | public void ZerosConst() | ||||
{ | { | ||||
// small size | // small size | ||||
var tensor = tf.zeros(new Shape(3, 2), tf.int32, "small"); | |||||
var tensor = tf.zeros((3, 2), tf.int32, "small"); | |||||
Assert.AreEqual(tensor.shape[0], 3); | Assert.AreEqual(tensor.shape[0], 3); | ||||
Assert.AreEqual(tensor.shape[1], 2); | Assert.AreEqual(tensor.shape[1], 2); | ||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, tensor.numpy().ToArray<int>())); | Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, tensor.numpy().ToArray<int>())); | ||||
// big size | // big size | ||||
tensor = tf.zeros(new Shape(200, 100), tf.int32, "big"); | |||||
tensor = tf.zeros((200, 100), tf.int32, "big"); | |||||
Assert.AreEqual(tensor.shape[0], 200); | Assert.AreEqual(tensor.shape[0], 200); | ||||
Assert.AreEqual(tensor.shape[1], 100); | Assert.AreEqual(tensor.shape[1], 100); | ||||
@@ -35,7 +35,26 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
var dz_dx = tape.gradient(z, x); | var dz_dx = tape.gradient(z, x); | ||||
var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; | var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; | ||||
Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.numpy().ToArray<float>(), expected)); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected)); | |||||
} | |||||
[TestMethod] | |||||
public void PersistentTape() | |||||
{ | |||||
var x = tf.ones((2, 2)); | |||||
using var tape = tf.GradientTape(persistent: true); | |||||
tape.watch(x); | |||||
var y = tf.reduce_sum(x); | |||||
var z = tf.multiply(y, y); | |||||
var dz_dx = tape.gradient(z, x); | |||||
var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; | |||||
Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected)); | |||||
var dz_dy = tape.gradient(z, y); | |||||
expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; | |||||
Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray<float>(), expected)); | |||||
} | } | ||||
} | } | ||||
} | } |