Browse Source

Implement map_fn and other fixes

tags/v0.20
Brendan Mulcahy Haiping Chen 5 years ago
parent
commit
67bfd365f4
8 changed files with 191 additions and 62 deletions
  1. +7
    -0
      src/TensorFlowNET.Core/Interfaces/IFromMergeVars.cs
  2. +5
    -7
      src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
  3. +13
    -1
      src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Operations/control_flow_ops.cs
  5. +51
    -39
      src/TensorFlowNET.Core/Operations/functional_ops.cs
  6. +110
    -10
      src/TensorFlowNET.Core/Operations/map_fn.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  8. +3
    -3
      test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs

+ 7
- 0
src/TensorFlowNET.Core/Interfaces/IFromMergeVars.cs View File

@@ -0,0 +1,7 @@
namespace Tensorflow
{
public interface IFromMergeVars<T>
{
T FromMergeVars(ITensorOrTensorArray[] mergeVars);
}
}

+ 5
- 7
src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs View File

@@ -118,7 +118,7 @@ namespace Tensorflow.Operations
Func<LoopVar<TItem>, LoopVar<TItem>> body,
LoopVar<TItem> loop_vars,
TensorShape[] shape_invariants,
bool return_same_structure)
bool return_same_structure) where TItem : IFromMergeVars<TItem>, new()
{
// Keep original_loop_vars to identify which are TensorArrays
var original_loop_vars = loop_vars;
@@ -178,7 +178,7 @@ namespace Tensorflow.Operations
Func<LoopVar<TItem>, LoopVar<TItem>> body,
LoopVar<TItem> original_loop_vars,
Tensor[] loop_vars,
TensorShape[] shape_invariants)
TensorShape[] shape_invariants) where TItem : IFromMergeVars<TItem>, new()
{
var flat_loop_vars = nest.flatten2(original_loop_vars)
.Select(x => (ITensorOrTensorArray)x)
@@ -235,11 +235,9 @@ namespace Tensorflow.Operations

// Build the graph for pred.
var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars);
//var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays, expand_composites: true);
var packed_vars = new LoopVar<TItem>((Tensor)merge_vars_with_tensor_arrays[0],
(TItem)(object)new BodyItemInRnnWhileLoop((Tensor)merge_vars_with_tensor_arrays[1],
new[] { (TensorArray)merge_vars_with_tensor_arrays[2] },
(Tensor)merge_vars_with_tensor_arrays[3]));
var packed_vars = new LoopVar<TItem>(
(Tensor) merge_vars_with_tensor_arrays[0],
new TItem().FromMergeVars(merge_vars_with_tensor_arrays));
var pp = pred(packed_vars);
var c = ops.convert_to_tensor(pp);
_pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond");


+ 13
- 1
src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs View File

@@ -4,7 +4,7 @@ using System.Text;

namespace Tensorflow.Operations
{
internal class BodyItemInRnnWhileLoop : ICanBeFlattened, IPackable<BodyItemInRnnWhileLoop>
internal class BodyItemInRnnWhileLoop : ICanBeFlattened, IPackable<BodyItemInRnnWhileLoop>, IFromMergeVars<BodyItemInRnnWhileLoop>
{
/// <summary>
/// int32 scalar Tensor.
@@ -19,6 +19,10 @@ namespace Tensorflow.Operations
/// </summary>
public Tensor state { get; set; }

public BodyItemInRnnWhileLoop()
{
}

public BodyItemInRnnWhileLoop(Tensor time, TensorArray[] output_ta_t, Tensor state)
{
this.time = time;
@@ -45,5 +49,13 @@ namespace Tensorflow.Operations

return new BodyItemInRnnWhileLoop(time, output_ta_t, state);
}

public BodyItemInRnnWhileLoop FromMergeVars(ITensorOrTensorArray[] mergeVars)
{
time = (Tensor) mergeVars[1];
output_ta_t = new[] {(TensorArray) mergeVars[2]};
state = (Tensor)mergeVars[3];
return this;
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/control_flow_ops.cs View File

@@ -625,7 +625,7 @@ namespace Tensorflow
bool swap_memory = false,
string name = null,
Tensor maximum_iterations = null,
bool return_same_structure = false)
bool return_same_structure = false) where TItem : IFromMergeVars<TItem>, new()
{
return tf_with(ops.name_scope(name, "while", loop_vars), scope =>
{


+ 51
- 39
src/TensorFlowNET.Core/Operations/functional_ops.cs View File

@@ -39,12 +39,12 @@ namespace Tensorflow
{
bool input_is_sequence = nest.is_sequence(elems);

List<Tensor> input_flatten(Tensor x) => input_is_sequence ? nest.flatten(x) : new List<Tensor> {x};
Tensor input_pack(List<Tensor> x) => input_is_sequence ? (Tensor)nest.pack_sequence_as(elems, x) : x[0];
Tensor[] input_flatten(Tensor x) => input_is_sequence ? nest.flatten(x).ToArray() : new [] {x};
Tensor input_pack(Tensor[] x) => input_is_sequence ? (Tensor)nest.pack_sequence_as(elems, x) : x[0];

bool output_is_sequence;
Func<Tensor, List<Tensor>> output_flatten;
Func<List<Tensor>, Tensor> output_pack;
Func<Tensor, Tensor[]> output_flatten;
Func<Tensor[], Tensor> output_pack;
if (initializer == null)
{
output_is_sequence = input_is_sequence;
@@ -54,31 +54,31 @@ namespace Tensorflow
else
{
output_is_sequence = nest.is_sequence(initializer);
output_flatten = (x) => output_is_sequence ? nest.flatten(x) : new List<Tensor> {x};
output_flatten = (x) => output_is_sequence ? nest.flatten(x).ToArray() : new [] {x};
output_pack = (x) => output_is_sequence ? (Tensor)nest.pack_sequence_as(initializer, x) : x[0];
}

var elems_flat = input_flatten(elems);

bool in_graph_mode = true; // todo !context.executing_eagerly()
bool in_graph_mode = tf.context.executing_eagerly();

return tf_with(ops.name_scope(name, "scan", new { elems_flat }), scope =>
{
// todo tf.net doesn't expose .caching_device
//if (in_graph_mode)
//{
// // Any get_variable calls in fn will cache the first call locally
// // and not issue repeated network I/O requests for each iteration.
// var varscope = variable_scope.get_variable_scope();
// bool varscope_caching_device_was_none = false;
// if (varscope.caching_device = null)
// {
// // varscope.set_caching_device(lambda op: op.device)
// // varscope_caching_device_was_none = True
// }
//}
if (in_graph_mode)
{
// todo tf.net doesn't expose .caching_device
//// Any get_variable calls in fn will cache the first call locally
//// and not issue repeated network I/O requests for each iteration.
//var varscope = variable_scope.get_variable_scope();
//bool varscope_caching_device_was_none = false;
//if (varscope.caching_device = null)
//{
// // varscope.set_caching_device(lambda op: op.device)
// // varscope_caching_device_was_none = True
//}
}

elems_flat = elems_flat.Select(elem => ops.convert_to_tensor(elem, name: "elem")).ToList();
elems_flat = elems_flat.Select(elem => ops.convert_to_tensor(elem, name: "elem")).ToArray();

var n = tensor_shape.dimension_value(elems_flat[0].shape[0]);

@@ -100,17 +100,17 @@ namespace Tensorflow
elems_ta[index].unstack(elems_flat[index]);
}

List<Tensor> a_flat;
Tensor[] a_flat;
int i;
if (initializer == null)
{
a_flat = elems_ta.Select(elem => elem.read(tf.constant(reverse ? n - 1 : 0))).ToList();
a_flat = elems_ta.Select(elem => elem.read(tf.constant(reverse ? n - 1 : 0))).ToArray();
i = 1;
}
else
{
List<Tensor> initializer_flat = output_flatten(initializer);
a_flat = initializer_flat.Select(init => ops.convert_to_tensor(init)).ToList();
Tensor[] initializer_flat = output_flatten(initializer);
a_flat = initializer_flat.Select(init => ops.convert_to_tensor(init)).ToArray();
i = 0;
}

@@ -119,11 +119,11 @@ namespace Tensorflow
size: tf.constant(n),
element_shape: infer_shape ? init.shape : null,
dynamic_size: false,
infer_shape: infer_shape)).ToList();
infer_shape: infer_shape)).ToArray();

if (initializer == null)
{
for (int index = 0; index < accs_ta.Count; index++)
for (int index = 0; index < accs_ta.Length; index++)
{
accs_ta[index].write(tf.constant(reverse ? n - 1 : 0), a_flat[index]);
}
@@ -131,14 +131,14 @@ namespace Tensorflow

BodyItem compute(BodyItem item)
{
var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(item.I))).ToList());
var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(item.I)).ToArray());
var packed_a = output_pack(item.A_Flat);
var a_out = fn(packed_a, packed_elems);

var flat_a_out = output_flatten(a_out);
for (int j = 0; j < item.Accs_ta.Count; j++)
for (int j = 0; j < item.Accs_ta.Length; j++)
{
item.Accs_ta[j].write(tf.constant(i), flat_a_out[j]);
item.Accs_ta[j].write(item.I, flat_a_out[j]);
}

var next_i = reverse ? item.I - 1 : item.I + 1;
@@ -150,12 +150,12 @@ namespace Tensorflow
if (reverse)
{
initial_i = n - 1 - i;
condition = x => tf.constant(x.I >= 0);
condition = x => x.I >= 0;
}
else
{
initial_i = i;
condition = x => tf.constant(x.I < n);
condition = x => x.I < n;
}

BodyItem bodyItem =
@@ -168,7 +168,7 @@ namespace Tensorflow
swap_memory: swap_memory,
maximum_iterations: tf.constant(n));

var results_flat = bodyItem.Accs_ta.Select(r => r.stack()).ToList();
var results_flat = bodyItem.Accs_ta.Select(r => r.stack()).ToArray();

var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].TensorShape.with_rank_at_least(1).dims[0]));
@@ -179,7 +179,7 @@ namespace Tensorflow

foreach (Tensor r in results_flat)
{
r.set_shape(new TensorShape(n_static).concatenate(r.TensorShape[new Slice("1:")]));
r.set_shape(new TensorShape(n_static).concatenate(r.dims.Skip(1).ToArray()));
}

// todo get working when the above caching_device is fixed
@@ -191,13 +191,17 @@ namespace Tensorflow
});
}

internal class BodyItem : ICanBeFlattened, IPackable<BodyItem>
internal class BodyItem : ICanBeFlattened, IPackable<BodyItem>, IFromMergeVars<BodyItem>
{
public Tensor I { get; set; }
public List<Tensor> A_Flat { get; set; }
public List<TensorArray> Accs_ta { get; set; }
public Tensor[] A_Flat { get; set; }
public TensorArray[] Accs_ta { get; set; }

public BodyItem()
{
}

public BodyItem(Tensor i, List<Tensor> a_flat, List<TensorArray> accs_ta)
public BodyItem(Tensor i, Tensor[] a_flat, TensorArray[] accs_ta)
{
I = i;
A_Flat = a_flat;
@@ -215,11 +219,19 @@ namespace Tensorflow
public BodyItem Pack(object[] sequences)
{
I = sequences[0] as Tensor;
A_Flat = new List<Tensor> { sequences[1] as Tensor };
Accs_ta = new List<TensorArray> { sequences[2] as TensorArray };
A_Flat = new [] { sequences[1] as Tensor };
Accs_ta = new [] { sequences[2] as TensorArray };
return new BodyItem(I, A_Flat, Accs_ta);
}

public BodyItem FromMergeVars(ITensorOrTensorArray[] merge_vars)
{
I = (Tensor)merge_vars[1];
A_Flat = new [] {(Tensor) merge_vars[2]};
Accs_ta = new [] {(TensorArray) merge_vars[3]};
return this;
}
}
}
}


+ 110
- 10
src/TensorFlowNET.Core/Operations/map_fn.cs View File

@@ -2,7 +2,10 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
using NumSharp;
using Tensorflow.Framework;
using Tensorflow.Operations;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -30,10 +33,40 @@ namespace Tensorflow
bool infer_shape = true,
string name = null)
{
var elems_flat = new[] { elems };
tf_with(ops.name_scope(name, "map", elems_flat), delegate
bool input_is_sequence = nest.is_sequence(elems);
Tensor[] input_flatten(Tensor x) => input_is_sequence ? nest.flatten(x).ToArray() : new [] {x};
Tensor input_pack(Tensor[] x) => input_is_sequence ? (Tensor)nest.pack_sequence_as(elems, x) : x[0];

bool output_is_sequence;
Func<Tensor, Tensor[]> output_flatten;
Func<Tensor[], Tensor> output_pack;
if (dtype == TF_DataType.DtInvalid)
{
output_is_sequence = input_is_sequence;
output_flatten = input_flatten;
output_pack = input_pack;
}
else
{
output_is_sequence = nest.is_sequence(dtype);
output_flatten = (x) => output_is_sequence ? nest.flatten(x).ToArray() : new [] {x};
output_pack = (x) => output_is_sequence ? (Tensor)nest.pack_sequence_as(dtype, x) : x[0];
}

var elems_flat = input_flatten(elems);
return tf_with(ops.name_scope(name, "map", elems_flat), delegate
{
var varscope = tf.get_variable_scope();
//if in_graph_mode:
//# Any get_variable calls in fn will cache the first call locally
//# and not issue repeated network I/O requests for each iteration.
//varscope = vs.get_variable_scope()
//varscope_caching_device_was_none = False
//if varscope.caching_device is None:
// # TODO(ebrevdo): Change to using colocate_with here and in other
// # methods.
// varscope.set_caching_device(lambda op: op.device)
// varscope_caching_device_was_none = True

elems_flat = elems_flat.Select(elem => ops.convert_to_tensor(elem, name: "elem"))
.ToArray();

@@ -65,22 +98,89 @@ namespace Tensorflow
dynamic_size: false,
infer_shape: infer_shape)).ToArray();

/*Func<Tensor, TensorArray> compute = (i, tas) =>

BodyItem compute(BodyItem item)
{
throw new NotImplementedException("");
};
var packed_values = input_pack(elems_ta.Select(elem_ta => elem_ta.read(item.I)).ToArray());
var packed_fn_values = fn(packed_values);
//nest.assert_same_structure(dtype or elems, packed_fn_values)

var flat_fn_values = output_flatten(packed_fn_values);
for (int j = 0; j < item.Accs_ta.Length; j++)
{
item.Accs_ta[j].write(item.I, flat_fn_values[j]);
}

return new BodyItem(item.I + 1, item.Accs_ta);
}

var r_a = control_flow_ops.while_loop(
(i, _) => i < n,
(x) => x.I < n,
compute,
new[] { i, accs_ta },
new BodyItem(i, accs_ta),
parallel_iterations: parallel_iterations,
back_prop: back_prop,
swap_memory: swap_memory,
maximum_iterations: n);*/
maximum_iterations: tf.constant(n));
var results_flat = r_a.Accs_ta.Select(r => r.stack()).ToArray();

var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].TensorShape.with_rank_at_least(1).dims[0]));
foreach (var elem in elems_flat.Skip(1))
{
n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.TensorShape.with_rank_at_least(1).dims[0])));
}

foreach (Tensor r in results_flat)
{
r.set_shape(new TensorShape(n_static).concatenate(r.dims.Skip(1).ToArray()));
}

// todo get working when the above caching_device is fixed
//if (in_graph_mode && varscope_caching_device_was_none) {
// varscope.set_caching_device(None);
//}

return output_pack(results_flat);
});
}

internal class BodyItem : ICanBeFlattened, IPackable<BodyItem>, IFromMergeVars<BodyItem>
{
public Tensor I { get; set; }
public TensorArray[] Accs_ta { get; set; }

throw new NotImplementedException("");
public BodyItem()
{
}

public BodyItem(Tensor i, TensorArray[] accs_ta)
{
I = i;
Accs_ta = accs_ta;
}

public object[] Flatten()
{
var elements = new List<object> { I };
elements.AddRange(Accs_ta);
return elements.ToArray();
}

public BodyItem Pack(object[] sequences)
{
I = sequences[0] as Tensor;
Accs_ta = new [] { sequences[1] as TensorArray };
return new BodyItem(I, Accs_ta);
}

public BodyItem FromMergeVars(ITensorOrTensorArray[] merge_vars)
{
I = (Tensor)merge_vars[1];
Accs_ta = new [] {(TensorArray) merge_vars[2]};
return this;
}
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -154,7 +154,7 @@ namespace Tensorflow
[SuppressMessage("ReSharper", "ParameterHidesMember")]
public TensorShape with_rank_at_least(int rank)
{
if (rank != ndim)
if (ndim < rank)
throw new ValueError($"Shape {this} must have rank at least {rank}");
else
return this;


+ 3
- 3
test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs View File

@@ -18,7 +18,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
var i = constant_op.constant(0, name: "i");
var c = new Func<Tensor, Tensor>(x => tf.less(x, 10, name: "c"));
var b = new Func<Tensor, Tensor>(x => tf.add(x, 1, name: "c"));
var r = control_flow_ops.while_loop(c, b, i);
//var r = control_flow_ops.while_loop(c, b, i);
}
private void _testWhileContextHelper(int maximum_iterations)
@@ -29,8 +29,8 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
var i = constant_op.constant(0, name: "i");
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c"));
var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c"));
control_flow_ops.while_loop(
c, b, i , maximum_iterations: tf.constant(maximum_iterations));
//control_flow_ops.while_loop(
// c, b, i , maximum_iterations: tf.constant(maximum_iterations));
foreach (Operation op in sess.graph.get_operations())
{
var control_flow_context = op._get_control_flow_context();


Loading…
Cancel
Save