|
|
@@ -37,68 +37,14 @@ namespace Tensorflow |
|
|
|
/// <param name="shape">Optional dimensions of resulting tensor.</param> |
|
|
|
/// <param name="name">Optional name for the tensor.</param> |
|
|
|
/// <returns></returns> |
|
|
|
public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, string name = "Const") |
|
|
|
public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, |
|
|
|
int[] shape = null, bool verify_shape = false, |
|
|
|
bool allow_broadcast = true, string name = "Const") |
|
|
|
{ |
|
|
|
return _constant_impl(value, dtype, shape, name, verify_shape: false, allow_broadcast: true); |
|
|
|
} |
|
|
|
|
|
|
|
/// <param name="verify_shape">Boolean that enables verification of a shape of values.</param> |
|
|
|
public static Tensor _constant_impl(object value, |
|
|
|
TF_DataType dtype, |
|
|
|
TensorShape shape, |
|
|
|
string name, |
|
|
|
bool verify_shape, |
|
|
|
bool allow_broadcast) |
|
|
|
{ |
|
|
|
if (tf.Context.executing_eagerly()) |
|
|
|
{ |
|
|
|
var t = convert_to_eager_tensor(value, tf.Context, dtype: dtype); |
|
|
|
if (shape == null) |
|
|
|
return t; |
|
|
|
|
|
|
|
if (t.shape.Select(x => Convert.ToInt64(x)).SequenceEqual(shape.dims)) |
|
|
|
return t; |
|
|
|
|
|
|
|
if (verify_shape) |
|
|
|
throw new TypeError($"Expected Tensor's shape: {shape}, got {t.shape}."); |
|
|
|
|
|
|
|
var num_t = t.TensorShape.num_elements(); |
|
|
|
if (num_t == shape.num_elements()) |
|
|
|
return _eager_reshape(t, shape, tf.Context); |
|
|
|
if (num_t == 1) |
|
|
|
{ |
|
|
|
if (t.dtype == dtypes.@bool) |
|
|
|
throw new NotImplementedException(""); |
|
|
|
else |
|
|
|
return _eager_fill(shape, t, tf.Context); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// graph mode |
|
|
|
Graph g = ops.get_default_graph(); |
|
|
|
var tensor_value = new AttrValue(); |
|
|
|
tensor_value.Tensor = tensor_util.make_tensor_proto(value, |
|
|
|
dtype: dtype, |
|
|
|
shape: shape, |
|
|
|
verify_shape: verify_shape, |
|
|
|
allow_broadcast: allow_broadcast); |
|
|
|
|
|
|
|
var dtype_value = new AttrValue |
|
|
|
{ |
|
|
|
Type = tensor_value.Tensor.Dtype, |
|
|
|
}; |
|
|
|
|
|
|
|
var attrs = new Dictionary<string, AttrValue>(); |
|
|
|
attrs["value"] = tensor_value; |
|
|
|
attrs["dtype"] = dtype_value; |
|
|
|
|
|
|
|
var op = g.create_op("Const", |
|
|
|
new Tensor[0], |
|
|
|
new TF_DataType[] { dtype_value.Type.as_tf_dtype() }, |
|
|
|
attrs: attrs, |
|
|
|
name: name); |
|
|
|
|
|
|
|
return op.outputs[0]; |
|
|
|
if(tf.executing_eagerly()) |
|
|
|
return convert_to_eager_tensor(value, dtype, shape, name, verify_shape: verify_shape, allow_broadcast: allow_broadcast); |
|
|
|
else |
|
|
|
return convert_to_graph_tensor(value, dtype, shape, name, verify_shape: verify_shape, allow_broadcast: allow_broadcast); |
|
|
|
} |
|
|
|
|
|
|
|
private static Tensor _eager_reshape(Tensor tensor, int[] shape, Context ctx) |
|
|
@@ -189,6 +135,70 @@ namespace Tensorflow |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
static Tensor convert_to_eager_tensor(object value, |
|
|
|
TF_DataType dtype, |
|
|
|
TensorShape shape, |
|
|
|
string name, |
|
|
|
bool verify_shape, |
|
|
|
bool allow_broadcast) |
|
|
|
{ |
|
|
|
var t = convert_to_eager_tensor(value, tf.Context, dtype: dtype); |
|
|
|
if (shape == null) |
|
|
|
return t; |
|
|
|
|
|
|
|
if (t.shape.Select(x => Convert.ToInt64(x)).SequenceEqual(shape.dims)) |
|
|
|
return t; |
|
|
|
|
|
|
|
if (verify_shape) |
|
|
|
throw new TypeError($"Expected Tensor's shape: {shape}, got {t.shape}."); |
|
|
|
|
|
|
|
var num_t = t.TensorShape.num_elements(); |
|
|
|
if (num_t == shape.num_elements()) |
|
|
|
return _eager_reshape(t, shape, tf.Context); |
|
|
|
if (num_t == 1) |
|
|
|
{ |
|
|
|
if (t.dtype == dtypes.@bool) |
|
|
|
throw new NotImplementedException(""); |
|
|
|
else |
|
|
|
return _eager_fill(shape, t, tf.Context); |
|
|
|
} |
|
|
|
|
|
|
|
throw new NotImplementedException(""); |
|
|
|
} |
|
|
|
|
|
|
|
static Tensor convert_to_graph_tensor(object value, |
|
|
|
TF_DataType dtype, |
|
|
|
TensorShape shape, |
|
|
|
string name, |
|
|
|
bool verify_shape, |
|
|
|
bool allow_broadcast) |
|
|
|
{ |
|
|
|
Graph g = ops.get_default_graph(); |
|
|
|
var tensor_value = new AttrValue(); |
|
|
|
tensor_value.Tensor = tensor_util.make_tensor_proto(value, |
|
|
|
dtype: dtype, |
|
|
|
shape: shape, |
|
|
|
verify_shape: verify_shape, |
|
|
|
allow_broadcast: allow_broadcast); |
|
|
|
|
|
|
|
var dtype_value = new AttrValue |
|
|
|
{ |
|
|
|
Type = tensor_value.Tensor.Dtype, |
|
|
|
}; |
|
|
|
|
|
|
|
var attrs = new Dictionary<string, AttrValue>(); |
|
|
|
attrs["value"] = tensor_value; |
|
|
|
attrs["dtype"] = dtype_value; |
|
|
|
|
|
|
|
var op = g.create_op("Const", |
|
|
|
new Tensor[0], |
|
|
|
new TF_DataType[] { dtype_value.Type.as_tf_dtype() }, |
|
|
|
attrs: attrs, |
|
|
|
name: name); |
|
|
|
|
|
|
|
return op.outputs[0]; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Function to convert TensorShape to Tensor. |
|
|
|
/// </summary> |
|
|
|