@@ -14,6 +14,7 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Collections.Generic; | |||
namespace Tensorflow | |||
@@ -61,5 +62,34 @@ namespace Tensorflow | |||
/// <returns></returns> | |||
public Tensor no_op(string name = null) | |||
=> gen_control_flow_ops.no_op(name: name); | |||
/// <summary> | |||
/// map on the list of tensors unpacked from `elems` on dimension 0. | |||
/// </summary> | |||
/// <param name="fn"></param> | |||
/// <param name="elems"></param> | |||
/// <param name="dtype"></param> | |||
/// <param name="parallel_iterations"></param> | |||
/// <param name="back_prop"></param> | |||
/// <param name="swap_memory"></param> | |||
/// <param name="infer_shape"></param> | |||
/// <param name="name"></param> | |||
/// <returns>A tensor or (possibly nested) sequence of tensors.</returns> | |||
public Tensor map_fn(Func<Tensor, Tensor> fn, | |||
Tensor elems, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
int parallel_iterations = -1, | |||
bool back_prop = true, | |||
bool swap_memory = false, | |||
bool infer_shape = true, | |||
string name = null) | |||
=> Operation.map_fn(fn, | |||
elems, | |||
dtype, | |||
parallel_iterations: parallel_iterations, | |||
back_prop: back_prop, | |||
swap_memory: swap_memory, | |||
infer_shape: infer_shape, | |||
name: name); | |||
} | |||
} |
@@ -145,7 +145,7 @@ namespace Tensorflow.Operations | |||
{ | |||
var ta = new TensorArray(dtype: dtype_, | |||
size: time_steps, | |||
element_shape: element_shape, | |||
element_shape: new[] { element_shape }, | |||
tensor_array_name: base_name + name); | |||
return ta; | |||
}; | |||
@@ -33,9 +33,13 @@ namespace Tensorflow.Operations | |||
{ | |||
_GraphTensorArray _implementation; | |||
public TensorArray(TF_DataType dtype, Tensor size = null, bool? clear_after_read = null, bool? dynamic_size = null, | |||
public TF_DataType dtype => _implementation._dtype; | |||
public Tensor handle => _implementation._handle; | |||
public Tensor flow => _implementation._flow; | |||
public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null, | |||
string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||
bool infer_shape = true, TensorShape element_shape = null, | |||
bool infer_shape = true, TensorShape[] element_shape = null, | |||
bool colocate_with_first_write_call = true, string name = null) | |||
{ | |||
_implementation = new _GraphTensorArray(dtype, | |||
@@ -50,5 +54,8 @@ namespace Tensorflow.Operations | |||
colocate_with_first_write_call: colocate_with_first_write_call, | |||
name: name); | |||
} | |||
public TensorArray unstack(Tensor value, string name = null) | |||
=> _implementation.unstack(value, name: name); | |||
} | |||
} |
@@ -16,6 +16,7 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using static Tensorflow.Binding; | |||
@@ -23,7 +24,7 @@ namespace Tensorflow.Operations | |||
{ | |||
internal class _GraphTensorArray | |||
{ | |||
TF_DataType _dtype; | |||
internal TF_DataType _dtype; | |||
/// <summary> | |||
/// Used to keep track of what tensors the TensorArray should be | |||
@@ -33,23 +34,27 @@ namespace Tensorflow.Operations | |||
bool _colocate_with_first_write_call; | |||
bool _infer_shape; | |||
bool _dynamic_size; | |||
List<TensorShape> _element_shape; | |||
object _colocate_with; | |||
List<Tensor> _colocate_with; | |||
internal Tensor _handle; | |||
internal Tensor _flow; | |||
public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, | |||
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||
bool infer_shape = true, TensorShape element_shape = null, | |||
bool infer_shape = true, TensorShape[] element_shape = null, | |||
bool colocate_with_first_write_call = true, string name = null) | |||
{ | |||
clear_after_read = clear_after_read ?? true; | |||
dynamic_size = dynamic_size ?? false; | |||
_dynamic_size = dynamic_size.Value; | |||
_dtype = dtype; | |||
_colocate_with_first_write_call = colocate_with_first_write_call; | |||
if (colocate_with_first_write_call) | |||
_colocate_with = new Tensor[0]; | |||
_colocate_with = new List<Tensor>(); | |||
// Record the current static shape for the array elements. The element | |||
// shape is defined either by `element_shape` or the shape of the tensor | |||
@@ -66,11 +71,12 @@ namespace Tensorflow.Operations | |||
_element_shape = new List<TensorShape> { }; | |||
} | |||
tf_with(ops.name_scope(name, "", new { handle, size, flow }), scope => | |||
tf_with(ops.name_scope(name, "TensorArray", new { handle, size, flow }), scope => | |||
{ | |||
if(handle != null) | |||
{ | |||
_handle = handle; | |||
_flow = flow; | |||
} | |||
else | |||
{ | |||
@@ -89,14 +95,65 @@ namespace Tensorflow.Operations | |||
if (colocate_with_first_write_call) | |||
{ | |||
ops.colocate_with(ignore_existing: true); | |||
create(); | |||
(_handle, _flow) = create(); | |||
} | |||
else | |||
{ | |||
(_handle, _flow) = create(); | |||
} | |||
} | |||
}); | |||
} | |||
public TensorArray unstack(Tensor value, string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate | |||
{ | |||
var num_elements = array_ops.shape(value)[0]; | |||
return scatter(indices: math_ops.range(0, num_elements), value: value, name: name); | |||
}); | |||
} | |||
public TensorArray scatter(Tensor indices, Tensor value, string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate | |||
{ | |||
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||
if (_infer_shape) | |||
{ | |||
var shape = new TensorShape(value.TensorShape.dims.Skip(1).ToArray()); | |||
_merge_element_shape(shape); | |||
} | |||
_maybe_colocate_with(value); | |||
var flow_out = gen_data_flow_ops.tensor_array_scatter_v3( | |||
handle: _handle, | |||
indices: indices, | |||
value: value, | |||
flow_in: _flow, | |||
name: name); | |||
var ta = new TensorArray(_dtype, | |||
infer_shape:_infer_shape, | |||
element_shape: _element_shape.ToArray(), | |||
dynamic_size: _dynamic_size, | |||
handle: _handle, | |||
flow: flow_out, | |||
colocate_with_first_write_call: _colocate_with_first_write_call); | |||
return ta; | |||
}); | |||
} | |||
public void _merge_element_shape(TensorShape shape) | |||
{ | |||
_element_shape.Add(shape); | |||
} | |||
public void _maybe_colocate_with(Tensor value) | |||
{ | |||
_colocate_with.Add(value); | |||
} | |||
} | |||
} |
@@ -27,10 +27,13 @@ namespace Tensorflow | |||
return _op.output; | |||
} | |||
public static (Tensor, Tensor) tensor_array_v3(Tensor size, TF_DataType dtype = TF_DataType.DtInvalid, | |||
int[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true, | |||
public static (Tensor, Tensor) tensor_array_v3<T>(T size, TF_DataType dtype = TF_DataType.DtInvalid, | |||
TensorShape[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true, | |||
bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null) | |||
{ | |||
if (tensor_array_name == null) | |||
tensor_array_name = string.Empty; | |||
var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new | |||
{ | |||
size, | |||
@@ -42,7 +45,21 @@ namespace Tensorflow | |||
tensor_array_name | |||
}); | |||
return (null, null); | |||
return (_op.outputs[0], _op.outputs[1]); | |||
} | |||
public static Tensor tensor_array_scatter_v3(Tensor handle, Tensor indices, Tensor value, | |||
Tensor flow_in, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("TensorArrayScatterV3", name, new | |||
{ | |||
handle, | |||
indices, | |||
value, | |||
flow_in | |||
}); | |||
return _op.output; | |||
} | |||
public static Tensor padding_fifo_queue_v2(TF_DataType[] component_types, TensorShape[] shapes, | |||
@@ -1494,5 +1494,22 @@ namespace TensorFlowNET.UnitTest | |||
} | |||
#endregion | |||
} | |||
[TestMethod] | |||
public void map_fn() | |||
{ | |||
var a = tf.constant(new[] { 1, 2, 3, 4 }); | |||
var b = tf.constant(new[] { 17, 12, 11, 10 }); | |||
var ab = tf.stack(new[] { a, b }, 1); | |||
Func<Tensor, Tensor> map_operation = (value_ab) => | |||
{ | |||
var value_a = value_ab[0]; | |||
var value_b = value_ab[1]; | |||
return value_a + value_b; | |||
}; | |||
var map_result = tf.map_fn(map_operation, ab); | |||
} | |||
} | |||
} |