Browse Source

fix ndarray index.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
d3e212488f
9 changed files with 97 additions and 80 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/NumPy/NDArray.Index.cs
  2. +3
    -4
      src/TensorFlowNET.Core/Operations/Initializers/Constant.cs
  3. +2
    -2
      src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs
  4. +4
    -4
      src/TensorFlowNET.Core/Operations/array_ops.cs
  5. +71
    -61
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  6. +6
    -0
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  7. +4
    -4
      src/TensorFlowNET.Core/Tensors/tf.constant.cs
  8. +4
    -1
      src/TensorFlowNET.Core/ops.cs
  9. +2
    -3
      test/TensorFlowNET.Graph.UnitTest/ImageTest.cs

+ 1
- 1
src/TensorFlowNET.Core/NumPy/NDArray.Index.cs View File

@@ -25,7 +25,7 @@ namespace Tensorflow.NumPy
{
get
{
return _tensor[index.Select(x => new Slice(x, x + 1)).ToArray()];
return _tensor[index.Select(x => new Slice(x, x + 1)).ToArray()];
}

set


+ 3
- 4
src/TensorFlowNET.Core/Operations/Initializers/Constant.cs View File

@@ -34,12 +34,11 @@ namespace Tensorflow.Operations.Initializers
if (args.DType == TF_DataType.DtInvalid)
args.DType = this.dtype;

if (!args.VerifyShape.HasValue)
args.VerifyShape = _verify_shape;
args.VerifyShape = _verify_shape;

return constant_op._constant_impl(value, args.DType, args.Shape,
return constant_op.constant(value, args.DType, args.Shape,
name: "Const",
verify_shape: args.VerifyShape.Value,
verify_shape: args.VerifyShape,
allow_broadcast: false);
}
}


+ 2
- 2
src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs View File

@@ -5,11 +5,11 @@
public string Name { get; set; }
public TensorShape Shape { get; set; }
public TF_DataType DType { get; set; }
public bool? VerifyShape { get; set; } = null;
public bool VerifyShape { get; set; }

public InitializerArgs(TensorShape shape,
TF_DataType dtype = TF_DataType.DtInvalid,
bool? verify_shape = null,
bool verify_shape = false,
string name = null)
{
Shape = shape;


+ 4
- 4
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -64,10 +64,10 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.DtInvalid,
int[] shape = null,
string name = "Const",
bool verify_shape = false) => constant_op._constant_impl(value,
dtype,
shape,
name,
bool verify_shape = false) => constant_op.constant(value,
dtype: dtype,
shape: shape,
name: name,
verify_shape: verify_shape,
allow_broadcast: false);



+ 71
- 61
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -37,68 +37,14 @@ namespace Tensorflow
/// <param name="shape">Optional dimensions of resulting tensor.</param>
/// <param name="name">Optional name for the tensor.</param>
/// <returns></returns>
public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, string name = "Const")
public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid,
int[] shape = null, bool verify_shape = false,
bool allow_broadcast = true, string name = "Const")
{
return _constant_impl(value, dtype, shape, name, verify_shape: false, allow_broadcast: true);
}

/// <param name="verify_shape">Boolean that enables verification of a shape of values.</param>
public static Tensor _constant_impl(object value,
TF_DataType dtype,
TensorShape shape,
string name,
bool verify_shape,
bool allow_broadcast)
{
if (tf.Context.executing_eagerly())
{
var t = convert_to_eager_tensor(value, tf.Context, dtype: dtype);
if (shape == null)
return t;

if (t.shape.Select(x => Convert.ToInt64(x)).SequenceEqual(shape.dims))
return t;

if (verify_shape)
throw new TypeError($"Expected Tensor's shape: {shape}, got {t.shape}.");

var num_t = t.TensorShape.num_elements();
if (num_t == shape.num_elements())
return _eager_reshape(t, shape, tf.Context);
if (num_t == 1)
{
if (t.dtype == dtypes.@bool)
throw new NotImplementedException("");
else
return _eager_fill(shape, t, tf.Context);
}
}

// graph mode
Graph g = ops.get_default_graph();
var tensor_value = new AttrValue();
tensor_value.Tensor = tensor_util.make_tensor_proto(value,
dtype: dtype,
shape: shape,
verify_shape: verify_shape,
allow_broadcast: allow_broadcast);

var dtype_value = new AttrValue
{
Type = tensor_value.Tensor.Dtype,
};

var attrs = new Dictionary<string, AttrValue>();
attrs["value"] = tensor_value;
attrs["dtype"] = dtype_value;

var op = g.create_op("Const",
new Tensor[0],
new TF_DataType[] { dtype_value.Type.as_tf_dtype() },
attrs: attrs,
name: name);

return op.outputs[0];
if(tf.executing_eagerly())
return convert_to_eager_tensor(value, dtype, shape, name, verify_shape: verify_shape, allow_broadcast: allow_broadcast);
else
return convert_to_graph_tensor(value, dtype, shape, name, verify_shape: verify_shape, allow_broadcast: allow_broadcast);
}

private static Tensor _eager_reshape(Tensor tensor, int[] shape, Context ctx)
@@ -189,6 +135,70 @@ namespace Tensorflow
}
}

static Tensor convert_to_eager_tensor(object value,
TF_DataType dtype,
TensorShape shape,
string name,
bool verify_shape,
bool allow_broadcast)
{
var t = convert_to_eager_tensor(value, tf.Context, dtype: dtype);
if (shape == null)
return t;

if (t.shape.Select(x => Convert.ToInt64(x)).SequenceEqual(shape.dims))
return t;

if (verify_shape)
throw new TypeError($"Expected Tensor's shape: {shape}, got {t.shape}.");

var num_t = t.TensorShape.num_elements();
if (num_t == shape.num_elements())
return _eager_reshape(t, shape, tf.Context);
if (num_t == 1)
{
if (t.dtype == dtypes.@bool)
throw new NotImplementedException("");
else
return _eager_fill(shape, t, tf.Context);
}

throw new NotImplementedException("");
}

static Tensor convert_to_graph_tensor(object value,
TF_DataType dtype,
TensorShape shape,
string name,
bool verify_shape,
bool allow_broadcast)
{
Graph g = ops.get_default_graph();
var tensor_value = new AttrValue();
tensor_value.Tensor = tensor_util.make_tensor_proto(value,
dtype: dtype,
shape: shape,
verify_shape: verify_shape,
allow_broadcast: allow_broadcast);

var dtype_value = new AttrValue
{
Type = tensor_value.Tensor.Dtype,
};

var attrs = new Dictionary<string, AttrValue>();
attrs["value"] = tensor_value;
attrs["dtype"] = dtype_value;

var op = g.create_op("Const",
new Tensor[0],
new TF_DataType[] { dtype_value.Type.as_tf_dtype() },
attrs: attrs,
name: name);

return op.outputs[0];
}

/// <summary>
/// Function to convert TensorShape to Tensor.
/// </summary>


+ 6
- 0
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -125,6 +125,12 @@ namespace Tensorflow
byte[] bytes = nd.ToByteArray();
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes);
}
else if (values is Tensor tensor && tensor.IsReferencedByNDArray)
{
var len = tensor.itemsize * tensor.size;
byte[] bytes = tensor.BufferToArray();
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes);
}
else if (!values.GetType().IsArray)
{
switch (values)


+ 4
- 4
src/TensorFlowNET.Core/Tensors/tf.constant.cs View File

@@ -30,10 +30,10 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.DtInvalid,
TensorShape shape = null,
string name = "Const")
=> constant_op._constant_impl(value,
dtype,
shape,
name,
=> constant_op.constant(value,
dtype: dtype,
shape: shape,
name: name,
verify_shape: false,
allow_broadcast: true);



+ 4
- 1
src/TensorFlowNET.Core/ops.cs View File

@@ -145,7 +145,10 @@ namespace Tensorflow
}
else if (value is Tensor tensor && tensor.IsReferencedByNDArray)
{
return tensor;
if (tf.executing_eagerly())
return tensor;
else
return constant_op.constant(tensor);
}

// graph mode


+ 2
- 3
test/TensorFlowNET.Graph.UnitTest/ImageTest.cs View File

@@ -82,15 +82,14 @@ namespace TensorFlowNET.UnitTest

var result = sess.run(cropped);
// check if cropped to 1x1 center was succesfull
Assert.AreEqual(result.size, 1);
Assert.AreEqual(result.size, 1ul);
Assert.AreEqual(result[0, 0, 0, 0], 4f);

cropped = tf.image.crop_and_resize(image2, box, boxInd, cropSize2_2);
result = sess.run(cropped);
// check if flipped and no cropping occured
Assert.AreEqual(result.size, 16);
Assert.AreEqual(result.size, 16ul);
Assert.AreEqual(result[0, 0, 0, 0], 12f);

}
}
}


Loading…
Cancel
Save