@@ -152,7 +152,7 @@ namespace Tensorflow | |||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
/// <param name="conjugate"></param> | /// <param name="conjugate"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Tensor transpose<T1>(T1 a, int[] perm = null, string name = "transpose", bool conjugate = false) | |||||
public Tensor transpose<T1>(T1 a, TensorShape perm = null, string name = "transpose", bool conjugate = false) | |||||
=> array_ops.transpose(a, perm, name, conjugate); | => array_ops.transpose(a, perm, name, conjugate); | ||||
/// <summary> | /// <summary> | ||||
@@ -779,7 +779,22 @@ namespace Tensorflow | |||||
return gen_array_ops.gather_v2(@params, indices, axis, name: name); | return gen_array_ops.gather_v2(@params, indices, axis, name: name); | ||||
} | } | ||||
public static Tensor transpose<T1, T2>(T1 a, T2 perm, string name = "transpose", bool conjugate = false) | |||||
public static Tensor transpose<T1>(T1 a, TensorShape perm, string name = "transpose", bool conjugate = false) | |||||
{ | |||||
return tf_with(ops.name_scope(name, "transpose", new { a }), scope => | |||||
{ | |||||
var a_tensor = ops.convert_to_tensor(a); | |||||
if(perm == null) | |||||
{ | |||||
var rank = a_tensor.rank; | |||||
perm = range(0, rank).OrderByDescending(x => x).ToArray(); | |||||
} | |||||
return gen_array_ops.transpose(a_tensor, perm, name: scope); | |||||
}); | |||||
} | |||||
public static Tensor transpose(Tensor a, Tensor perm, string name = "transpose", bool conjugate = false) | |||||
{ | { | ||||
return tf_with(ops.name_scope(name, "transpose", new { a }), scope => | return tf_with(ops.name_scope(name, "transpose", new { a }), scope => | ||||
{ | { | ||||
@@ -531,7 +531,7 @@ namespace Tensorflow | |||||
input, multiples).FirstOrDefault(), | input, multiples).FirstOrDefault(), | ||||
input); | input); | ||||
public static Tensor transpose<T1, T2>(T1 x, T2 perm, string name = null) | |||||
public static Tensor transpose<T1>(Tensor x, T1 perm, string name = null) | |||||
{ | { | ||||
if (tf.Context.executing_eagerly()) | if (tf.Context.executing_eagerly()) | ||||
{ | { | ||||
@@ -1,4 +1,4 @@ | |||||
/***************************************************************************** | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | ||||
Licensed under the Apache License, Version 2.0 (the "License"); | Licensed under the Apache License, Version 2.0 (the "License"); | ||||
@@ -619,6 +619,16 @@ namespace Tensorflow | |||||
public static Tensor squared_difference(Tensor x, Tensor y, string name = null) | public static Tensor squared_difference(Tensor x, Tensor y, string name = null) | ||||
{ | { | ||||
if (tf.Context.executing_eagerly()) | |||||
{ | |||||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
"SquaredDifference", name, | |||||
null, | |||||
x,y); | |||||
return results[0]; | |||||
} | |||||
var _op = tf.OpDefLib._apply_op_helper("SquaredDifference", name, args: new { x, y, name }); | var _op = tf.OpDefLib._apply_op_helper("SquaredDifference", name, args: new { x, y, name }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
@@ -1210,4 +1220,4 @@ namespace Tensorflow | |||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
} | } | ||||
} | |||||
} |
@@ -1,4 +1,4 @@ | |||||
/***************************************************************************** | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | ||||
Licensed under the Apache License, Version 2.0 (the "License"); | Licensed under the Apache License, Version 2.0 (the "License"); | ||||
@@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Utils | |||||
public static Tensor compute_weighted_loss(Tensor losses, Tensor sample_weight = null, string reduction = null, string name = null) | public static Tensor compute_weighted_loss(Tensor losses, Tensor sample_weight = null, string reduction = null, string name = null) | ||||
{ | { | ||||
if (sample_weight == null) | if (sample_weight == null) | ||||
sample_weight = tf.constant(1.0f); | |||||
sample_weight = losses.dtype == TF_DataType.TF_DOUBLE ? tf.constant(1.0) : tf.constant(1.0f); | |||||
var weighted_losses = scale_losses_by_sample_weight(losses, sample_weight); | var weighted_losses = scale_losses_by_sample_weight(losses, sample_weight); | ||||
// Apply reduction function to the individual weighted losses. | // Apply reduction function to the individual weighted losses. | ||||
var loss = reduce_weighted_loss(weighted_losses, reduction); | var loss = reduce_weighted_loss(weighted_losses, reduction); | ||||
@@ -83,6 +83,26 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
Assert.AreEqual(nd[2], x[2].numpy()); | Assert.AreEqual(nd[2], x[2].numpy()); | ||||
} | } | ||||
[TestMethod, Ignore] | |||||
public void TypeMismatchedSliceAssign() | |||||
{ | |||||
NDArray intNd = new int[] | |||||
{ | |||||
1, -2, 3 | |||||
}; | |||||
NDArray doubleNd = new double[] | |||||
{ | |||||
-5, 6, -7 | |||||
}; | |||||
var x = tf.Variable(doubleNd); | |||||
var slice = x[":"]; | |||||
Assert.ThrowsException<System.Exception>( | |||||
// this statement exit without throwing any exception but the "test execution summary" seems not able to detect that. | |||||
() => slice.assign(intNd) | |||||
); | |||||
} | |||||
[TestMethod] | [TestMethod] | ||||
public void Accumulation() | public void Accumulation() | ||||
{ | { | ||||
@@ -11,14 +11,70 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
[TestMethod] | [TestMethod] | ||||
public void TransposeTest() | public void TransposeTest() | ||||
{ | { | ||||
var a = tf.constant(np.array(new[, , ,] { { { { 1, 11, 2, 22 } }, { { 3, 33, 4, 44 } } }, | |||||
{ { { 5, 55, 6, 66 } }, { { 7, 77, 8, 88 } } } })); | |||||
var b = tf.transpose(a, new[] { 3, 1, 2, 0 }); | |||||
var transpose_a = tf.constant(np.array(new[, , ,] { { { { 1, 5 } }, { { 3, 7 } } }, | |||||
{ { { 11, 55 } }, { { 33, 77 } } }, { { { 2, 6 } }, { { 4, 8 } } }, | |||||
{ { { 22, 66 } }, { { 44, 88 } } } })); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 4, 2, 1, 2 }, b.shape)); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(transpose_a.numpy().ToArray<int>(), b.numpy().ToArray<int>())); | |||||
// https://www.tensorflow.org/api_docs/python/tf/transpose#for_example_2 | |||||
var x = tf.constant(new int[,] | |||||
{ | |||||
{ 1, 2, 3 }, | |||||
{ 4, 5, 6 } | |||||
}); | |||||
var transpose_x = tf.transpose(x); | |||||
Assert.AreEqual(new[] { 1, 4 }, transpose_x[0].numpy()); | |||||
Assert.AreEqual(new[] { 2, 5 }, transpose_x[1].numpy()); | |||||
Assert.AreEqual(new[] { 3, 6 }, transpose_x[2].numpy()); | |||||
#region constant a | |||||
var a = tf.constant(np.array(new[, , ,] | |||||
{ | |||||
{ | |||||
{ | |||||
{ 1, 11, 2, 22 } | |||||
}, | |||||
{ | |||||
{ 3, 33, 4, 44 } | |||||
} | |||||
}, | |||||
{ | |||||
{ | |||||
{ 5, 55, 6, 66 } | |||||
}, | |||||
{ | |||||
{ 7, 77, 8, 88 } | |||||
} | |||||
} | |||||
})); | |||||
#endregion | |||||
var actual_transposed_a = tf.transpose(a, new[] { 3, 1, 2, 0 }); | |||||
#region constant transpose_a | |||||
var expected_transposed_a = tf.constant(np.array(new[, , ,] | |||||
{ | |||||
{ | |||||
{ { 1, 5 } }, { { 3, 7 } } | |||||
}, | |||||
{ | |||||
{ { 11, 55 } }, { { 33, 77 } } | |||||
}, | |||||
{ | |||||
{ | |||||
{ 2, 6 } | |||||
}, | |||||
{ | |||||
{ 4, 8 } | |||||
} | |||||
}, | |||||
{ | |||||
{ | |||||
{ 22, 66 } | |||||
}, | |||||
{ | |||||
{ 44, 88 } | |||||
} | |||||
} | |||||
})); | |||||
#endregion | |||||
Assert.AreEqual((4, 2, 1, 2 ), actual_transposed_a.TensorShape); | |||||
Assert.AreEqual(expected_transposed_a.numpy(), actual_transposed_a.numpy()); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||