diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs
index 8ab82278..316ee024 100644
--- a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs
+++ b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs
@@ -25,7 +25,7 @@ namespace Tensorflow.NumPy
{
get
{
- return _tensor[index.Select(x => new Slice(x, x + 1)).ToArray()];
+ return _tensor[index.Select(x => new Slice(x, x + 1)).ToArray()];
}
set
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs
index cf230978..fdcb5aff 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs
@@ -34,12 +34,11 @@ namespace Tensorflow.Operations.Initializers
if (args.DType == TF_DataType.DtInvalid)
args.DType = this.dtype;
- if (!args.VerifyShape.HasValue)
- args.VerifyShape = _verify_shape;
+ args.VerifyShape = _verify_shape;
- return constant_op._constant_impl(value, args.DType, args.Shape,
+ return constant_op.constant(value, args.DType, args.Shape,
name: "Const",
- verify_shape: args.VerifyShape.Value,
+ verify_shape: args.VerifyShape,
allow_broadcast: false);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs b/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs
index 10702ece..756f33a7 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs
@@ -5,11 +5,11 @@
public string Name { get; set; }
public TensorShape Shape { get; set; }
public TF_DataType DType { get; set; }
- public bool? VerifyShape { get; set; } = null;
+ public bool VerifyShape { get; set; }
public InitializerArgs(TensorShape shape,
TF_DataType dtype = TF_DataType.DtInvalid,
- bool? verify_shape = null,
+ bool verify_shape = false,
string name = null)
{
Shape = shape;
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs
index be10541e..9e7290ed 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.cs
@@ -64,10 +64,10 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.DtInvalid,
int[] shape = null,
string name = "Const",
- bool verify_shape = false) => constant_op._constant_impl(value,
- dtype,
- shape,
- name,
+ bool verify_shape = false) => constant_op.constant(value,
+ dtype: dtype,
+ shape: shape,
+ name: name,
verify_shape: verify_shape,
allow_broadcast: false);
diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs
index a8870252..cf6c76a2 100644
--- a/src/TensorFlowNET.Core/Tensors/constant_op.cs
+++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs
@@ -37,68 +37,14 @@ namespace Tensorflow
/// Optional dimensions of resulting tensor.
/// Optional name for the tensor.
///
- public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, string name = "Const")
+ public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid,
+ int[] shape = null, bool verify_shape = false,
+ bool allow_broadcast = true, string name = "Const")
{
- return _constant_impl(value, dtype, shape, name, verify_shape: false, allow_broadcast: true);
- }
-
- /// Boolean that enables verification of a shape of values.
- public static Tensor _constant_impl(object value,
- TF_DataType dtype,
- TensorShape shape,
- string name,
- bool verify_shape,
- bool allow_broadcast)
- {
- if (tf.Context.executing_eagerly())
- {
- var t = convert_to_eager_tensor(value, tf.Context, dtype: dtype);
- if (shape == null)
- return t;
-
- if (t.shape.Select(x => Convert.ToInt64(x)).SequenceEqual(shape.dims))
- return t;
-
- if (verify_shape)
- throw new TypeError($"Expected Tensor's shape: {shape}, got {t.shape}.");
-
- var num_t = t.TensorShape.num_elements();
- if (num_t == shape.num_elements())
- return _eager_reshape(t, shape, tf.Context);
- if (num_t == 1)
- {
- if (t.dtype == dtypes.@bool)
- throw new NotImplementedException("");
- else
- return _eager_fill(shape, t, tf.Context);
- }
- }
-
- // graph mode
- Graph g = ops.get_default_graph();
- var tensor_value = new AttrValue();
- tensor_value.Tensor = tensor_util.make_tensor_proto(value,
- dtype: dtype,
- shape: shape,
- verify_shape: verify_shape,
- allow_broadcast: allow_broadcast);
-
- var dtype_value = new AttrValue
- {
- Type = tensor_value.Tensor.Dtype,
- };
-
- var attrs = new Dictionary();
- attrs["value"] = tensor_value;
- attrs["dtype"] = dtype_value;
-
- var op = g.create_op("Const",
- new Tensor[0],
- new TF_DataType[] { dtype_value.Type.as_tf_dtype() },
- attrs: attrs,
- name: name);
-
- return op.outputs[0];
+ if(tf.executing_eagerly())
+ return convert_to_eager_tensor(value, dtype, shape, name, verify_shape: verify_shape, allow_broadcast: allow_broadcast);
+ else
+ return convert_to_graph_tensor(value, dtype, shape, name, verify_shape: verify_shape, allow_broadcast: allow_broadcast);
}
private static Tensor _eager_reshape(Tensor tensor, int[] shape, Context ctx)
@@ -189,6 +135,70 @@ namespace Tensorflow
}
}
+ static Tensor convert_to_eager_tensor(object value,
+ TF_DataType dtype,
+ TensorShape shape,
+ string name,
+ bool verify_shape,
+ bool allow_broadcast)
+ {
+ var t = convert_to_eager_tensor(value, tf.Context, dtype: dtype);
+ if (shape == null)
+ return t;
+
+ if (t.shape.Select(x => Convert.ToInt64(x)).SequenceEqual(shape.dims))
+ return t;
+
+ if (verify_shape)
+ throw new TypeError($"Expected Tensor's shape: {shape}, got {t.shape}.");
+
+ var num_t = t.TensorShape.num_elements();
+ if (num_t == shape.num_elements())
+ return _eager_reshape(t, shape, tf.Context);
+ if (num_t == 1)
+ {
+ if (t.dtype == dtypes.@bool)
+ throw new NotImplementedException("");
+ else
+ return _eager_fill(shape, t, tf.Context);
+ }
+
+ throw new NotImplementedException("");
+ }
+
+ static Tensor convert_to_graph_tensor(object value,
+ TF_DataType dtype,
+ TensorShape shape,
+ string name,
+ bool verify_shape,
+ bool allow_broadcast)
+ {
+ Graph g = ops.get_default_graph();
+ var tensor_value = new AttrValue();
+ tensor_value.Tensor = tensor_util.make_tensor_proto(value,
+ dtype: dtype,
+ shape: shape,
+ verify_shape: verify_shape,
+ allow_broadcast: allow_broadcast);
+
+ var dtype_value = new AttrValue
+ {
+ Type = tensor_value.Tensor.Dtype,
+ };
+
+ var attrs = new Dictionary();
+ attrs["value"] = tensor_value;
+ attrs["dtype"] = dtype_value;
+
+ var op = g.create_op("Const",
+ new Tensor[0],
+ new TF_DataType[] { dtype_value.Type.as_tf_dtype() },
+ attrs: attrs,
+ name: name);
+
+ return op.outputs[0];
+ }
+
///
/// Function to convert TensorShape to Tensor.
///
diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
index d97ea1da..5ad8bc9b 100644
--- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs
+++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
@@ -125,6 +125,12 @@ namespace Tensorflow
byte[] bytes = nd.ToByteArray();
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes);
}
+ else if (values is Tensor tensor && tensor.IsReferencedByNDArray)
+ {
+ var len = tensor.itemsize * tensor.size;
+ byte[] bytes = tensor.BufferToArray();
+ tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes);
+ }
else if (!values.GetType().IsArray)
{
switch (values)
diff --git a/src/TensorFlowNET.Core/Tensors/tf.constant.cs b/src/TensorFlowNET.Core/Tensors/tf.constant.cs
index 291e8d0c..3bf6614c 100644
--- a/src/TensorFlowNET.Core/Tensors/tf.constant.cs
+++ b/src/TensorFlowNET.Core/Tensors/tf.constant.cs
@@ -30,10 +30,10 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.DtInvalid,
TensorShape shape = null,
string name = "Const")
- => constant_op._constant_impl(value,
- dtype,
- shape,
- name,
+ => constant_op.constant(value,
+ dtype: dtype,
+ shape: shape,
+ name: name,
verify_shape: false,
allow_broadcast: true);
diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs
index 5e2e8287..07697d5f 100644
--- a/src/TensorFlowNET.Core/ops.cs
+++ b/src/TensorFlowNET.Core/ops.cs
@@ -145,7 +145,10 @@ namespace Tensorflow
}
else if (value is Tensor tensor && tensor.IsReferencedByNDArray)
{
- return tensor;
+ if (tf.executing_eagerly())
+ return tensor;
+ else
+ return constant_op.constant(tensor);
}
// graph mode
diff --git a/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs b/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs
index 39a004f0..a53635d4 100644
--- a/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs
+++ b/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs
@@ -82,15 +82,14 @@ namespace TensorFlowNET.UnitTest
var result = sess.run(cropped);
// check if cropped to 1x1 center was succesfull
- Assert.AreEqual(result.size, 1);
+ Assert.AreEqual(result.size, 1ul);
Assert.AreEqual(result[0, 0, 0, 0], 4f);
cropped = tf.image.crop_and_resize(image2, box, boxInd, cropSize2_2);
result = sess.run(cropped);
// check if flipped and no cropping occured
- Assert.AreEqual(result.size, 16);
+ Assert.AreEqual(result.size, 16ul);
Assert.AreEqual(result[0, 0, 0, 0], 12f);
-
}
}
}