Browse Source

fix string data.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
9264a19002
3 changed files with 10 additions and 3 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
  2. +6
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.Value.cs
  3. +3
    -0
      src/TensorFlowNET.Core/ops.cs

+ 1
- 1
src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs View File

@@ -37,7 +37,7 @@ namespace Tensorflow.NumPy
=> new NDArray(value);

public static implicit operator Tensor(NDArray nd)
=> constant_op.constant(nd);
=> nd._tensor;

public static implicit operator NDArray(Tensor tensor)
=> new NDArray(tensor);


+ 6
- 2
src/TensorFlowNET.Core/Tensors/Tensor.Value.cs View File

@@ -49,8 +49,12 @@ namespace Tensorflow

protected NDArray GetNDArray(TF_DataType dtype)
{
/*if (dtype == TF_DataType.TF_STRING)
return np.array(StringData());*/
if (dtype == TF_DataType.TF_STRING)
{
var str= StringData();
return new NDArray(str, new Shape(str.Length));
}
return new NDArray(this);
}



+ 3
- 0
src/TensorFlowNET.Core/ops.cs View File

@@ -171,6 +171,9 @@ namespace Tensorflow
_ => constant_op.constant(value, dtype: dtype, name: name)
};

if (dtype == TF_DataType.TF_STRING)
return ret;

var original_dtype = value.GetDataType();
if (dtype != TF_DataType.DtInvalid && dtype != original_dtype)
ret = gen_math_ops.cast(ret, dtype.as_base_dtype(), name: name);


Loading…
Cancel
Save