Browse Source

add more type case for tensor.zeros

tags/v0.110.0-LSTM-Model
lingbai-kong 2 years ago
parent
commit
fcd10447ab
1 changed files with 11 additions and 0 deletions
  1. +11
    -0
      src/TensorFlowNET.Core/Operations/array_ops.cs

+ 11
- 0
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -84,8 +84,13 @@ namespace Tensorflow
// var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape);
Tensor zeros = dtype switch
{
TF_DataType.TF_BOOL => constant(false),
TF_DataType.TF_DOUBLE => constant(0d),
TF_DataType.TF_FLOAT => constant(0f),
TF_DataType.TF_INT64 => constant(0L),
TF_DataType.TF_UINT64 => constant((ulong)0),
TF_DataType.TF_INT32 => constant(0),
TF_DataType.TF_UINT32 => constant((uint)0),
TF_DataType.TF_INT8 => constant((sbyte)0),
TF_DataType.TF_UINT8 => constant((byte)0),
_ => constant(0)
@@ -108,9 +113,15 @@ namespace Tensorflow
return _constant_if_small(0.0F, shape, dtype, name);
case TF_DataType.TF_INT64:
return _constant_if_small(0L, shape, dtype, name);
case TF_DataType.TF_UINT64:
return _constant_if_small<ulong>(0, shape, dtype, name);
case TF_DataType.TF_INT32:
return _constant_if_small(0, shape, dtype, name);
case TF_DataType.TF_UINT32:
return _constant_if_small<uint>(0, shape, dtype, name);
case TF_DataType.TF_INT8:
return _constant_if_small<sbyte>(0, shape, dtype, name);
case TF_DataType.TF_UINT8:
return _constant_if_small<byte>(0, shape, dtype, name);
default:
throw new TypeError("can't find type for zeros");


Loading…
Cancel
Save