diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 54931059..b6403193 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -526,8 +526,19 @@ namespace Tensorflow var type = data.GetType(); switch (data) { - case Shape shape: + case TensorShape: + case Shape: return TF_DataType.TF_INT64; + case Axis: + return TF_DataType.TF_INT32; + case NDArray nd: + return nd.dtype; + case Tensor tensor: + return tensor.dtype; + case Tensor[] tensor: + return tensor[0].dtype; + case ResourceVariable variable: + return variable.dtype; default: return type.as_tf_dtype(); } diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index 95f75a94..6c09c91d 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -142,7 +142,7 @@ namespace Tensorflow.Contexts bool has_graph_arg = !tf.Context.executing_eagerly(); foreach (var el in flatten_args) { - if (el is Tensor tensor && !tensor.IsEagerTensor) + if (el is Tensor tensor && tensor.IsCreatedInGraphMode) { has_graph_arg = true; break; diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs index d5b0f152..d1789aae 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs @@ -50,9 +50,6 @@ namespace Tensorflow.Eager public EagerTensor(Shape shape, TF_DataType dtype) : base(shape, dtype) => NewEagerTensorHandle(_handle); - internal unsafe EagerTensor(string value) : base(value) - => NewEagerTensorHandle(_handle); - internal unsafe EagerTensor(Array array, Shape shape) : base(array, shape) => NewEagerTensorHandle(_handle); diff --git a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs index b4356107..c803b2b3 100644 --- a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs +++ b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs @@ -141,7 +141,7 @@ namespace Tensorflow.Functions src_graph: _func_graph); var captures_from_forward = backwards_graph.external_captures - .Where(x => !x.IsEagerTensor && x.graph == _func_graph) + .Where(x => x.IsCreatedInGraphMode && x.graph == _func_graph) .ToArray(); foreach(var capture in captures_from_forward) { diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs index 43fdde55..12213857 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs @@ -8,20 +8,47 @@ namespace Tensorflow.NumPy { public partial class NDArray { - public NDArray(bool value) => _tensor = new EagerTensor(value); - public NDArray(byte value) => _tensor = new EagerTensor(value); - public NDArray(short value) => _tensor = new EagerTensor(value); - public NDArray(int value) => _tensor = new EagerTensor(value); - public NDArray(long value) => _tensor = new EagerTensor(value); - public NDArray(float value) => _tensor = new EagerTensor(value); - public NDArray(double value) => _tensor = new EagerTensor(value); + public NDArray(bool value) => Init(value); + public NDArray(byte value) => Init(value); + public NDArray(short value) => Init(value); + public NDArray(int value) => Init(value); + public NDArray(long value) => Init(value); + public NDArray(float value) => Init(value); + public NDArray(double value) => Init(value); + public NDArray(Array value, Shape? shape = null) => Init(value, shape); + public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) => Init(shape, dtype: dtype); + public NDArray(Tensor value, Shape? shape = null) => Init(value, shape); - public NDArray(Array value, Shape? shape = null) => _tensor = new EagerTensor(value, shape); + public static NDArray Scalar(T value) where T : unmanaged + => value switch + { + bool val => new NDArray(val), + byte val => new NDArray(val), + int val => new NDArray(val), + float val => new NDArray(val), + double val => new NDArray(val), + _ => throw new NotImplementedException("") + }; + + void Init(T value) where T : unmanaged + { + _tensor = new EagerTensor(value); + _tensor.SetReferencedByNDArray(); + } + + void Init(Array value, Shape? shape = null) + { + _tensor = new EagerTensor(value, shape ?? value.GetShape()); + _tensor.SetReferencedByNDArray(); + } - public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) - => _tensor = new EagerTensor(shape, dtype: dtype); + void Init(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) + { + _tensor = new EagerTensor(shape, dtype: dtype); + _tensor.SetReferencedByNDArray(); + } - public NDArray(Tensor value, Shape? shape = null) + void Init(Tensor value, Shape? shape = null) { if (shape is not null) _tensor = tf.reshape(value, shape); @@ -30,18 +57,8 @@ namespace Tensorflow.NumPy if (_tensor.TensorDataPointer == IntPtr.Zero) _tensor = tf.get_default_session().eval(_tensor); - } - public static NDArray Scalar(T value) where T : unmanaged - { - return value switch - { - bool val => new NDArray(val), - int val => new NDArray(val), - float val => new NDArray(val), - double val => new NDArray(val), - _ => throw new NotImplementedException("") - }; + _tensor.SetReferencedByNDArray(); } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 3cfcb7d0..991c6a51 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -21,6 +21,7 @@ using System.Linq; using System.Numerics; using System.Text; using static Tensorflow.c_api; +using static Tensorflow.Binding; namespace Tensorflow { @@ -31,7 +32,7 @@ namespace Tensorflow public Tensor() { - + isCreatedInGraphMode = !tf.executing_eagerly(); } /// @@ -41,60 +42,7 @@ namespace Tensorflow public Tensor(IntPtr handle) { _handle = handle; - //no need to set AllocationType = AllocationType.None; -#if TRACK_TENSOR_LIFE - print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}"); -#endif - } - - unsafe internal Tensor(Shape shape, TF_DataType dtype) - => _handle = TF_NewTensor(shape, dtype, null); - - internal Tensor(Array array, Shape? shape = null) - => InitTensor(array, shape); - - unsafe void InitTensor(Array array, Shape? shape = null) - { - shape = shape ?? array.GetShape(); - var dtype = array.GetType().GetElementType().as_tf_dtype(); - - switch (array) - { - case bool[] val: - fixed (void* addr = &val[0]) - _handle = TF_NewTensor(shape, dtype, addr); - break; - case int[] val: - fixed (void* addr = &val[0]) - _handle = TF_NewTensor(shape, dtype, addr); - break; - case int[,] val: - fixed (void* addr = &val[0, 0]) - _handle = TF_NewTensor(shape, dtype, addr); - break; - case long[] val: - fixed (void* addr = &val[0]) - _handle = TF_NewTensor(shape, dtype, addr); - break; - case float[] val: - fixed (void* addr = &val[0]) - _handle = TF_NewTensor(shape, dtype, addr); - break; - case float[,] val: - fixed (void* addr = &val[0, 0]) - _handle = TF_NewTensor(shape, dtype, addr); - break; - case double[] val: - fixed (void* addr = &val[0]) - _handle = TF_NewTensor(shape, dtype, addr); - break; - case double[,] val: - fixed (void* addr = &val[0, 0]) - _handle = TF_NewTensor(shape, dtype, addr); - break; - default: - throw new NotImplementedException(""); - } + isCreatedInGraphMode = !tf.executing_eagerly(); } /// @@ -109,22 +57,26 @@ namespace Tensorflow public Tensor(IntPtr data_ptr, long[] shape, TF_DataType dType, int num_bytes) { _handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: data_ptr, len: (ulong)num_bytes); + isCreatedInGraphMode = !tf.executing_eagerly(); } public unsafe Tensor(NDArray nd) - => _handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer()); + { + _handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer()); + isCreatedInGraphMode = !tf.executing_eagerly(); + } #region scala - public Tensor(bool value) => _handle = TF_NewTensor(value); - public Tensor(byte value) => _handle = TF_NewTensor(value); - public Tensor(sbyte value) => _handle = TF_NewTensor(value); - public Tensor(short value) => _handle = TF_NewTensor(value); - public Tensor(int value) => _handle = TF_NewTensor(value); - public Tensor(uint value) => _handle = TF_NewTensor(value); - public Tensor(long value) => _handle = TF_NewTensor(value); - public Tensor(ulong value) => _handle = TF_NewTensor(value); - public Tensor(float value) => _handle = TF_NewTensor(value); - public Tensor(double value) => _handle = TF_NewTensor(value); + public Tensor(bool value) => InitTensor(value); + public Tensor(byte value) => InitTensor(value); + public Tensor(sbyte value) => InitTensor(value); + public Tensor(short value) => InitTensor(value); + public Tensor(int value) => InitTensor(value); + public Tensor(uint value) => InitTensor(value); + public Tensor(long value) => InitTensor(value); + public Tensor(ulong value) => InitTensor(value); + public Tensor(float value) => InitTensor(value); + public Tensor(double value) => InitTensor(value); #endregion #region 1d array @@ -142,31 +94,74 @@ namespace Tensorflow public Tensor(Complex[] data, Shape? shape = null) => InitTensor(data, shape); #endregion - /// - /// Create a string Tensor from the given string - /// - public Tensor(string str) + public Tensor(Operation op, int value_index, TF_DataType dtype) + { + _op = op; + _value_index = value_index; + _override_dtype = dtype; + _id = ops.uid(); + isCreatedInGraphMode = !tf.executing_eagerly(); + } + + internal Tensor(Shape shape, TF_DataType dtype) => InitTensor(shape, dtype); + internal Tensor(Array array, Shape? shape = null) => InitTensor(array, shape); + internal Tensor(string value) => InitTensor(value); + + protected unsafe void InitTensor(T data) where T : unmanaged { - _handle = StringTensor(new string[] { str }, TensorShape.Scalar); -#if TRACK_TENSOR_LIFE - print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}"); -#endif + _handle = TF_NewTensor(data); + isCreatedInGraphMode = !tf.executing_eagerly(); } - public Tensor(string[] strings) + protected unsafe void InitTensor(Shape shape, TF_DataType dtype) { - _handle = StringTensor(strings, new TensorShape(strings.Length)); -#if TRACK_TENSOR_LIFE - print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}"); -#endif + _handle = TF_NewTensor(shape, dtype, null); + isCreatedInGraphMode = !tf.executing_eagerly(); } - public Tensor(Operation op, int value_index, TF_DataType dtype) + protected void InitTensor(string value) { - _op = op; - _value_index = value_index; - _override_dtype = dtype; - _id = ops.uid(); + _handle = StringTensor(new[] { value }, TensorShape.Scalar); + isCreatedInGraphMode = !tf.executing_eagerly(); + } + + protected unsafe void InitTensor(Array array, Shape? shape = null) + { + shape = shape ?? array.GetShape(); + var dtype = array.GetType().GetElementType().as_tf_dtype(); + + switch (array) + { + case bool[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case bool[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case bool[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case bool[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case byte[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case byte[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case byte[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case byte[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case int[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case int[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case int[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case int[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case long[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case long[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case long[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case long[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case float[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case float[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case float[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case float[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case double[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case double[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case double[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case double[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; + case string[] val: _handle = StringTensor(val, shape); break; + default: + throw new NotImplementedException(""); + } + + isCreatedInGraphMode = !tf.executing_eagerly(); } } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.String.cs b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs index 8f08716b..2c5a5038 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.String.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs @@ -23,7 +23,7 @@ namespace Tensorflow public IntPtr StringTensor(byte[][] buffer, TensorShape shape) { var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, - shape.ndim == 0 ? null : shape.dims.Select(x => (long)x).ToArray(), + shape.ndim == 0 ? null : shape.dims, shape.ndim, (ulong)shape.size * TF_TSRING_SIZE); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 05655f92..f0dd4274 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -93,9 +93,13 @@ namespace Tensorflow /// TFE_TensorHandle /// public SafeTensorHandleHandle EagerTensorHandle { get; set; } - protected bool _createdInGraphMode; - public bool CreatedInGraphMode => _createdInGraphMode; - public bool IsEagerTensor => this is EagerTensor; + + protected bool isReferencedByNDArray; + public bool IsReferencedByNDArray => isReferencedByNDArray; + + protected bool isCreatedInGraphMode; + + public bool IsCreatedInGraphMode => isCreatedInGraphMode; public bool IsSparseTensor => this is SparseTensor; /// @@ -207,6 +211,8 @@ namespace Tensorflow return _tf_output.Value; } + public void SetReferencedByNDArray() => isReferencedByNDArray = true; + public Tensor MaybeMove() { var tensor = c_api.TF_TensorMaybeMove(_handle); diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs index ecb273a0..fee26f00 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs @@ -1,4 +1,5 @@ -using Tensorflow.NumPy; +using System.Linq; +using Tensorflow.NumPy; namespace Tensorflow { @@ -13,7 +14,7 @@ namespace Tensorflow public static implicit operator TensorShape(Shape shape) => new TensorShape((long[])shape.dims.Clone()); public static implicit operator Shape(TensorShape shape) => shape == null ? null : new Shape((long[])shape.dims.Clone()); - public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes + public static implicit operator int[](TensorShape shape) => shape == null ? null : shape.dims.Select(x => (int)x).ToArray(); //we clone to avoid any changes public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims); public static implicit operator long[](TensorShape shape) => shape == null ? null : (long[])shape.dims.Clone(); //we clone to avoid any changes diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index 4b191aa0..a9bfe159 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -21,7 +21,7 @@ namespace Tensorflow public TensorShape shape => items.First().TensorShape; public int rank => items.First().rank; public Graph graph => items.First().graph; - public bool IsEagerTensor => items.First().IsEagerTensor; + public bool IsCreatedInGraphMode => items.First().IsCreatedInGraphMode; public bool IsList { get; set; } public int Length => items.Count(); diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 574bffc1..185fd8a5 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -98,7 +98,6 @@ namespace Tensorflow attrs: attrs, name: name); - var o = op.outputs; return op.outputs[0]; } @@ -167,9 +166,9 @@ namespace Tensorflow case TensorShape val: return new EagerTensor(val.dims, ctx.DeviceName); case string val: - return new EagerTensor(val); + return new EagerTensor(new[] { val }, Shape.Scalar); case string[] val: - return new EagerTensor(val, ctx.DeviceName); + return new EagerTensor(val, new Shape(val.Length)); case bool val: return new EagerTensor(new[] { val }, Shape.Scalar); case byte val: diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index a33f3fb8..68d21305 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -75,7 +75,7 @@ namespace Tensorflow case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX return typeof(Complex); default: - return null; + throw new NotSupportedException($"Unable to convert {type} to a system data type."); } } @@ -83,24 +83,25 @@ namespace Tensorflow /// /// /// - /// /// /// When has no equivalent - public static TF_DataType as_tf_dtype(this Type type, TF_DataType? dtype = null) + public static TF_DataType as_tf_dtype(this Type type) { while (type.IsArray) type = type.GetElementType(); + TF_DataType dtype = TF_DataType.DtInvalid; + switch (type.Name) { case "Char": - dtype = dtype ?? TF_DataType.TF_UINT8; + dtype = TF_DataType.TF_UINT8; break; case "SByte": dtype = TF_DataType.TF_INT8; break; case "Byte": - dtype = dtype ?? TF_DataType.TF_UINT8; + dtype = TF_DataType.TF_UINT8; break; case "Int16": dtype = TF_DataType.TF_INT16; @@ -136,60 +137,32 @@ namespace Tensorflow dtype = TF_DataType.TF_BOOL; break; default: - throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); + throw new NotSupportedException($"Unable to convert {type} to a TensorFlow data type."); } - return dtype.Value; + return dtype; } public static TF_DataType tf_dtype_from_name(string name) { - TF_DataType dtype = TF_DataType.DtInvalid; - switch (name.ToLower()) + TF_DataType dtype = name.ToLower() switch { - case "char": - dtype = TF_DataType.TF_UINT8; - break; - case "boolean": - dtype = TF_DataType.TF_BOOL; - break; - case "sbyte": - dtype = TF_DataType.TF_INT8; - break; - case "byte": - dtype = TF_DataType.TF_UINT8; - break; - case "int16": - dtype = TF_DataType.TF_INT16; - break; - case "uint16": - dtype = TF_DataType.TF_UINT16; - break; - case "int32": - dtype = TF_DataType.TF_INT32; - break; - case "uint32": - dtype = TF_DataType.TF_UINT32; - break; - case "int64": - dtype = TF_DataType.TF_INT64; - break; - case "uint64": - dtype = TF_DataType.TF_UINT64; - break; - case "single": - dtype = TF_DataType.TF_FLOAT; - break; - case "double": - dtype = TF_DataType.TF_DOUBLE; - break; - case "complex": - dtype = TF_DataType.TF_COMPLEX128; - break; - case "string": - dtype = TF_DataType.TF_STRING; - break; - } + "char" => TF_DataType.TF_UINT8, + "boolean" => TF_DataType.TF_BOOL, + "sbyte" => TF_DataType.TF_INT8, + "byte" => TF_DataType.TF_UINT8, + "int16" => TF_DataType.TF_INT16, + "uint16" => TF_DataType.TF_UINT16, + "int32" => TF_DataType.TF_INT32, + "uint32" => TF_DataType.TF_UINT32, + "int64" => TF_DataType.TF_INT64, + "uint64" => TF_DataType.TF_UINT64, + "single" => TF_DataType.TF_FLOAT, + "double" => TF_DataType.TF_DOUBLE, + "complex" => TF_DataType.TF_COMPLEX128, + "string" => TF_DataType.TF_STRING, + _ => TF_DataType.DtInvalid + }; return dtype; } diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 58aa455e..0f168904 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -108,7 +108,7 @@ namespace Tensorflow if (values is TensorProto tp) return tp; - dtype = values.GetType().as_tf_dtype(); + dtype = values.GetDataType(); shape = shape ?? values.GetShape(); var tensor_proto = new TensorProto { @@ -117,7 +117,13 @@ namespace Tensorflow }; // scalar - if (!values.GetType().IsArray) + if (values is NDArray nd) + { + var len = nd.dtypesize * nd.size; + byte[] bytes = nd.ToByteArray(); + tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); + } + else if (!values.GetType().IsArray) { switch (values) { @@ -154,7 +160,7 @@ namespace Tensorflow else if (values is byte[] byte_values) tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(byte_values); } - else if(values is Array array) + else if (values is Array array) { // array var len = dtype.get_datatype_size() * (int)shape.size; diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index accea30f..a898fed5 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -68,7 +68,7 @@ namespace Tensorflow // when this object is garbage collected the deleter will be too. This // means ResourceVariables can be part of reference cycles without those // cycles being uncollectable. - if (handle.IsEagerTensor) + if (!handle.IsCreatedInGraphMode) { _handle = handle.EagerTensorHandle.DangerousGetHandle(); eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 0b13a2aa..5e2e8287 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -123,7 +123,7 @@ namespace Tensorflow if (dtype == TF_DataType.DtInvalid) dtype = preferred_dtype; - if (value is EagerTensor eager_tensor) + if (value is EagerTensor eager_tensor && !eager_tensor.IsCreatedInGraphMode) { if (tf.executing_eagerly()) { @@ -140,7 +140,13 @@ namespace Tensorflow } } else if (value is NDArray nd) + { return nd; + } + else if (value is Tensor tensor && tensor.IsReferencedByNDArray) + { + return tensor; + } // graph mode Tensor ret = value switch diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 4a605553..3c936a8b 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -115,7 +115,7 @@ namespace Tensorflow.Keras.Engine bool _in_functional_construction_mode(Tensors inputs) { return tf.Context.executing_eagerly() - && inputs.Count(x => !x.IsEagerTensor) == inputs.Count(); + && inputs.Count(x => x.IsCreatedInGraphMode) == inputs.Count(); } public void SetConnectivityMetadata(Tensors inputs, Tensors outputs) @@ -177,7 +177,7 @@ namespace Tensorflow.Keras.Engine tf.init_scope(); bool need_restore_mode = false; - if (inputs.IsEagerTensor || tf.Context.is_build_function()) + if (!inputs.IsCreatedInGraphMode || tf.Context.is_build_function()) { need_restore_mode = true; tf.Context.eager_mode(isFunc: tf.Context.is_build_function()); diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs index b705284b..8317346e 100644 --- a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs +++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs @@ -148,10 +148,10 @@ namespace TensorFlowNET.UnitTest.Dataset { var dataset = tf.data.Dataset.range(10); var cardinality = dataset.cardinality(); - Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); + Assert.AreEqual(cardinality.numpy(), 10L); dataset = dataset.map(x => x[0] + 1); cardinality = dataset.cardinality(); - Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); + Assert.AreEqual(cardinality.numpy(), 10L); } [TestMethod] @@ -160,7 +160,7 @@ namespace TensorFlowNET.UnitTest.Dataset var dataset = tf.data.Dataset.range(10); dataset = dataset.map(x => x, num_parallel_calls: -1); var cardinality = dataset.cardinality(); - Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); + Assert.AreEqual(cardinality.numpy(), 10L); } [TestMethod] diff --git a/test/TensorFlowNET.UnitTest/Hub/MnistModelLoaderTest.cs b/test/TensorFlowNET.UnitTest/Hub/MnistModelLoaderTest.cs index e2fc0c89..b16a5f3d 100644 --- a/test/TensorFlowNET.UnitTest/Hub/MnistModelLoaderTest.cs +++ b/test/TensorFlowNET.UnitTest/Hub/MnistModelLoaderTest.cs @@ -7,7 +7,7 @@ namespace TensorFlowNET.UnitTest [TestClass] public class MnistModelLoaderTest { - [TestMethod] + [TestMethod, Ignore] public async Task TestLoad() { var loader = new MnistModelLoader();