|
|
@@ -134,12 +134,41 @@ namespace Tensorflow |
|
|
|
TensorShape = shape.as_shape_proto() |
|
|
|
}; |
|
|
|
|
|
|
|
// scalar |
|
|
|
if (values is NDArray nd) |
|
|
|
{ |
|
|
|
var len = nd.dtypesize * nd.size; |
|
|
|
byte[] bytes = nd.ToByteArray(); |
|
|
|
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); |
|
|
|
// scalar |
|
|
|
if (nd.shape.IsScalar) |
|
|
|
{ |
|
|
|
switch (nd.dtype) |
|
|
|
{ |
|
|
|
case TF_DataType.TF_BOOL: |
|
|
|
tensor_proto.BoolVal.AddRange(nd.ToArray<bool>()); |
|
|
|
break; |
|
|
|
case TF_DataType.TF_UINT8: |
|
|
|
tensor_proto.IntVal.AddRange(nd.ToArray<byte>().Select(x => (int)x).ToArray()); |
|
|
|
break; |
|
|
|
case TF_DataType.TF_INT32: |
|
|
|
tensor_proto.IntVal.AddRange(nd.ToArray<int>()); |
|
|
|
break; |
|
|
|
case TF_DataType.TF_INT64: |
|
|
|
tensor_proto.Int64Val.AddRange(nd.ToArray<long>()); |
|
|
|
break; |
|
|
|
case TF_DataType.TF_FLOAT: |
|
|
|
tensor_proto.FloatVal.AddRange(nd.ToArray<float>()); |
|
|
|
break; |
|
|
|
case TF_DataType.TF_DOUBLE: |
|
|
|
tensor_proto.DoubleVal.AddRange(nd.ToArray<double>()); |
|
|
|
break; |
|
|
|
default: |
|
|
|
throw new Exception("make_tensor_proto Not Implemented"); |
|
|
|
} |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
var len = nd.dtypesize * nd.size; |
|
|
|
byte[] bytes = nd.ToByteArray(); |
|
|
|
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); |
|
|
|
} |
|
|
|
} |
|
|
|
else if (dtype == TF_DataType.TF_STRING && !(values is NDArray)) |
|
|
|
{ |
|
|
|