@@ -526,8 +526,19 @@ namespace Tensorflow | |||||
var type = data.GetType(); | var type = data.GetType(); | ||||
switch (data) | switch (data) | ||||
{ | { | ||||
case Shape shape: | |||||
case TensorShape: | |||||
case Shape: | |||||
return TF_DataType.TF_INT64; | return TF_DataType.TF_INT64; | ||||
case Axis: | |||||
return TF_DataType.TF_INT32; | |||||
case NDArray nd: | |||||
return nd.dtype; | |||||
case Tensor tensor: | |||||
return tensor.dtype; | |||||
case Tensor[] tensor: | |||||
return tensor[0].dtype; | |||||
case ResourceVariable variable: | |||||
return variable.dtype; | |||||
default: | default: | ||||
return type.as_tf_dtype(); | return type.as_tf_dtype(); | ||||
} | } | ||||
@@ -142,7 +142,7 @@ namespace Tensorflow.Contexts | |||||
bool has_graph_arg = !tf.Context.executing_eagerly(); | bool has_graph_arg = !tf.Context.executing_eagerly(); | ||||
foreach (var el in flatten_args) | foreach (var el in flatten_args) | ||||
{ | { | ||||
if (el is Tensor tensor && !tensor.IsEagerTensor) | |||||
if (el is Tensor tensor && tensor.IsCreatedInGraphMode) | |||||
{ | { | ||||
has_graph_arg = true; | has_graph_arg = true; | ||||
break; | break; | ||||
@@ -50,9 +50,6 @@ namespace Tensorflow.Eager | |||||
public EagerTensor(Shape shape, TF_DataType dtype) : base(shape, dtype) | public EagerTensor(Shape shape, TF_DataType dtype) : base(shape, dtype) | ||||
=> NewEagerTensorHandle(_handle); | => NewEagerTensorHandle(_handle); | ||||
internal unsafe EagerTensor(string value) : base(value) | |||||
=> NewEagerTensorHandle(_handle); | |||||
internal unsafe EagerTensor(Array array, Shape shape) : base(array, shape) | internal unsafe EagerTensor(Array array, Shape shape) : base(array, shape) | ||||
=> NewEagerTensorHandle(_handle); | => NewEagerTensorHandle(_handle); | ||||
@@ -141,7 +141,7 @@ namespace Tensorflow.Functions | |||||
src_graph: _func_graph); | src_graph: _func_graph); | ||||
var captures_from_forward = backwards_graph.external_captures | var captures_from_forward = backwards_graph.external_captures | ||||
.Where(x => !x.IsEagerTensor && x.graph == _func_graph) | |||||
.Where(x => x.IsCreatedInGraphMode && x.graph == _func_graph) | |||||
.ToArray(); | .ToArray(); | ||||
foreach(var capture in captures_from_forward) | foreach(var capture in captures_from_forward) | ||||
{ | { | ||||
@@ -8,20 +8,47 @@ namespace Tensorflow.NumPy | |||||
{ | { | ||||
public partial class NDArray | public partial class NDArray | ||||
{ | { | ||||
public NDArray(bool value) => _tensor = new EagerTensor(value); | |||||
public NDArray(byte value) => _tensor = new EagerTensor(value); | |||||
public NDArray(short value) => _tensor = new EagerTensor(value); | |||||
public NDArray(int value) => _tensor = new EagerTensor(value); | |||||
public NDArray(long value) => _tensor = new EagerTensor(value); | |||||
public NDArray(float value) => _tensor = new EagerTensor(value); | |||||
public NDArray(double value) => _tensor = new EagerTensor(value); | |||||
public NDArray(bool value) => Init(value); | |||||
public NDArray(byte value) => Init(value); | |||||
public NDArray(short value) => Init(value); | |||||
public NDArray(int value) => Init(value); | |||||
public NDArray(long value) => Init(value); | |||||
public NDArray(float value) => Init(value); | |||||
public NDArray(double value) => Init(value); | |||||
public NDArray(Array value, Shape? shape = null) => Init(value, shape); | |||||
public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) => Init(shape, dtype: dtype); | |||||
public NDArray(Tensor value, Shape? shape = null) => Init(value, shape); | |||||
public NDArray(Array value, Shape? shape = null) => _tensor = new EagerTensor(value, shape); | |||||
public static NDArray Scalar<T>(T value) where T : unmanaged | |||||
=> value switch | |||||
{ | |||||
bool val => new NDArray(val), | |||||
byte val => new NDArray(val), | |||||
int val => new NDArray(val), | |||||
float val => new NDArray(val), | |||||
double val => new NDArray(val), | |||||
_ => throw new NotImplementedException("") | |||||
}; | |||||
void Init<T>(T value) where T : unmanaged | |||||
{ | |||||
_tensor = new EagerTensor(value); | |||||
_tensor.SetReferencedByNDArray(); | |||||
} | |||||
void Init(Array value, Shape? shape = null) | |||||
{ | |||||
_tensor = new EagerTensor(value, shape ?? value.GetShape()); | |||||
_tensor.SetReferencedByNDArray(); | |||||
} | |||||
public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) | |||||
=> _tensor = new EagerTensor(shape, dtype: dtype); | |||||
void Init(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) | |||||
{ | |||||
_tensor = new EagerTensor(shape, dtype: dtype); | |||||
_tensor.SetReferencedByNDArray(); | |||||
} | |||||
public NDArray(Tensor value, Shape? shape = null) | |||||
void Init(Tensor value, Shape? shape = null) | |||||
{ | { | ||||
if (shape is not null) | if (shape is not null) | ||||
_tensor = tf.reshape(value, shape); | _tensor = tf.reshape(value, shape); | ||||
@@ -30,18 +57,8 @@ namespace Tensorflow.NumPy | |||||
if (_tensor.TensorDataPointer == IntPtr.Zero) | if (_tensor.TensorDataPointer == IntPtr.Zero) | ||||
_tensor = tf.get_default_session().eval(_tensor); | _tensor = tf.get_default_session().eval(_tensor); | ||||
} | |||||
public static NDArray Scalar<T>(T value) where T : unmanaged | |||||
{ | |||||
return value switch | |||||
{ | |||||
bool val => new NDArray(val), | |||||
int val => new NDArray(val), | |||||
float val => new NDArray(val), | |||||
double val => new NDArray(val), | |||||
_ => throw new NotImplementedException("") | |||||
}; | |||||
_tensor.SetReferencedByNDArray(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -21,6 +21,7 @@ using System.Linq; | |||||
using System.Numerics; | using System.Numerics; | ||||
using System.Text; | using System.Text; | ||||
using static Tensorflow.c_api; | using static Tensorflow.c_api; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -31,7 +32,7 @@ namespace Tensorflow | |||||
public Tensor() | public Tensor() | ||||
{ | { | ||||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -41,60 +42,7 @@ namespace Tensorflow | |||||
public Tensor(IntPtr handle) | public Tensor(IntPtr handle) | ||||
{ | { | ||||
_handle = handle; | _handle = handle; | ||||
//no need to set AllocationType = AllocationType.None; | |||||
#if TRACK_TENSOR_LIFE | |||||
print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}"); | |||||
#endif | |||||
} | |||||
unsafe internal Tensor(Shape shape, TF_DataType dtype) | |||||
=> _handle = TF_NewTensor(shape, dtype, null); | |||||
internal Tensor(Array array, Shape? shape = null) | |||||
=> InitTensor(array, shape); | |||||
unsafe void InitTensor(Array array, Shape? shape = null) | |||||
{ | |||||
shape = shape ?? array.GetShape(); | |||||
var dtype = array.GetType().GetElementType().as_tf_dtype(); | |||||
switch (array) | |||||
{ | |||||
case bool[] val: | |||||
fixed (void* addr = &val[0]) | |||||
_handle = TF_NewTensor(shape, dtype, addr); | |||||
break; | |||||
case int[] val: | |||||
fixed (void* addr = &val[0]) | |||||
_handle = TF_NewTensor(shape, dtype, addr); | |||||
break; | |||||
case int[,] val: | |||||
fixed (void* addr = &val[0, 0]) | |||||
_handle = TF_NewTensor(shape, dtype, addr); | |||||
break; | |||||
case long[] val: | |||||
fixed (void* addr = &val[0]) | |||||
_handle = TF_NewTensor(shape, dtype, addr); | |||||
break; | |||||
case float[] val: | |||||
fixed (void* addr = &val[0]) | |||||
_handle = TF_NewTensor(shape, dtype, addr); | |||||
break; | |||||
case float[,] val: | |||||
fixed (void* addr = &val[0, 0]) | |||||
_handle = TF_NewTensor(shape, dtype, addr); | |||||
break; | |||||
case double[] val: | |||||
fixed (void* addr = &val[0]) | |||||
_handle = TF_NewTensor(shape, dtype, addr); | |||||
break; | |||||
case double[,] val: | |||||
fixed (void* addr = &val[0, 0]) | |||||
_handle = TF_NewTensor(shape, dtype, addr); | |||||
break; | |||||
default: | |||||
throw new NotImplementedException(""); | |||||
} | |||||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -109,22 +57,26 @@ namespace Tensorflow | |||||
public Tensor(IntPtr data_ptr, long[] shape, TF_DataType dType, int num_bytes) | public Tensor(IntPtr data_ptr, long[] shape, TF_DataType dType, int num_bytes) | ||||
{ | { | ||||
_handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: data_ptr, len: (ulong)num_bytes); | _handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: data_ptr, len: (ulong)num_bytes); | ||||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
} | } | ||||
public unsafe Tensor(NDArray nd) | public unsafe Tensor(NDArray nd) | ||||
=> _handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer()); | |||||
{ | |||||
_handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer()); | |||||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
} | |||||
#region scala | #region scala | ||||
public Tensor(bool value) => _handle = TF_NewTensor(value); | |||||
public Tensor(byte value) => _handle = TF_NewTensor(value); | |||||
public Tensor(sbyte value) => _handle = TF_NewTensor(value); | |||||
public Tensor(short value) => _handle = TF_NewTensor(value); | |||||
public Tensor(int value) => _handle = TF_NewTensor(value); | |||||
public Tensor(uint value) => _handle = TF_NewTensor(value); | |||||
public Tensor(long value) => _handle = TF_NewTensor(value); | |||||
public Tensor(ulong value) => _handle = TF_NewTensor(value); | |||||
public Tensor(float value) => _handle = TF_NewTensor(value); | |||||
public Tensor(double value) => _handle = TF_NewTensor(value); | |||||
public Tensor(bool value) => InitTensor(value); | |||||
public Tensor(byte value) => InitTensor(value); | |||||
public Tensor(sbyte value) => InitTensor(value); | |||||
public Tensor(short value) => InitTensor(value); | |||||
public Tensor(int value) => InitTensor(value); | |||||
public Tensor(uint value) => InitTensor(value); | |||||
public Tensor(long value) => InitTensor(value); | |||||
public Tensor(ulong value) => InitTensor(value); | |||||
public Tensor(float value) => InitTensor(value); | |||||
public Tensor(double value) => InitTensor(value); | |||||
#endregion | #endregion | ||||
#region 1d array | #region 1d array | ||||
@@ -142,31 +94,74 @@ namespace Tensorflow | |||||
public Tensor(Complex[] data, Shape? shape = null) => InitTensor(data, shape); | public Tensor(Complex[] data, Shape? shape = null) => InitTensor(data, shape); | ||||
#endregion | #endregion | ||||
/// <summary> | |||||
/// Create a string Tensor from the given string | |||||
/// </summary> | |||||
public Tensor(string str) | |||||
public Tensor(Operation op, int value_index, TF_DataType dtype) | |||||
{ | |||||
_op = op; | |||||
_value_index = value_index; | |||||
_override_dtype = dtype; | |||||
_id = ops.uid(); | |||||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
} | |||||
internal Tensor(Shape shape, TF_DataType dtype) => InitTensor(shape, dtype); | |||||
internal Tensor(Array array, Shape? shape = null) => InitTensor(array, shape); | |||||
internal Tensor(string value) => InitTensor(value); | |||||
protected unsafe void InitTensor<T>(T data) where T : unmanaged | |||||
{ | { | ||||
_handle = StringTensor(new string[] { str }, TensorShape.Scalar); | |||||
#if TRACK_TENSOR_LIFE | |||||
print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}"); | |||||
#endif | |||||
_handle = TF_NewTensor(data); | |||||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
} | } | ||||
public Tensor(string[] strings) | |||||
protected unsafe void InitTensor(Shape shape, TF_DataType dtype) | |||||
{ | { | ||||
_handle = StringTensor(strings, new TensorShape(strings.Length)); | |||||
#if TRACK_TENSOR_LIFE | |||||
print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}"); | |||||
#endif | |||||
_handle = TF_NewTensor(shape, dtype, null); | |||||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
} | } | ||||
public Tensor(Operation op, int value_index, TF_DataType dtype) | |||||
protected void InitTensor(string value) | |||||
{ | { | ||||
_op = op; | |||||
_value_index = value_index; | |||||
_override_dtype = dtype; | |||||
_id = ops.uid(); | |||||
_handle = StringTensor(new[] { value }, TensorShape.Scalar); | |||||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
} | |||||
protected unsafe void InitTensor(Array array, Shape? shape = null) | |||||
{ | |||||
shape = shape ?? array.GetShape(); | |||||
var dtype = array.GetType().GetElementType().as_tf_dtype(); | |||||
switch (array) | |||||
{ | |||||
case bool[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case bool[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case bool[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case bool[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case byte[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case byte[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case byte[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case byte[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case int[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case int[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case int[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case int[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case long[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case long[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case long[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case long[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case float[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case float[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case float[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case float[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case double[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case double[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case double[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case double[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
case string[] val: _handle = StringTensor(val, shape); break; | |||||
default: | |||||
throw new NotImplementedException(""); | |||||
} | |||||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -23,7 +23,7 @@ namespace Tensorflow | |||||
public IntPtr StringTensor(byte[][] buffer, TensorShape shape) | public IntPtr StringTensor(byte[][] buffer, TensorShape shape) | ||||
{ | { | ||||
var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, | var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, | ||||
shape.ndim == 0 ? null : shape.dims.Select(x => (long)x).ToArray(), | |||||
shape.ndim == 0 ? null : shape.dims, | |||||
shape.ndim, | shape.ndim, | ||||
(ulong)shape.size * TF_TSRING_SIZE); | (ulong)shape.size * TF_TSRING_SIZE); | ||||
@@ -93,9 +93,13 @@ namespace Tensorflow | |||||
/// TFE_TensorHandle | /// TFE_TensorHandle | ||||
/// </summary> | /// </summary> | ||||
public SafeTensorHandleHandle EagerTensorHandle { get; set; } | public SafeTensorHandleHandle EagerTensorHandle { get; set; } | ||||
protected bool _createdInGraphMode; | |||||
public bool CreatedInGraphMode => _createdInGraphMode; | |||||
public bool IsEagerTensor => this is EagerTensor; | |||||
protected bool isReferencedByNDArray; | |||||
public bool IsReferencedByNDArray => isReferencedByNDArray; | |||||
protected bool isCreatedInGraphMode; | |||||
public bool IsCreatedInGraphMode => isCreatedInGraphMode; | |||||
public bool IsSparseTensor => this is SparseTensor; | public bool IsSparseTensor => this is SparseTensor; | ||||
/// <summary> | /// <summary> | ||||
@@ -207,6 +211,8 @@ namespace Tensorflow | |||||
return _tf_output.Value; | return _tf_output.Value; | ||||
} | } | ||||
public void SetReferencedByNDArray() => isReferencedByNDArray = true; | |||||
public Tensor MaybeMove() | public Tensor MaybeMove() | ||||
{ | { | ||||
var tensor = c_api.TF_TensorMaybeMove(_handle); | var tensor = c_api.TF_TensorMaybeMove(_handle); | ||||
@@ -1,4 +1,5 @@ | |||||
using Tensorflow.NumPy; | |||||
using System.Linq; | |||||
using Tensorflow.NumPy; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -13,7 +14,7 @@ namespace Tensorflow | |||||
public static implicit operator TensorShape(Shape shape) => new TensorShape((long[])shape.dims.Clone()); | public static implicit operator TensorShape(Shape shape) => new TensorShape((long[])shape.dims.Clone()); | ||||
public static implicit operator Shape(TensorShape shape) => shape == null ? null : new Shape((long[])shape.dims.Clone()); | public static implicit operator Shape(TensorShape shape) => shape == null ? null : new Shape((long[])shape.dims.Clone()); | ||||
public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes | |||||
public static implicit operator int[](TensorShape shape) => shape == null ? null : shape.dims.Select(x => (int)x).ToArray(); //we clone to avoid any changes | |||||
public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims); | public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims); | ||||
public static implicit operator long[](TensorShape shape) => shape == null ? null : (long[])shape.dims.Clone(); //we clone to avoid any changes | public static implicit operator long[](TensorShape shape) => shape == null ? null : (long[])shape.dims.Clone(); //we clone to avoid any changes | ||||
@@ -21,7 +21,7 @@ namespace Tensorflow | |||||
public TensorShape shape => items.First().TensorShape; | public TensorShape shape => items.First().TensorShape; | ||||
public int rank => items.First().rank; | public int rank => items.First().rank; | ||||
public Graph graph => items.First().graph; | public Graph graph => items.First().graph; | ||||
public bool IsEagerTensor => items.First().IsEagerTensor; | |||||
public bool IsCreatedInGraphMode => items.First().IsCreatedInGraphMode; | |||||
public bool IsList { get; set; } | public bool IsList { get; set; } | ||||
public int Length => items.Count(); | public int Length => items.Count(); | ||||
@@ -98,7 +98,6 @@ namespace Tensorflow | |||||
attrs: attrs, | attrs: attrs, | ||||
name: name); | name: name); | ||||
var o = op.outputs; | |||||
return op.outputs[0]; | return op.outputs[0]; | ||||
} | } | ||||
@@ -167,9 +166,9 @@ namespace Tensorflow | |||||
case TensorShape val: | case TensorShape val: | ||||
return new EagerTensor(val.dims, ctx.DeviceName); | return new EagerTensor(val.dims, ctx.DeviceName); | ||||
case string val: | case string val: | ||||
return new EagerTensor(val); | |||||
return new EagerTensor(new[] { val }, Shape.Scalar); | |||||
case string[] val: | case string[] val: | ||||
return new EagerTensor(val, ctx.DeviceName); | |||||
return new EagerTensor(val, new Shape(val.Length)); | |||||
case bool val: | case bool val: | ||||
return new EagerTensor(new[] { val }, Shape.Scalar); | return new EagerTensor(new[] { val }, Shape.Scalar); | ||||
case byte val: | case byte val: | ||||
@@ -75,7 +75,7 @@ namespace Tensorflow | |||||
case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX | case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX | ||||
return typeof(Complex); | return typeof(Complex); | ||||
default: | default: | ||||
return null; | |||||
throw new NotSupportedException($"Unable to convert {type} to a system data type."); | |||||
} | } | ||||
} | } | ||||
@@ -83,24 +83,25 @@ namespace Tensorflow | |||||
/// | /// | ||||
/// </summary> | /// </summary> | ||||
/// <param name="type"></param> | /// <param name="type"></param> | ||||
/// <param name="dtype"></param> | |||||
/// <returns></returns> | /// <returns></returns> | ||||
/// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="TF_DataType"/></exception> | /// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="TF_DataType"/></exception> | ||||
public static TF_DataType as_tf_dtype(this Type type, TF_DataType? dtype = null) | |||||
public static TF_DataType as_tf_dtype(this Type type) | |||||
{ | { | ||||
while (type.IsArray) | while (type.IsArray) | ||||
type = type.GetElementType(); | type = type.GetElementType(); | ||||
TF_DataType dtype = TF_DataType.DtInvalid; | |||||
switch (type.Name) | switch (type.Name) | ||||
{ | { | ||||
case "Char": | case "Char": | ||||
dtype = dtype ?? TF_DataType.TF_UINT8; | |||||
dtype = TF_DataType.TF_UINT8; | |||||
break; | break; | ||||
case "SByte": | case "SByte": | ||||
dtype = TF_DataType.TF_INT8; | dtype = TF_DataType.TF_INT8; | ||||
break; | break; | ||||
case "Byte": | case "Byte": | ||||
dtype = dtype ?? TF_DataType.TF_UINT8; | |||||
dtype = TF_DataType.TF_UINT8; | |||||
break; | break; | ||||
case "Int16": | case "Int16": | ||||
dtype = TF_DataType.TF_INT16; | dtype = TF_DataType.TF_INT16; | ||||
@@ -136,60 +137,32 @@ namespace Tensorflow | |||||
dtype = TF_DataType.TF_BOOL; | dtype = TF_DataType.TF_BOOL; | ||||
break; | break; | ||||
default: | default: | ||||
throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); | |||||
throw new NotSupportedException($"Unable to convert {type} to a TensorFlow data type."); | |||||
} | } | ||||
return dtype.Value; | |||||
return dtype; | |||||
} | } | ||||
public static TF_DataType tf_dtype_from_name(string name) | public static TF_DataType tf_dtype_from_name(string name) | ||||
{ | { | ||||
TF_DataType dtype = TF_DataType.DtInvalid; | |||||
switch (name.ToLower()) | |||||
TF_DataType dtype = name.ToLower() switch | |||||
{ | { | ||||
case "char": | |||||
dtype = TF_DataType.TF_UINT8; | |||||
break; | |||||
case "boolean": | |||||
dtype = TF_DataType.TF_BOOL; | |||||
break; | |||||
case "sbyte": | |||||
dtype = TF_DataType.TF_INT8; | |||||
break; | |||||
case "byte": | |||||
dtype = TF_DataType.TF_UINT8; | |||||
break; | |||||
case "int16": | |||||
dtype = TF_DataType.TF_INT16; | |||||
break; | |||||
case "uint16": | |||||
dtype = TF_DataType.TF_UINT16; | |||||
break; | |||||
case "int32": | |||||
dtype = TF_DataType.TF_INT32; | |||||
break; | |||||
case "uint32": | |||||
dtype = TF_DataType.TF_UINT32; | |||||
break; | |||||
case "int64": | |||||
dtype = TF_DataType.TF_INT64; | |||||
break; | |||||
case "uint64": | |||||
dtype = TF_DataType.TF_UINT64; | |||||
break; | |||||
case "single": | |||||
dtype = TF_DataType.TF_FLOAT; | |||||
break; | |||||
case "double": | |||||
dtype = TF_DataType.TF_DOUBLE; | |||||
break; | |||||
case "complex": | |||||
dtype = TF_DataType.TF_COMPLEX128; | |||||
break; | |||||
case "string": | |||||
dtype = TF_DataType.TF_STRING; | |||||
break; | |||||
} | |||||
"char" => TF_DataType.TF_UINT8, | |||||
"boolean" => TF_DataType.TF_BOOL, | |||||
"sbyte" => TF_DataType.TF_INT8, | |||||
"byte" => TF_DataType.TF_UINT8, | |||||
"int16" => TF_DataType.TF_INT16, | |||||
"uint16" => TF_DataType.TF_UINT16, | |||||
"int32" => TF_DataType.TF_INT32, | |||||
"uint32" => TF_DataType.TF_UINT32, | |||||
"int64" => TF_DataType.TF_INT64, | |||||
"uint64" => TF_DataType.TF_UINT64, | |||||
"single" => TF_DataType.TF_FLOAT, | |||||
"double" => TF_DataType.TF_DOUBLE, | |||||
"complex" => TF_DataType.TF_COMPLEX128, | |||||
"string" => TF_DataType.TF_STRING, | |||||
_ => TF_DataType.DtInvalid | |||||
}; | |||||
return dtype; | return dtype; | ||||
} | } | ||||
@@ -108,7 +108,7 @@ namespace Tensorflow | |||||
if (values is TensorProto tp) | if (values is TensorProto tp) | ||||
return tp; | return tp; | ||||
dtype = values.GetType().as_tf_dtype(); | |||||
dtype = values.GetDataType(); | |||||
shape = shape ?? values.GetShape(); | shape = shape ?? values.GetShape(); | ||||
var tensor_proto = new TensorProto | var tensor_proto = new TensorProto | ||||
{ | { | ||||
@@ -117,7 +117,13 @@ namespace Tensorflow | |||||
}; | }; | ||||
// scalar | // scalar | ||||
if (!values.GetType().IsArray) | |||||
if (values is NDArray nd) | |||||
{ | |||||
var len = nd.dtypesize * nd.size; | |||||
byte[] bytes = nd.ToByteArray(); | |||||
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); | |||||
} | |||||
else if (!values.GetType().IsArray) | |||||
{ | { | ||||
switch (values) | switch (values) | ||||
{ | { | ||||
@@ -154,7 +160,7 @@ namespace Tensorflow | |||||
else if (values is byte[] byte_values) | else if (values is byte[] byte_values) | ||||
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(byte_values); | tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(byte_values); | ||||
} | } | ||||
else if(values is Array array) | |||||
else if (values is Array array) | |||||
{ | { | ||||
// array | // array | ||||
var len = dtype.get_datatype_size() * (int)shape.size; | var len = dtype.get_datatype_size() * (int)shape.size; | ||||
@@ -68,7 +68,7 @@ namespace Tensorflow | |||||
// when this object is garbage collected the deleter will be too. This | // when this object is garbage collected the deleter will be too. This | ||||
// means ResourceVariables can be part of reference cycles without those | // means ResourceVariables can be part of reference cycles without those | ||||
// cycles being uncollectable. | // cycles being uncollectable. | ||||
if (handle.IsEagerTensor) | |||||
if (!handle.IsCreatedInGraphMode) | |||||
{ | { | ||||
_handle = handle.EagerTensorHandle.DangerousGetHandle(); | _handle = handle.EagerTensorHandle.DangerousGetHandle(); | ||||
eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); | eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); | ||||
@@ -123,7 +123,7 @@ namespace Tensorflow | |||||
if (dtype == TF_DataType.DtInvalid) | if (dtype == TF_DataType.DtInvalid) | ||||
dtype = preferred_dtype; | dtype = preferred_dtype; | ||||
if (value is EagerTensor eager_tensor) | |||||
if (value is EagerTensor eager_tensor && !eager_tensor.IsCreatedInGraphMode) | |||||
{ | { | ||||
if (tf.executing_eagerly()) | if (tf.executing_eagerly()) | ||||
{ | { | ||||
@@ -140,7 +140,13 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
else if (value is NDArray nd) | else if (value is NDArray nd) | ||||
{ | |||||
return nd; | return nd; | ||||
} | |||||
else if (value is Tensor tensor && tensor.IsReferencedByNDArray) | |||||
{ | |||||
return tensor; | |||||
} | |||||
// graph mode | // graph mode | ||||
Tensor ret = value switch | Tensor ret = value switch | ||||
@@ -115,7 +115,7 @@ namespace Tensorflow.Keras.Engine | |||||
bool _in_functional_construction_mode(Tensors inputs) | bool _in_functional_construction_mode(Tensors inputs) | ||||
{ | { | ||||
return tf.Context.executing_eagerly() | return tf.Context.executing_eagerly() | ||||
&& inputs.Count(x => !x.IsEagerTensor) == inputs.Count(); | |||||
&& inputs.Count(x => x.IsCreatedInGraphMode) == inputs.Count(); | |||||
} | } | ||||
public void SetConnectivityMetadata(Tensors inputs, Tensors outputs) | public void SetConnectivityMetadata(Tensors inputs, Tensors outputs) | ||||
@@ -177,7 +177,7 @@ namespace Tensorflow.Keras.Engine | |||||
tf.init_scope(); | tf.init_scope(); | ||||
bool need_restore_mode = false; | bool need_restore_mode = false; | ||||
if (inputs.IsEagerTensor || tf.Context.is_build_function()) | |||||
if (!inputs.IsCreatedInGraphMode || tf.Context.is_build_function()) | |||||
{ | { | ||||
need_restore_mode = true; | need_restore_mode = true; | ||||
tf.Context.eager_mode(isFunc: tf.Context.is_build_function()); | tf.Context.eager_mode(isFunc: tf.Context.is_build_function()); | ||||
@@ -148,10 +148,10 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
{ | { | ||||
var dataset = tf.data.Dataset.range(10); | var dataset = tf.data.Dataset.range(10); | ||||
var cardinality = dataset.cardinality(); | var cardinality = dataset.cardinality(); | ||||
Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | |||||
Assert.AreEqual(cardinality.numpy(), 10L); | |||||
dataset = dataset.map(x => x[0] + 1); | dataset = dataset.map(x => x[0] + 1); | ||||
cardinality = dataset.cardinality(); | cardinality = dataset.cardinality(); | ||||
Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | |||||
Assert.AreEqual(cardinality.numpy(), 10L); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -160,7 +160,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
var dataset = tf.data.Dataset.range(10); | var dataset = tf.data.Dataset.range(10); | ||||
dataset = dataset.map(x => x, num_parallel_calls: -1); | dataset = dataset.map(x => x, num_parallel_calls: -1); | ||||
var cardinality = dataset.cardinality(); | var cardinality = dataset.cardinality(); | ||||
Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | |||||
Assert.AreEqual(cardinality.numpy(), 10L); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -7,7 +7,7 @@ namespace TensorFlowNET.UnitTest | |||||
[TestClass] | [TestClass] | ||||
public class MnistModelLoaderTest | public class MnistModelLoaderTest | ||||
{ | { | ||||
[TestMethod] | |||||
[TestMethod, Ignore] | |||||
public async Task TestLoad() | public async Task TestLoad() | ||||
{ | { | ||||
var loader = new MnistModelLoader(); | var loader = new MnistModelLoader(); | ||||