Browse Source

fix GetDataType.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
3c7207c251
18 changed files with 196 additions and 185 deletions
  1. +12
    -1
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Contexts/Context.cs
  3. +0
    -3
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
  5. +39
    -22
      src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
  6. +79
    -84
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.String.cs
  8. +9
    -3
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  9. +3
    -2
      src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  11. +2
    -3
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  12. +25
    -52
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  13. +9
    -3
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  14. +1
    -1
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  15. +7
    -1
      src/TensorFlowNET.Core/ops.cs
  16. +2
    -2
      src/TensorFlowNET.Keras/Engine/Layer.cs
  17. +3
    -3
      test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs
  18. +1
    -1
      test/TensorFlowNET.UnitTest/Hub/MnistModelLoaderTest.cs

+ 12
- 1
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -526,8 +526,19 @@ namespace Tensorflow
var type = data.GetType();
switch (data)
{
case Shape shape:
case TensorShape:
case Shape:
return TF_DataType.TF_INT64;
case Axis:
return TF_DataType.TF_INT32;
case NDArray nd:
return nd.dtype;
case Tensor tensor:
return tensor.dtype;
case Tensor[] tensor:
return tensor[0].dtype;
case ResourceVariable variable:
return variable.dtype;
default:
return type.as_tf_dtype();
}


+ 1
- 1
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -142,7 +142,7 @@ namespace Tensorflow.Contexts
bool has_graph_arg = !tf.Context.executing_eagerly();
foreach (var el in flatten_args)
{
if (el is Tensor tensor && !tensor.IsEagerTensor)
if (el is Tensor tensor && tensor.IsCreatedInGraphMode)
{
has_graph_arg = true;
break;


+ 0
- 3
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -50,9 +50,6 @@ namespace Tensorflow.Eager
public EagerTensor(Shape shape, TF_DataType dtype) : base(shape, dtype)
=> NewEagerTensorHandle(_handle);

internal unsafe EagerTensor(string value) : base(value)
=> NewEagerTensorHandle(_handle);

internal unsafe EagerTensor(Array array, Shape shape) : base(array, shape)
=> NewEagerTensorHandle(_handle);



+ 1
- 1
src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs View File

@@ -141,7 +141,7 @@ namespace Tensorflow.Functions
src_graph: _func_graph);

var captures_from_forward = backwards_graph.external_captures
.Where(x => !x.IsEagerTensor && x.graph == _func_graph)
.Where(x => x.IsCreatedInGraphMode && x.graph == _func_graph)
.ToArray();
foreach(var capture in captures_from_forward)
{


+ 39
- 22
src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs View File

@@ -8,20 +8,47 @@ namespace Tensorflow.NumPy
{
public partial class NDArray
{
public NDArray(bool value) => _tensor = new EagerTensor(value);
public NDArray(byte value) => _tensor = new EagerTensor(value);
public NDArray(short value) => _tensor = new EagerTensor(value);
public NDArray(int value) => _tensor = new EagerTensor(value);
public NDArray(long value) => _tensor = new EagerTensor(value);
public NDArray(float value) => _tensor = new EagerTensor(value);
public NDArray(double value) => _tensor = new EagerTensor(value);
public NDArray(bool value) => Init(value);
public NDArray(byte value) => Init(value);
public NDArray(short value) => Init(value);
public NDArray(int value) => Init(value);
public NDArray(long value) => Init(value);
public NDArray(float value) => Init(value);
public NDArray(double value) => Init(value);
public NDArray(Array value, Shape? shape = null) => Init(value, shape);
public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) => Init(shape, dtype: dtype);
public NDArray(Tensor value, Shape? shape = null) => Init(value, shape);

public NDArray(Array value, Shape? shape = null) => _tensor = new EagerTensor(value, shape);
public static NDArray Scalar<T>(T value) where T : unmanaged
=> value switch
{
bool val => new NDArray(val),
byte val => new NDArray(val),
int val => new NDArray(val),
float val => new NDArray(val),
double val => new NDArray(val),
_ => throw new NotImplementedException("")
};

void Init<T>(T value) where T : unmanaged
{
_tensor = new EagerTensor(value);
_tensor.SetReferencedByNDArray();
}

void Init(Array value, Shape? shape = null)
{
_tensor = new EagerTensor(value, shape ?? value.GetShape());
_tensor.SetReferencedByNDArray();
}

public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
=> _tensor = new EagerTensor(shape, dtype: dtype);
void Init(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
{
_tensor = new EagerTensor(shape, dtype: dtype);
_tensor.SetReferencedByNDArray();
}

public NDArray(Tensor value, Shape? shape = null)
void Init(Tensor value, Shape? shape = null)
{
if (shape is not null)
_tensor = tf.reshape(value, shape);
@@ -30,18 +57,8 @@ namespace Tensorflow.NumPy

if (_tensor.TensorDataPointer == IntPtr.Zero)
_tensor = tf.get_default_session().eval(_tensor);
}

public static NDArray Scalar<T>(T value) where T : unmanaged
{
return value switch
{
bool val => new NDArray(val),
int val => new NDArray(val),
float val => new NDArray(val),
double val => new NDArray(val),
_ => throw new NotImplementedException("")
};
_tensor.SetReferencedByNDArray();
}
}
}

+ 79
- 84
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -21,6 +21,7 @@ using System.Linq;
using System.Numerics;
using System.Text;
using static Tensorflow.c_api;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -31,7 +32,7 @@ namespace Tensorflow

public Tensor()
{
isCreatedInGraphMode = !tf.executing_eagerly();
}

/// <summary>
@@ -41,60 +42,7 @@ namespace Tensorflow
public Tensor(IntPtr handle)
{
_handle = handle;
//no need to set AllocationType = AllocationType.None;
#if TRACK_TENSOR_LIFE
print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}");
#endif
}

unsafe internal Tensor(Shape shape, TF_DataType dtype)
=> _handle = TF_NewTensor(shape, dtype, null);

internal Tensor(Array array, Shape? shape = null)
=> InitTensor(array, shape);

unsafe void InitTensor(Array array, Shape? shape = null)
{
shape = shape ?? array.GetShape();
var dtype = array.GetType().GetElementType().as_tf_dtype();

switch (array)
{
case bool[] val:
fixed (void* addr = &val[0])
_handle = TF_NewTensor(shape, dtype, addr);
break;
case int[] val:
fixed (void* addr = &val[0])
_handle = TF_NewTensor(shape, dtype, addr);
break;
case int[,] val:
fixed (void* addr = &val[0, 0])
_handle = TF_NewTensor(shape, dtype, addr);
break;
case long[] val:
fixed (void* addr = &val[0])
_handle = TF_NewTensor(shape, dtype, addr);
break;
case float[] val:
fixed (void* addr = &val[0])
_handle = TF_NewTensor(shape, dtype, addr);
break;
case float[,] val:
fixed (void* addr = &val[0, 0])
_handle = TF_NewTensor(shape, dtype, addr);
break;
case double[] val:
fixed (void* addr = &val[0])
_handle = TF_NewTensor(shape, dtype, addr);
break;
case double[,] val:
fixed (void* addr = &val[0, 0])
_handle = TF_NewTensor(shape, dtype, addr);
break;
default:
throw new NotImplementedException("");
}
isCreatedInGraphMode = !tf.executing_eagerly();
}

/// <summary>
@@ -109,22 +57,26 @@ namespace Tensorflow
public Tensor(IntPtr data_ptr, long[] shape, TF_DataType dType, int num_bytes)
{
_handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: data_ptr, len: (ulong)num_bytes);
isCreatedInGraphMode = !tf.executing_eagerly();
}

public unsafe Tensor(NDArray nd)
=> _handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer());
{
_handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer());
isCreatedInGraphMode = !tf.executing_eagerly();
}

#region scala
public Tensor(bool value) => _handle = TF_NewTensor(value);
public Tensor(byte value) => _handle = TF_NewTensor(value);
public Tensor(sbyte value) => _handle = TF_NewTensor(value);
public Tensor(short value) => _handle = TF_NewTensor(value);
public Tensor(int value) => _handle = TF_NewTensor(value);
public Tensor(uint value) => _handle = TF_NewTensor(value);
public Tensor(long value) => _handle = TF_NewTensor(value);
public Tensor(ulong value) => _handle = TF_NewTensor(value);
public Tensor(float value) => _handle = TF_NewTensor(value);
public Tensor(double value) => _handle = TF_NewTensor(value);
public Tensor(bool value) => InitTensor(value);
public Tensor(byte value) => InitTensor(value);
public Tensor(sbyte value) => InitTensor(value);
public Tensor(short value) => InitTensor(value);
public Tensor(int value) => InitTensor(value);
public Tensor(uint value) => InitTensor(value);
public Tensor(long value) => InitTensor(value);
public Tensor(ulong value) => InitTensor(value);
public Tensor(float value) => InitTensor(value);
public Tensor(double value) => InitTensor(value);
#endregion

#region 1d array
@@ -142,31 +94,74 @@ namespace Tensorflow
public Tensor(Complex[] data, Shape? shape = null) => InitTensor(data, shape);
#endregion

/// <summary>
/// Create a string Tensor from the given string
/// </summary>
public Tensor(string str)
public Tensor(Operation op, int value_index, TF_DataType dtype)
{
_op = op;
_value_index = value_index;
_override_dtype = dtype;
_id = ops.uid();
isCreatedInGraphMode = !tf.executing_eagerly();
}

internal Tensor(Shape shape, TF_DataType dtype) => InitTensor(shape, dtype);
internal Tensor(Array array, Shape? shape = null) => InitTensor(array, shape);
internal Tensor(string value) => InitTensor(value);

protected unsafe void InitTensor<T>(T data) where T : unmanaged
{
_handle = StringTensor(new string[] { str }, TensorShape.Scalar);
#if TRACK_TENSOR_LIFE
print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}");
#endif
_handle = TF_NewTensor(data);
isCreatedInGraphMode = !tf.executing_eagerly();
}

public Tensor(string[] strings)
protected unsafe void InitTensor(Shape shape, TF_DataType dtype)
{
_handle = StringTensor(strings, new TensorShape(strings.Length));
#if TRACK_TENSOR_LIFE
print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}");
#endif
_handle = TF_NewTensor(shape, dtype, null);
isCreatedInGraphMode = !tf.executing_eagerly();
}

public Tensor(Operation op, int value_index, TF_DataType dtype)
protected void InitTensor(string value)
{
_op = op;
_value_index = value_index;
_override_dtype = dtype;
_id = ops.uid();
_handle = StringTensor(new[] { value }, TensorShape.Scalar);
isCreatedInGraphMode = !tf.executing_eagerly();
}

protected unsafe void InitTensor(Array array, Shape? shape = null)
{
shape = shape ?? array.GetShape();
var dtype = array.GetType().GetElementType().as_tf_dtype();

switch (array)
{
case bool[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case bool[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case bool[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case bool[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case byte[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case byte[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case byte[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case byte[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case int[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case int[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case int[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case int[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case long[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case long[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case long[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case long[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case float[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case float[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case float[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case float[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case double[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case double[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case double[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case double[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
case string[] val: _handle = StringTensor(val, shape); break;
default:
throw new NotImplementedException("");
}

isCreatedInGraphMode = !tf.executing_eagerly();
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.String.cs View File

@@ -23,7 +23,7 @@ namespace Tensorflow
public IntPtr StringTensor(byte[][] buffer, TensorShape shape)
{
var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING,
shape.ndim == 0 ? null : shape.dims.Select(x => (long)x).ToArray(),
shape.ndim == 0 ? null : shape.dims,
shape.ndim,
(ulong)shape.size * TF_TSRING_SIZE);



+ 9
- 3
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -93,9 +93,13 @@ namespace Tensorflow
/// TFE_TensorHandle
/// </summary>
public SafeTensorHandleHandle EagerTensorHandle { get; set; }
protected bool _createdInGraphMode;
public bool CreatedInGraphMode => _createdInGraphMode;
public bool IsEagerTensor => this is EagerTensor;

protected bool isReferencedByNDArray;
public bool IsReferencedByNDArray => isReferencedByNDArray;

protected bool isCreatedInGraphMode;
public bool IsCreatedInGraphMode => isCreatedInGraphMode;
public bool IsSparseTensor => this is SparseTensor;

/// <summary>
@@ -207,6 +211,8 @@ namespace Tensorflow
return _tf_output.Value;
}

public void SetReferencedByNDArray() => isReferencedByNDArray = true;

public Tensor MaybeMove()
{
var tensor = c_api.TF_TensorMaybeMove(_handle);


+ 3
- 2
src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs View File

@@ -1,4 +1,5 @@
using Tensorflow.NumPy;
using System.Linq;
using Tensorflow.NumPy;

namespace Tensorflow
{
@@ -13,7 +14,7 @@ namespace Tensorflow
public static implicit operator TensorShape(Shape shape) => new TensorShape((long[])shape.dims.Clone());
public static implicit operator Shape(TensorShape shape) => shape == null ? null : new Shape((long[])shape.dims.Clone());

public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes
public static implicit operator int[](TensorShape shape) => shape == null ? null : shape.dims.Select(x => (int)x).ToArray(); //we clone to avoid any changes
public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims);

public static implicit operator long[](TensorShape shape) => shape == null ? null : (long[])shape.dims.Clone(); //we clone to avoid any changes


+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow
public TensorShape shape => items.First().TensorShape;
public int rank => items.First().rank;
public Graph graph => items.First().graph;
public bool IsEagerTensor => items.First().IsEagerTensor;
public bool IsCreatedInGraphMode => items.First().IsCreatedInGraphMode;
public bool IsList { get; set; }
public int Length => items.Count();



+ 2
- 3
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -98,7 +98,6 @@ namespace Tensorflow
attrs: attrs,
name: name);

var o = op.outputs;
return op.outputs[0];
}

@@ -167,9 +166,9 @@ namespace Tensorflow
case TensorShape val:
return new EagerTensor(val.dims, ctx.DeviceName);
case string val:
return new EagerTensor(val);
return new EagerTensor(new[] { val }, Shape.Scalar);
case string[] val:
return new EagerTensor(val, ctx.DeviceName);
return new EagerTensor(val, new Shape(val.Length));
case bool val:
return new EagerTensor(new[] { val }, Shape.Scalar);
case byte val:


+ 25
- 52
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -75,7 +75,7 @@ namespace Tensorflow
case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX
return typeof(Complex);
default:
return null;
throw new NotSupportedException($"Unable to convert {type} to a system data type.");
}
}

@@ -83,24 +83,25 @@ namespace Tensorflow
///
/// </summary>
/// <param name="type"></param>
/// <param name="dtype"></param>
/// <returns></returns>
/// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="TF_DataType"/></exception>
public static TF_DataType as_tf_dtype(this Type type, TF_DataType? dtype = null)
public static TF_DataType as_tf_dtype(this Type type)
{
while (type.IsArray)
type = type.GetElementType();

TF_DataType dtype = TF_DataType.DtInvalid;

switch (type.Name)
{
case "Char":
dtype = dtype ?? TF_DataType.TF_UINT8;
dtype = TF_DataType.TF_UINT8;
break;
case "SByte":
dtype = TF_DataType.TF_INT8;
break;
case "Byte":
dtype = dtype ?? TF_DataType.TF_UINT8;
dtype = TF_DataType.TF_UINT8;
break;
case "Int16":
dtype = TF_DataType.TF_INT16;
@@ -136,60 +137,32 @@ namespace Tensorflow
dtype = TF_DataType.TF_BOOL;
break;
default:
throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode.");
throw new NotSupportedException($"Unable to convert {type} to a TensorFlow data type.");
}

return dtype.Value;
return dtype;
}

public static TF_DataType tf_dtype_from_name(string name)
{
TF_DataType dtype = TF_DataType.DtInvalid;
switch (name.ToLower())
TF_DataType dtype = name.ToLower() switch
{
case "char":
dtype = TF_DataType.TF_UINT8;
break;
case "boolean":
dtype = TF_DataType.TF_BOOL;
break;
case "sbyte":
dtype = TF_DataType.TF_INT8;
break;
case "byte":
dtype = TF_DataType.TF_UINT8;
break;
case "int16":
dtype = TF_DataType.TF_INT16;
break;
case "uint16":
dtype = TF_DataType.TF_UINT16;
break;
case "int32":
dtype = TF_DataType.TF_INT32;
break;
case "uint32":
dtype = TF_DataType.TF_UINT32;
break;
case "int64":
dtype = TF_DataType.TF_INT64;
break;
case "uint64":
dtype = TF_DataType.TF_UINT64;
break;
case "single":
dtype = TF_DataType.TF_FLOAT;
break;
case "double":
dtype = TF_DataType.TF_DOUBLE;
break;
case "complex":
dtype = TF_DataType.TF_COMPLEX128;
break;
case "string":
dtype = TF_DataType.TF_STRING;
break;
}
"char" => TF_DataType.TF_UINT8,
"boolean" => TF_DataType.TF_BOOL,
"sbyte" => TF_DataType.TF_INT8,
"byte" => TF_DataType.TF_UINT8,
"int16" => TF_DataType.TF_INT16,
"uint16" => TF_DataType.TF_UINT16,
"int32" => TF_DataType.TF_INT32,
"uint32" => TF_DataType.TF_UINT32,
"int64" => TF_DataType.TF_INT64,
"uint64" => TF_DataType.TF_UINT64,
"single" => TF_DataType.TF_FLOAT,
"double" => TF_DataType.TF_DOUBLE,
"complex" => TF_DataType.TF_COMPLEX128,
"string" => TF_DataType.TF_STRING,
_ => TF_DataType.DtInvalid
};

return dtype;
}


+ 9
- 3
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -108,7 +108,7 @@ namespace Tensorflow
if (values is TensorProto tp)
return tp;

dtype = values.GetType().as_tf_dtype();
dtype = values.GetDataType();
shape = shape ?? values.GetShape();
var tensor_proto = new TensorProto
{
@@ -117,7 +117,13 @@ namespace Tensorflow
};

// scalar
if (!values.GetType().IsArray)
if (values is NDArray nd)
{
var len = nd.dtypesize * nd.size;
byte[] bytes = nd.ToByteArray();
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes);
}
else if (!values.GetType().IsArray)
{
switch (values)
{
@@ -154,7 +160,7 @@ namespace Tensorflow
else if (values is byte[] byte_values)
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(byte_values);
}
else if(values is Array array)
else if (values is Array array)
{
// array
var len = dtype.get_datatype_size() * (int)shape.size;


+ 1
- 1
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -68,7 +68,7 @@ namespace Tensorflow
// when this object is garbage collected the deleter will be too. This
// means ResourceVariables can be part of reference cycles without those
// cycles being uncollectable.
if (handle.IsEagerTensor)
if (!handle.IsCreatedInGraphMode)
{
_handle = handle.EagerTensorHandle.DangerousGetHandle();
eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device);


+ 7
- 1
src/TensorFlowNET.Core/ops.cs View File

@@ -123,7 +123,7 @@ namespace Tensorflow
if (dtype == TF_DataType.DtInvalid)
dtype = preferred_dtype;

if (value is EagerTensor eager_tensor)
if (value is EagerTensor eager_tensor && !eager_tensor.IsCreatedInGraphMode)
{
if (tf.executing_eagerly())
{
@@ -140,7 +140,13 @@ namespace Tensorflow
}
}
else if (value is NDArray nd)
{
return nd;
}
else if (value is Tensor tensor && tensor.IsReferencedByNDArray)
{
return tensor;
}

// graph mode
Tensor ret = value switch


+ 2
- 2
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -115,7 +115,7 @@ namespace Tensorflow.Keras.Engine
bool _in_functional_construction_mode(Tensors inputs)
{
return tf.Context.executing_eagerly()
&& inputs.Count(x => !x.IsEagerTensor) == inputs.Count();
&& inputs.Count(x => x.IsCreatedInGraphMode) == inputs.Count();
}

public void SetConnectivityMetadata(Tensors inputs, Tensors outputs)
@@ -177,7 +177,7 @@ namespace Tensorflow.Keras.Engine
tf.init_scope();

bool need_restore_mode = false;
if (inputs.IsEagerTensor || tf.Context.is_build_function())
if (!inputs.IsCreatedInGraphMode || tf.Context.is_build_function())
{
need_restore_mode = true;
tf.Context.eager_mode(isFunc: tf.Context.is_build_function());


+ 3
- 3
test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs View File

@@ -148,10 +148,10 @@ namespace TensorFlowNET.UnitTest.Dataset
{
var dataset = tf.data.Dataset.range(10);
var cardinality = dataset.cardinality();
Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
Assert.AreEqual(cardinality.numpy(), 10L);
dataset = dataset.map(x => x[0] + 1);
cardinality = dataset.cardinality();
Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
Assert.AreEqual(cardinality.numpy(), 10L);
}

[TestMethod]
@@ -160,7 +160,7 @@ namespace TensorFlowNET.UnitTest.Dataset
var dataset = tf.data.Dataset.range(10);
dataset = dataset.map(x => x, num_parallel_calls: -1);
var cardinality = dataset.cardinality();
Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
Assert.AreEqual(cardinality.numpy(), 10L);
}

[TestMethod]


+ 1
- 1
test/TensorFlowNET.UnitTest/Hub/MnistModelLoaderTest.cs View File

@@ -7,7 +7,7 @@ namespace TensorFlowNET.UnitTest
[TestClass]
public class MnistModelLoaderTest
{
[TestMethod]
[TestMethod, Ignore]
public async Task TestLoad()
{
var loader = new MnistModelLoader();


Loading…
Cancel
Save