@@ -25,7 +25,7 @@ namespace Tensorflow.NumPy | |||||
{ | { | ||||
get | get | ||||
{ | { | ||||
return _tensor[index.Select(x => new Slice(x, x + 1)).ToArray()]; | |||||
return _tensor[index.Select(x => new Slice(x, x + 1)).ToArray()]; | |||||
} | } | ||||
set | set | ||||
@@ -34,12 +34,11 @@ namespace Tensorflow.Operations.Initializers | |||||
if (args.DType == TF_DataType.DtInvalid) | if (args.DType == TF_DataType.DtInvalid) | ||||
args.DType = this.dtype; | args.DType = this.dtype; | ||||
if (!args.VerifyShape.HasValue) | |||||
args.VerifyShape = _verify_shape; | |||||
args.VerifyShape = _verify_shape; | |||||
return constant_op._constant_impl(value, args.DType, args.Shape, | |||||
return constant_op.constant(value, args.DType, args.Shape, | |||||
name: "Const", | name: "Const", | ||||
verify_shape: args.VerifyShape.Value, | |||||
verify_shape: args.VerifyShape, | |||||
allow_broadcast: false); | allow_broadcast: false); | ||||
} | } | ||||
} | } | ||||
@@ -5,11 +5,11 @@ | |||||
public string Name { get; set; } | public string Name { get; set; } | ||||
public TensorShape Shape { get; set; } | public TensorShape Shape { get; set; } | ||||
public TF_DataType DType { get; set; } | public TF_DataType DType { get; set; } | ||||
public bool? VerifyShape { get; set; } = null; | |||||
public bool VerifyShape { get; set; } | |||||
public InitializerArgs(TensorShape shape, | public InitializerArgs(TensorShape shape, | ||||
TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
bool? verify_shape = null, | |||||
bool verify_shape = false, | |||||
string name = null) | string name = null) | ||||
{ | { | ||||
Shape = shape; | Shape = shape; | ||||
@@ -64,10 +64,10 @@ namespace Tensorflow | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
int[] shape = null, | int[] shape = null, | ||||
string name = "Const", | string name = "Const", | ||||
bool verify_shape = false) => constant_op._constant_impl(value, | |||||
dtype, | |||||
shape, | |||||
name, | |||||
bool verify_shape = false) => constant_op.constant(value, | |||||
dtype: dtype, | |||||
shape: shape, | |||||
name: name, | |||||
verify_shape: verify_shape, | verify_shape: verify_shape, | ||||
allow_broadcast: false); | allow_broadcast: false); | ||||
@@ -37,68 +37,14 @@ namespace Tensorflow | |||||
/// <param name="shape">Optional dimensions of resulting tensor.</param> | /// <param name="shape">Optional dimensions of resulting tensor.</param> | ||||
/// <param name="name">Optional name for the tensor.</param> | /// <param name="name">Optional name for the tensor.</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, string name = "Const") | |||||
public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, | |||||
int[] shape = null, bool verify_shape = false, | |||||
bool allow_broadcast = true, string name = "Const") | |||||
{ | { | ||||
return _constant_impl(value, dtype, shape, name, verify_shape: false, allow_broadcast: true); | |||||
} | |||||
/// <param name="verify_shape">Boolean that enables verification of a shape of values.</param> | |||||
public static Tensor _constant_impl(object value, | |||||
TF_DataType dtype, | |||||
TensorShape shape, | |||||
string name, | |||||
bool verify_shape, | |||||
bool allow_broadcast) | |||||
{ | |||||
if (tf.Context.executing_eagerly()) | |||||
{ | |||||
var t = convert_to_eager_tensor(value, tf.Context, dtype: dtype); | |||||
if (shape == null) | |||||
return t; | |||||
if (t.shape.Select(x => Convert.ToInt64(x)).SequenceEqual(shape.dims)) | |||||
return t; | |||||
if (verify_shape) | |||||
throw new TypeError($"Expected Tensor's shape: {shape}, got {t.shape}."); | |||||
var num_t = t.TensorShape.num_elements(); | |||||
if (num_t == shape.num_elements()) | |||||
return _eager_reshape(t, shape, tf.Context); | |||||
if (num_t == 1) | |||||
{ | |||||
if (t.dtype == dtypes.@bool) | |||||
throw new NotImplementedException(""); | |||||
else | |||||
return _eager_fill(shape, t, tf.Context); | |||||
} | |||||
} | |||||
// graph mode | |||||
Graph g = ops.get_default_graph(); | |||||
var tensor_value = new AttrValue(); | |||||
tensor_value.Tensor = tensor_util.make_tensor_proto(value, | |||||
dtype: dtype, | |||||
shape: shape, | |||||
verify_shape: verify_shape, | |||||
allow_broadcast: allow_broadcast); | |||||
var dtype_value = new AttrValue | |||||
{ | |||||
Type = tensor_value.Tensor.Dtype, | |||||
}; | |||||
var attrs = new Dictionary<string, AttrValue>(); | |||||
attrs["value"] = tensor_value; | |||||
attrs["dtype"] = dtype_value; | |||||
var op = g.create_op("Const", | |||||
new Tensor[0], | |||||
new TF_DataType[] { dtype_value.Type.as_tf_dtype() }, | |||||
attrs: attrs, | |||||
name: name); | |||||
return op.outputs[0]; | |||||
if(tf.executing_eagerly()) | |||||
return convert_to_eager_tensor(value, dtype, shape, name, verify_shape: verify_shape, allow_broadcast: allow_broadcast); | |||||
else | |||||
return convert_to_graph_tensor(value, dtype, shape, name, verify_shape: verify_shape, allow_broadcast: allow_broadcast); | |||||
} | } | ||||
private static Tensor _eager_reshape(Tensor tensor, int[] shape, Context ctx) | private static Tensor _eager_reshape(Tensor tensor, int[] shape, Context ctx) | ||||
@@ -189,6 +135,70 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
static Tensor convert_to_eager_tensor(object value, | |||||
TF_DataType dtype, | |||||
TensorShape shape, | |||||
string name, | |||||
bool verify_shape, | |||||
bool allow_broadcast) | |||||
{ | |||||
var t = convert_to_eager_tensor(value, tf.Context, dtype: dtype); | |||||
if (shape == null) | |||||
return t; | |||||
if (t.shape.Select(x => Convert.ToInt64(x)).SequenceEqual(shape.dims)) | |||||
return t; | |||||
if (verify_shape) | |||||
throw new TypeError($"Expected Tensor's shape: {shape}, got {t.shape}."); | |||||
var num_t = t.TensorShape.num_elements(); | |||||
if (num_t == shape.num_elements()) | |||||
return _eager_reshape(t, shape, tf.Context); | |||||
if (num_t == 1) | |||||
{ | |||||
if (t.dtype == dtypes.@bool) | |||||
throw new NotImplementedException(""); | |||||
else | |||||
return _eager_fill(shape, t, tf.Context); | |||||
} | |||||
throw new NotImplementedException(""); | |||||
} | |||||
static Tensor convert_to_graph_tensor(object value, | |||||
TF_DataType dtype, | |||||
TensorShape shape, | |||||
string name, | |||||
bool verify_shape, | |||||
bool allow_broadcast) | |||||
{ | |||||
Graph g = ops.get_default_graph(); | |||||
var tensor_value = new AttrValue(); | |||||
tensor_value.Tensor = tensor_util.make_tensor_proto(value, | |||||
dtype: dtype, | |||||
shape: shape, | |||||
verify_shape: verify_shape, | |||||
allow_broadcast: allow_broadcast); | |||||
var dtype_value = new AttrValue | |||||
{ | |||||
Type = tensor_value.Tensor.Dtype, | |||||
}; | |||||
var attrs = new Dictionary<string, AttrValue>(); | |||||
attrs["value"] = tensor_value; | |||||
attrs["dtype"] = dtype_value; | |||||
var op = g.create_op("Const", | |||||
new Tensor[0], | |||||
new TF_DataType[] { dtype_value.Type.as_tf_dtype() }, | |||||
attrs: attrs, | |||||
name: name); | |||||
return op.outputs[0]; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Function to convert TensorShape to Tensor. | /// Function to convert TensorShape to Tensor. | ||||
/// </summary> | /// </summary> | ||||
@@ -125,6 +125,12 @@ namespace Tensorflow | |||||
byte[] bytes = nd.ToByteArray(); | byte[] bytes = nd.ToByteArray(); | ||||
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); | tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); | ||||
} | } | ||||
else if (values is Tensor tensor && tensor.IsReferencedByNDArray) | |||||
{ | |||||
var len = tensor.itemsize * tensor.size; | |||||
byte[] bytes = tensor.BufferToArray(); | |||||
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); | |||||
} | |||||
else if (!values.GetType().IsArray) | else if (!values.GetType().IsArray) | ||||
{ | { | ||||
switch (values) | switch (values) | ||||
@@ -30,10 +30,10 @@ namespace Tensorflow | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
TensorShape shape = null, | TensorShape shape = null, | ||||
string name = "Const") | string name = "Const") | ||||
=> constant_op._constant_impl(value, | |||||
dtype, | |||||
shape, | |||||
name, | |||||
=> constant_op.constant(value, | |||||
dtype: dtype, | |||||
shape: shape, | |||||
name: name, | |||||
verify_shape: false, | verify_shape: false, | ||||
allow_broadcast: true); | allow_broadcast: true); | ||||
@@ -145,7 +145,10 @@ namespace Tensorflow | |||||
} | } | ||||
else if (value is Tensor tensor && tensor.IsReferencedByNDArray) | else if (value is Tensor tensor && tensor.IsReferencedByNDArray) | ||||
{ | { | ||||
return tensor; | |||||
if (tf.executing_eagerly()) | |||||
return tensor; | |||||
else | |||||
return constant_op.constant(tensor); | |||||
} | } | ||||
// graph mode | // graph mode | ||||
@@ -82,15 +82,14 @@ namespace TensorFlowNET.UnitTest | |||||
var result = sess.run(cropped); | var result = sess.run(cropped); | ||||
// check if cropped to 1x1 center was succesfull | // check if cropped to 1x1 center was succesfull | ||||
Assert.AreEqual(result.size, 1); | |||||
Assert.AreEqual(result.size, 1ul); | |||||
Assert.AreEqual(result[0, 0, 0, 0], 4f); | Assert.AreEqual(result[0, 0, 0, 0], 4f); | ||||
cropped = tf.image.crop_and_resize(image2, box, boxInd, cropSize2_2); | cropped = tf.image.crop_and_resize(image2, box, boxInd, cropSize2_2); | ||||
result = sess.run(cropped); | result = sess.run(cropped); | ||||
// check if flipped and no cropping occured | // check if flipped and no cropping occured | ||||
Assert.AreEqual(result.size, 16); | |||||
Assert.AreEqual(result.size, 16ul); | |||||
Assert.AreEqual(result[0, 0, 0, 0], 12f); | Assert.AreEqual(result[0, 0, 0, 0], 12f); | ||||
} | } | ||||
} | } | ||||
} | } | ||||