Browse Source

resove conflict.

tags/v0.12
Oceania2018 6 years ago
parent
commit
8a3a16b72f
6 changed files with 143 additions and 15 deletions
  1. +30
    -0
      src/TensorFlowNET.Core/APIs/tf.ops.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
  3. +9
    -2
      src/TensorFlowNET.Core/Operations/TensorArray.cs
  4. +66
    -9
      src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
  5. +20
    -3
      src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
  6. +17
    -0
      test/TensorFlowNET.UnitTest/OperationsTest.cs

+ 30
- 0
src/TensorFlowNET.Core/APIs/tf.ops.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;

namespace Tensorflow
@@ -61,5 +62,34 @@ namespace Tensorflow
/// <returns></returns>
public Tensor no_op(string name = null)
=> gen_control_flow_ops.no_op(name: name);

/// <summary>
/// map on the list of tensors unpacked from `elems` on dimension 0.
/// </summary>
/// <param name="fn"></param>
/// <param name="elems"></param>
/// <param name="dtype"></param>
/// <param name="parallel_iterations"></param>
/// <param name="back_prop"></param>
/// <param name="swap_memory"></param>
/// <param name="infer_shape"></param>
/// <param name="name"></param>
/// <returns>A tensor or (possibly nested) sequence of tensors.</returns>
public Tensor map_fn(Func<Tensor, Tensor> fn,
Tensor elems,
TF_DataType dtype = TF_DataType.DtInvalid,
int parallel_iterations = -1,
bool back_prop = true,
bool swap_memory = false,
bool infer_shape = true,
string name = null)
=> Operation.map_fn(fn,
elems,
dtype,
parallel_iterations: parallel_iterations,
back_prop: back_prop,
swap_memory: swap_memory,
infer_shape: infer_shape,
name: name);
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -145,7 +145,7 @@ namespace Tensorflow.Operations
{
var ta = new TensorArray(dtype: dtype_,
size: time_steps,
element_shape: element_shape,
element_shape: new[] { element_shape },
tensor_array_name: base_name + name);
return ta;
};


+ 9
- 2
src/TensorFlowNET.Core/Operations/TensorArray.cs View File

@@ -33,9 +33,13 @@ namespace Tensorflow.Operations
{
_GraphTensorArray _implementation;

public TensorArray(TF_DataType dtype, Tensor size = null, bool? clear_after_read = null, bool? dynamic_size = null,
public TF_DataType dtype => _implementation._dtype;
public Tensor handle => _implementation._handle;
public Tensor flow => _implementation._flow;

public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null,
string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
bool infer_shape = true, TensorShape element_shape = null,
bool infer_shape = true, TensorShape[] element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
_implementation = new _GraphTensorArray(dtype,
@@ -50,5 +54,8 @@ namespace Tensorflow.Operations
colocate_with_first_write_call: colocate_with_first_write_call,
name: name);
}

public TensorArray unstack(Tensor value, string name = null)
=> _implementation.unstack(value, name: name);
}
}

+ 66
- 9
src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs View File

@@ -16,6 +16,7 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using static Tensorflow.Binding;

@@ -23,7 +24,7 @@ namespace Tensorflow.Operations
{
internal class _GraphTensorArray
{
TF_DataType _dtype;
internal TF_DataType _dtype;

/// <summary>
/// Used to keep track of what tensors the TensorArray should be
@@ -33,23 +34,27 @@ namespace Tensorflow.Operations
bool _colocate_with_first_write_call;

bool _infer_shape;
bool _dynamic_size;
List<TensorShape> _element_shape;

object _colocate_with;
List<Tensor> _colocate_with;

internal Tensor _handle;
internal Tensor _flow;

public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null,
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
bool infer_shape = true, TensorShape element_shape = null,
bool infer_shape = true, TensorShape[] element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
clear_after_read = clear_after_read ?? true;
dynamic_size = dynamic_size ?? false;
_dynamic_size = dynamic_size.Value;
_dtype = dtype;

_colocate_with_first_write_call = colocate_with_first_write_call;
if (colocate_with_first_write_call)
_colocate_with = new Tensor[0];
_colocate_with = new List<Tensor>();

// Record the current static shape for the array elements. The element
// shape is defined either by `element_shape` or the shape of the tensor
@@ -66,11 +71,12 @@ namespace Tensorflow.Operations
_element_shape = new List<TensorShape> { };
}

tf_with(ops.name_scope(name, "", new { handle, size, flow }), scope =>
tf_with(ops.name_scope(name, "TensorArray", new { handle, size, flow }), scope =>
{
if(handle != null)
{

_handle = handle;
_flow = flow;
}
else
{
@@ -89,14 +95,65 @@ namespace Tensorflow.Operations
if (colocate_with_first_write_call)
{
ops.colocate_with(ignore_existing: true);
create();
(_handle, _flow) = create();
}
else
{
(_handle, _flow) = create();
}
}
});
}

public TensorArray unstack(Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate
{
var num_elements = array_ops.shape(value)[0];
return scatter(indices: math_ops.range(0, num_elements), value: value, name: name);
});
}

public TensorArray scatter(Tensor indices, Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate
{
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
if (_infer_shape)
{
var shape = new TensorShape(value.TensorShape.dims.Skip(1).ToArray());
_merge_element_shape(shape);
}

_maybe_colocate_with(value);
var flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
handle: _handle,
indices: indices,
value: value,
flow_in: _flow,
name: name);

var ta = new TensorArray(_dtype,
infer_shape:_infer_shape,
element_shape: _element_shape.ToArray(),
dynamic_size: _dynamic_size,
handle: _handle,
flow: flow_out,
colocate_with_first_write_call: _colocate_with_first_write_call);

return ta;
});
}

public void _merge_element_shape(TensorShape shape)
{
_element_shape.Add(shape);
}

public void _maybe_colocate_with(Tensor value)
{
_colocate_with.Add(value);
}
}
}

+ 20
- 3
src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs View File

@@ -27,10 +27,13 @@ namespace Tensorflow
return _op.output;
}

public static (Tensor, Tensor) tensor_array_v3(Tensor size, TF_DataType dtype = TF_DataType.DtInvalid,
int[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true,
public static (Tensor, Tensor) tensor_array_v3<T>(T size, TF_DataType dtype = TF_DataType.DtInvalid,
TensorShape[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true,
bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null)
{
if (tensor_array_name == null)
tensor_array_name = string.Empty;

var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new
{
size,
@@ -42,7 +45,21 @@ namespace Tensorflow
tensor_array_name
});

return (null, null);
return (_op.outputs[0], _op.outputs[1]);
}

public static Tensor tensor_array_scatter_v3(Tensor handle, Tensor indices, Tensor value,
Tensor flow_in, string name = null)
{
var _op = _op_def_lib._apply_op_helper("TensorArrayScatterV3", name, new
{
handle,
indices,
value,
flow_in
});

return _op.output;
}

public static Tensor padding_fifo_queue_v2(TF_DataType[] component_types, TensorShape[] shapes,


+ 17
- 0
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -1494,5 +1494,22 @@ namespace TensorFlowNET.UnitTest
}
#endregion
}

[TestMethod]
public void map_fn()
{
var a = tf.constant(new[] { 1, 2, 3, 4 });
var b = tf.constant(new[] { 17, 12, 11, 10 });
var ab = tf.stack(new[] { a, b }, 1);

Func<Tensor, Tensor> map_operation = (value_ab) =>
{
var value_a = value_ab[0];
var value_b = value_ab[1];
return value_a + value_b;
};

var map_result = tf.map_fn(map_operation, ab);
}
}
}

Loading…
Cancel
Save