@@ -526,8 +526,19 @@ namespace Tensorflow | |||
var type = data.GetType(); | |||
switch (data) | |||
{ | |||
case Shape shape: | |||
case TensorShape: | |||
case Shape: | |||
return TF_DataType.TF_INT64; | |||
case Axis: | |||
return TF_DataType.TF_INT32; | |||
case NDArray nd: | |||
return nd.dtype; | |||
case Tensor tensor: | |||
return tensor.dtype; | |||
case Tensor[] tensor: | |||
return tensor[0].dtype; | |||
case ResourceVariable variable: | |||
return variable.dtype; | |||
default: | |||
return type.as_tf_dtype(); | |||
} | |||
@@ -142,7 +142,7 @@ namespace Tensorflow.Contexts | |||
bool has_graph_arg = !tf.Context.executing_eagerly(); | |||
foreach (var el in flatten_args) | |||
{ | |||
if (el is Tensor tensor && !tensor.IsEagerTensor) | |||
if (el is Tensor tensor && tensor.IsCreatedInGraphMode) | |||
{ | |||
has_graph_arg = true; | |||
break; | |||
@@ -50,9 +50,6 @@ namespace Tensorflow.Eager | |||
public EagerTensor(Shape shape, TF_DataType dtype) : base(shape, dtype) | |||
=> NewEagerTensorHandle(_handle); | |||
internal unsafe EagerTensor(string value) : base(value) | |||
=> NewEagerTensorHandle(_handle); | |||
internal unsafe EagerTensor(Array array, Shape shape) : base(array, shape) | |||
=> NewEagerTensorHandle(_handle); | |||
@@ -141,7 +141,7 @@ namespace Tensorflow.Functions | |||
src_graph: _func_graph); | |||
var captures_from_forward = backwards_graph.external_captures | |||
.Where(x => !x.IsEagerTensor && x.graph == _func_graph) | |||
.Where(x => x.IsCreatedInGraphMode && x.graph == _func_graph) | |||
.ToArray(); | |||
foreach(var capture in captures_from_forward) | |||
{ | |||
@@ -8,20 +8,47 @@ namespace Tensorflow.NumPy | |||
{ | |||
public partial class NDArray | |||
{ | |||
public NDArray(bool value) => _tensor = new EagerTensor(value); | |||
public NDArray(byte value) => _tensor = new EagerTensor(value); | |||
public NDArray(short value) => _tensor = new EagerTensor(value); | |||
public NDArray(int value) => _tensor = new EagerTensor(value); | |||
public NDArray(long value) => _tensor = new EagerTensor(value); | |||
public NDArray(float value) => _tensor = new EagerTensor(value); | |||
public NDArray(double value) => _tensor = new EagerTensor(value); | |||
public NDArray(bool value) => Init(value); | |||
public NDArray(byte value) => Init(value); | |||
public NDArray(short value) => Init(value); | |||
public NDArray(int value) => Init(value); | |||
public NDArray(long value) => Init(value); | |||
public NDArray(float value) => Init(value); | |||
public NDArray(double value) => Init(value); | |||
public NDArray(Array value, Shape? shape = null) => Init(value, shape); | |||
public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) => Init(shape, dtype: dtype); | |||
public NDArray(Tensor value, Shape? shape = null) => Init(value, shape); | |||
public NDArray(Array value, Shape? shape = null) => _tensor = new EagerTensor(value, shape); | |||
public static NDArray Scalar<T>(T value) where T : unmanaged | |||
=> value switch | |||
{ | |||
bool val => new NDArray(val), | |||
byte val => new NDArray(val), | |||
int val => new NDArray(val), | |||
float val => new NDArray(val), | |||
double val => new NDArray(val), | |||
_ => throw new NotImplementedException("") | |||
}; | |||
void Init<T>(T value) where T : unmanaged | |||
{ | |||
_tensor = new EagerTensor(value); | |||
_tensor.SetReferencedByNDArray(); | |||
} | |||
void Init(Array value, Shape? shape = null) | |||
{ | |||
_tensor = new EagerTensor(value, shape ?? value.GetShape()); | |||
_tensor.SetReferencedByNDArray(); | |||
} | |||
public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) | |||
=> _tensor = new EagerTensor(shape, dtype: dtype); | |||
void Init(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) | |||
{ | |||
_tensor = new EagerTensor(shape, dtype: dtype); | |||
_tensor.SetReferencedByNDArray(); | |||
} | |||
public NDArray(Tensor value, Shape? shape = null) | |||
void Init(Tensor value, Shape? shape = null) | |||
{ | |||
if (shape is not null) | |||
_tensor = tf.reshape(value, shape); | |||
@@ -30,18 +57,8 @@ namespace Tensorflow.NumPy | |||
if (_tensor.TensorDataPointer == IntPtr.Zero) | |||
_tensor = tf.get_default_session().eval(_tensor); | |||
} | |||
public static NDArray Scalar<T>(T value) where T : unmanaged | |||
{ | |||
return value switch | |||
{ | |||
bool val => new NDArray(val), | |||
int val => new NDArray(val), | |||
float val => new NDArray(val), | |||
double val => new NDArray(val), | |||
_ => throw new NotImplementedException("") | |||
}; | |||
_tensor.SetReferencedByNDArray(); | |||
} | |||
} | |||
} |
@@ -21,6 +21,7 @@ using System.Linq; | |||
using System.Numerics; | |||
using System.Text; | |||
using static Tensorflow.c_api; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
@@ -31,7 +32,7 @@ namespace Tensorflow | |||
public Tensor() | |||
{ | |||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||
} | |||
/// <summary> | |||
@@ -41,60 +42,7 @@ namespace Tensorflow | |||
public Tensor(IntPtr handle) | |||
{ | |||
_handle = handle; | |||
//no need to set AllocationType = AllocationType.None; | |||
#if TRACK_TENSOR_LIFE | |||
print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}"); | |||
#endif | |||
} | |||
unsafe internal Tensor(Shape shape, TF_DataType dtype) | |||
=> _handle = TF_NewTensor(shape, dtype, null); | |||
internal Tensor(Array array, Shape? shape = null) | |||
=> InitTensor(array, shape); | |||
unsafe void InitTensor(Array array, Shape? shape = null) | |||
{ | |||
shape = shape ?? array.GetShape(); | |||
var dtype = array.GetType().GetElementType().as_tf_dtype(); | |||
switch (array) | |||
{ | |||
case bool[] val: | |||
fixed (void* addr = &val[0]) | |||
_handle = TF_NewTensor(shape, dtype, addr); | |||
break; | |||
case int[] val: | |||
fixed (void* addr = &val[0]) | |||
_handle = TF_NewTensor(shape, dtype, addr); | |||
break; | |||
case int[,] val: | |||
fixed (void* addr = &val[0, 0]) | |||
_handle = TF_NewTensor(shape, dtype, addr); | |||
break; | |||
case long[] val: | |||
fixed (void* addr = &val[0]) | |||
_handle = TF_NewTensor(shape, dtype, addr); | |||
break; | |||
case float[] val: | |||
fixed (void* addr = &val[0]) | |||
_handle = TF_NewTensor(shape, dtype, addr); | |||
break; | |||
case float[,] val: | |||
fixed (void* addr = &val[0, 0]) | |||
_handle = TF_NewTensor(shape, dtype, addr); | |||
break; | |||
case double[] val: | |||
fixed (void* addr = &val[0]) | |||
_handle = TF_NewTensor(shape, dtype, addr); | |||
break; | |||
case double[,] val: | |||
fixed (void* addr = &val[0, 0]) | |||
_handle = TF_NewTensor(shape, dtype, addr); | |||
break; | |||
default: | |||
throw new NotImplementedException(""); | |||
} | |||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||
} | |||
/// <summary> | |||
@@ -109,22 +57,26 @@ namespace Tensorflow | |||
public Tensor(IntPtr data_ptr, long[] shape, TF_DataType dType, int num_bytes) | |||
{ | |||
_handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: data_ptr, len: (ulong)num_bytes); | |||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||
} | |||
public unsafe Tensor(NDArray nd) | |||
=> _handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer()); | |||
{ | |||
_handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer()); | |||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||
} | |||
#region scala | |||
public Tensor(bool value) => _handle = TF_NewTensor(value); | |||
public Tensor(byte value) => _handle = TF_NewTensor(value); | |||
public Tensor(sbyte value) => _handle = TF_NewTensor(value); | |||
public Tensor(short value) => _handle = TF_NewTensor(value); | |||
public Tensor(int value) => _handle = TF_NewTensor(value); | |||
public Tensor(uint value) => _handle = TF_NewTensor(value); | |||
public Tensor(long value) => _handle = TF_NewTensor(value); | |||
public Tensor(ulong value) => _handle = TF_NewTensor(value); | |||
public Tensor(float value) => _handle = TF_NewTensor(value); | |||
public Tensor(double value) => _handle = TF_NewTensor(value); | |||
public Tensor(bool value) => InitTensor(value); | |||
public Tensor(byte value) => InitTensor(value); | |||
public Tensor(sbyte value) => InitTensor(value); | |||
public Tensor(short value) => InitTensor(value); | |||
public Tensor(int value) => InitTensor(value); | |||
public Tensor(uint value) => InitTensor(value); | |||
public Tensor(long value) => InitTensor(value); | |||
public Tensor(ulong value) => InitTensor(value); | |||
public Tensor(float value) => InitTensor(value); | |||
public Tensor(double value) => InitTensor(value); | |||
#endregion | |||
#region 1d array | |||
@@ -142,31 +94,74 @@ namespace Tensorflow | |||
public Tensor(Complex[] data, Shape? shape = null) => InitTensor(data, shape); | |||
#endregion | |||
/// <summary> | |||
/// Create a string Tensor from the given string | |||
/// </summary> | |||
public Tensor(string str) | |||
public Tensor(Operation op, int value_index, TF_DataType dtype) | |||
{ | |||
_op = op; | |||
_value_index = value_index; | |||
_override_dtype = dtype; | |||
_id = ops.uid(); | |||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||
} | |||
internal Tensor(Shape shape, TF_DataType dtype) => InitTensor(shape, dtype); | |||
internal Tensor(Array array, Shape? shape = null) => InitTensor(array, shape); | |||
internal Tensor(string value) => InitTensor(value); | |||
protected unsafe void InitTensor<T>(T data) where T : unmanaged | |||
{ | |||
_handle = StringTensor(new string[] { str }, TensorShape.Scalar); | |||
#if TRACK_TENSOR_LIFE | |||
print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}"); | |||
#endif | |||
_handle = TF_NewTensor(data); | |||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||
} | |||
public Tensor(string[] strings) | |||
protected unsafe void InitTensor(Shape shape, TF_DataType dtype) | |||
{ | |||
_handle = StringTensor(strings, new TensorShape(strings.Length)); | |||
#if TRACK_TENSOR_LIFE | |||
print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}"); | |||
#endif | |||
_handle = TF_NewTensor(shape, dtype, null); | |||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||
} | |||
public Tensor(Operation op, int value_index, TF_DataType dtype) | |||
protected void InitTensor(string value) | |||
{ | |||
_op = op; | |||
_value_index = value_index; | |||
_override_dtype = dtype; | |||
_id = ops.uid(); | |||
_handle = StringTensor(new[] { value }, TensorShape.Scalar); | |||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||
} | |||
protected unsafe void InitTensor(Array array, Shape? shape = null) | |||
{ | |||
shape = shape ?? array.GetShape(); | |||
var dtype = array.GetType().GetElementType().as_tf_dtype(); | |||
switch (array) | |||
{ | |||
case bool[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case bool[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case bool[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case bool[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case byte[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case byte[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case byte[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case byte[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case int[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case int[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case int[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case int[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case long[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case long[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case long[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case long[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case float[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case float[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case float[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case float[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case double[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case double[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case double[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case double[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
case string[] val: _handle = StringTensor(val, shape); break; | |||
default: | |||
throw new NotImplementedException(""); | |||
} | |||
isCreatedInGraphMode = !tf.executing_eagerly(); | |||
} | |||
} | |||
} |
@@ -23,7 +23,7 @@ namespace Tensorflow | |||
public IntPtr StringTensor(byte[][] buffer, TensorShape shape) | |||
{ | |||
var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, | |||
shape.ndim == 0 ? null : shape.dims.Select(x => (long)x).ToArray(), | |||
shape.ndim == 0 ? null : shape.dims, | |||
shape.ndim, | |||
(ulong)shape.size * TF_TSRING_SIZE); | |||
@@ -93,9 +93,13 @@ namespace Tensorflow | |||
/// TFE_TensorHandle | |||
/// </summary> | |||
public SafeTensorHandleHandle EagerTensorHandle { get; set; } | |||
protected bool _createdInGraphMode; | |||
public bool CreatedInGraphMode => _createdInGraphMode; | |||
public bool IsEagerTensor => this is EagerTensor; | |||
protected bool isReferencedByNDArray; | |||
public bool IsReferencedByNDArray => isReferencedByNDArray; | |||
protected bool isCreatedInGraphMode; | |||
public bool IsCreatedInGraphMode => isCreatedInGraphMode; | |||
public bool IsSparseTensor => this is SparseTensor; | |||
/// <summary> | |||
@@ -207,6 +211,8 @@ namespace Tensorflow | |||
return _tf_output.Value; | |||
} | |||
public void SetReferencedByNDArray() => isReferencedByNDArray = true; | |||
public Tensor MaybeMove() | |||
{ | |||
var tensor = c_api.TF_TensorMaybeMove(_handle); | |||
@@ -1,4 +1,5 @@ | |||
using Tensorflow.NumPy; | |||
using System.Linq; | |||
using Tensorflow.NumPy; | |||
namespace Tensorflow | |||
{ | |||
@@ -13,7 +14,7 @@ namespace Tensorflow | |||
public static implicit operator TensorShape(Shape shape) => new TensorShape((long[])shape.dims.Clone()); | |||
public static implicit operator Shape(TensorShape shape) => shape == null ? null : new Shape((long[])shape.dims.Clone()); | |||
public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes | |||
public static implicit operator int[](TensorShape shape) => shape == null ? null : shape.dims.Select(x => (int)x).ToArray(); //we clone to avoid any changes | |||
public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims); | |||
public static implicit operator long[](TensorShape shape) => shape == null ? null : (long[])shape.dims.Clone(); //we clone to avoid any changes | |||
@@ -21,7 +21,7 @@ namespace Tensorflow | |||
public TensorShape shape => items.First().TensorShape; | |||
public int rank => items.First().rank; | |||
public Graph graph => items.First().graph; | |||
public bool IsEagerTensor => items.First().IsEagerTensor; | |||
public bool IsCreatedInGraphMode => items.First().IsCreatedInGraphMode; | |||
public bool IsList { get; set; } | |||
public int Length => items.Count(); | |||
@@ -98,7 +98,6 @@ namespace Tensorflow | |||
attrs: attrs, | |||
name: name); | |||
var o = op.outputs; | |||
return op.outputs[0]; | |||
} | |||
@@ -167,9 +166,9 @@ namespace Tensorflow | |||
case TensorShape val: | |||
return new EagerTensor(val.dims, ctx.DeviceName); | |||
case string val: | |||
return new EagerTensor(val); | |||
return new EagerTensor(new[] { val }, Shape.Scalar); | |||
case string[] val: | |||
return new EagerTensor(val, ctx.DeviceName); | |||
return new EagerTensor(val, new Shape(val.Length)); | |||
case bool val: | |||
return new EagerTensor(new[] { val }, Shape.Scalar); | |||
case byte val: | |||
@@ -75,7 +75,7 @@ namespace Tensorflow | |||
case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX | |||
return typeof(Complex); | |||
default: | |||
return null; | |||
throw new NotSupportedException($"Unable to convert {type} to a system data type."); | |||
} | |||
} | |||
@@ -83,24 +83,25 @@ namespace Tensorflow | |||
/// | |||
/// </summary> | |||
/// <param name="type"></param> | |||
/// <param name="dtype"></param> | |||
/// <returns></returns> | |||
/// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="TF_DataType"/></exception> | |||
public static TF_DataType as_tf_dtype(this Type type, TF_DataType? dtype = null) | |||
public static TF_DataType as_tf_dtype(this Type type) | |||
{ | |||
while (type.IsArray) | |||
type = type.GetElementType(); | |||
TF_DataType dtype = TF_DataType.DtInvalid; | |||
switch (type.Name) | |||
{ | |||
case "Char": | |||
dtype = dtype ?? TF_DataType.TF_UINT8; | |||
dtype = TF_DataType.TF_UINT8; | |||
break; | |||
case "SByte": | |||
dtype = TF_DataType.TF_INT8; | |||
break; | |||
case "Byte": | |||
dtype = dtype ?? TF_DataType.TF_UINT8; | |||
dtype = TF_DataType.TF_UINT8; | |||
break; | |||
case "Int16": | |||
dtype = TF_DataType.TF_INT16; | |||
@@ -136,60 +137,32 @@ namespace Tensorflow | |||
dtype = TF_DataType.TF_BOOL; | |||
break; | |||
default: | |||
throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); | |||
throw new NotSupportedException($"Unable to convert {type} to a TensorFlow data type."); | |||
} | |||
return dtype.Value; | |||
return dtype; | |||
} | |||
public static TF_DataType tf_dtype_from_name(string name) | |||
{ | |||
TF_DataType dtype = TF_DataType.DtInvalid; | |||
switch (name.ToLower()) | |||
TF_DataType dtype = name.ToLower() switch | |||
{ | |||
case "char": | |||
dtype = TF_DataType.TF_UINT8; | |||
break; | |||
case "boolean": | |||
dtype = TF_DataType.TF_BOOL; | |||
break; | |||
case "sbyte": | |||
dtype = TF_DataType.TF_INT8; | |||
break; | |||
case "byte": | |||
dtype = TF_DataType.TF_UINT8; | |||
break; | |||
case "int16": | |||
dtype = TF_DataType.TF_INT16; | |||
break; | |||
case "uint16": | |||
dtype = TF_DataType.TF_UINT16; | |||
break; | |||
case "int32": | |||
dtype = TF_DataType.TF_INT32; | |||
break; | |||
case "uint32": | |||
dtype = TF_DataType.TF_UINT32; | |||
break; | |||
case "int64": | |||
dtype = TF_DataType.TF_INT64; | |||
break; | |||
case "uint64": | |||
dtype = TF_DataType.TF_UINT64; | |||
break; | |||
case "single": | |||
dtype = TF_DataType.TF_FLOAT; | |||
break; | |||
case "double": | |||
dtype = TF_DataType.TF_DOUBLE; | |||
break; | |||
case "complex": | |||
dtype = TF_DataType.TF_COMPLEX128; | |||
break; | |||
case "string": | |||
dtype = TF_DataType.TF_STRING; | |||
break; | |||
} | |||
"char" => TF_DataType.TF_UINT8, | |||
"boolean" => TF_DataType.TF_BOOL, | |||
"sbyte" => TF_DataType.TF_INT8, | |||
"byte" => TF_DataType.TF_UINT8, | |||
"int16" => TF_DataType.TF_INT16, | |||
"uint16" => TF_DataType.TF_UINT16, | |||
"int32" => TF_DataType.TF_INT32, | |||
"uint32" => TF_DataType.TF_UINT32, | |||
"int64" => TF_DataType.TF_INT64, | |||
"uint64" => TF_DataType.TF_UINT64, | |||
"single" => TF_DataType.TF_FLOAT, | |||
"double" => TF_DataType.TF_DOUBLE, | |||
"complex" => TF_DataType.TF_COMPLEX128, | |||
"string" => TF_DataType.TF_STRING, | |||
_ => TF_DataType.DtInvalid | |||
}; | |||
return dtype; | |||
} | |||
@@ -108,7 +108,7 @@ namespace Tensorflow | |||
if (values is TensorProto tp) | |||
return tp; | |||
dtype = values.GetType().as_tf_dtype(); | |||
dtype = values.GetDataType(); | |||
shape = shape ?? values.GetShape(); | |||
var tensor_proto = new TensorProto | |||
{ | |||
@@ -117,7 +117,13 @@ namespace Tensorflow | |||
}; | |||
// scalar | |||
if (!values.GetType().IsArray) | |||
if (values is NDArray nd) | |||
{ | |||
var len = nd.dtypesize * nd.size; | |||
byte[] bytes = nd.ToByteArray(); | |||
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); | |||
} | |||
else if (!values.GetType().IsArray) | |||
{ | |||
switch (values) | |||
{ | |||
@@ -154,7 +160,7 @@ namespace Tensorflow | |||
else if (values is byte[] byte_values) | |||
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(byte_values); | |||
} | |||
else if(values is Array array) | |||
else if (values is Array array) | |||
{ | |||
// array | |||
var len = dtype.get_datatype_size() * (int)shape.size; | |||
@@ -68,7 +68,7 @@ namespace Tensorflow | |||
// when this object is garbage collected the deleter will be too. This | |||
// means ResourceVariables can be part of reference cycles without those | |||
// cycles being uncollectable. | |||
if (handle.IsEagerTensor) | |||
if (!handle.IsCreatedInGraphMode) | |||
{ | |||
_handle = handle.EagerTensorHandle.DangerousGetHandle(); | |||
eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); | |||
@@ -123,7 +123,7 @@ namespace Tensorflow | |||
if (dtype == TF_DataType.DtInvalid) | |||
dtype = preferred_dtype; | |||
if (value is EagerTensor eager_tensor) | |||
if (value is EagerTensor eager_tensor && !eager_tensor.IsCreatedInGraphMode) | |||
{ | |||
if (tf.executing_eagerly()) | |||
{ | |||
@@ -140,7 +140,13 @@ namespace Tensorflow | |||
} | |||
} | |||
else if (value is NDArray nd) | |||
{ | |||
return nd; | |||
} | |||
else if (value is Tensor tensor && tensor.IsReferencedByNDArray) | |||
{ | |||
return tensor; | |||
} | |||
// graph mode | |||
Tensor ret = value switch | |||
@@ -115,7 +115,7 @@ namespace Tensorflow.Keras.Engine | |||
bool _in_functional_construction_mode(Tensors inputs) | |||
{ | |||
return tf.Context.executing_eagerly() | |||
&& inputs.Count(x => !x.IsEagerTensor) == inputs.Count(); | |||
&& inputs.Count(x => x.IsCreatedInGraphMode) == inputs.Count(); | |||
} | |||
public void SetConnectivityMetadata(Tensors inputs, Tensors outputs) | |||
@@ -177,7 +177,7 @@ namespace Tensorflow.Keras.Engine | |||
tf.init_scope(); | |||
bool need_restore_mode = false; | |||
if (inputs.IsEagerTensor || tf.Context.is_build_function()) | |||
if (!inputs.IsCreatedInGraphMode || tf.Context.is_build_function()) | |||
{ | |||
need_restore_mode = true; | |||
tf.Context.eager_mode(isFunc: tf.Context.is_build_function()); | |||
@@ -148,10 +148,10 @@ namespace TensorFlowNET.UnitTest.Dataset | |||
{ | |||
var dataset = tf.data.Dataset.range(10); | |||
var cardinality = dataset.cardinality(); | |||
Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | |||
Assert.AreEqual(cardinality.numpy(), 10L); | |||
dataset = dataset.map(x => x[0] + 1); | |||
cardinality = dataset.cardinality(); | |||
Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | |||
Assert.AreEqual(cardinality.numpy(), 10L); | |||
} | |||
[TestMethod] | |||
@@ -160,7 +160,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||
var dataset = tf.data.Dataset.range(10); | |||
dataset = dataset.map(x => x, num_parallel_calls: -1); | |||
var cardinality = dataset.cardinality(); | |||
Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | |||
Assert.AreEqual(cardinality.numpy(), 10L); | |||
} | |||
[TestMethod] | |||
@@ -7,7 +7,7 @@ namespace TensorFlowNET.UnitTest | |||
[TestClass] | |||
public class MnistModelLoaderTest | |||
{ | |||
[TestMethod] | |||
[TestMethod, Ignore] | |||
public async Task TestLoad() | |||
{ | |||
var loader = new MnistModelLoader(); | |||