@@ -19,6 +19,7 @@ using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using static Tensorflow.Binding; | |||
using Tensorflow.Operations; | |||
namespace Tensorflow | |||
{ | |||
@@ -309,5 +310,27 @@ namespace Tensorflow | |||
/// <returns></returns> | |||
public Tensor stop_gradient(Tensor x, string name = null) | |||
=> gen_array_ops.stop_gradient(x, name: name); | |||
public TensorArray TensorArray(TF_DataType dtype, int size = 0, bool dynamic_size = false, | |||
bool clear_after_read = true, Shape? element_shape = null, bool colocate_with_first_write_call = true, | |||
bool infer_shape = true) | |||
=> tf.executing_eagerly() ? | |||
new _EagerTensorArray(dtype, size: constant_op.constant(size), dynamic_size: dynamic_size, | |||
clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape, | |||
colocate_with_first_write_call: colocate_with_first_write_call) : | |||
new _GraphTensorArray(dtype, size: constant_op.constant(size), dynamic_size: dynamic_size, | |||
clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape, | |||
colocate_with_first_write_call: colocate_with_first_write_call); | |||
public TensorArray TensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false, | |||
bool clear_after_read = true, Shape? element_shape = null, bool colocate_with_first_write_call = true, | |||
bool infer_shape = true) | |||
=> tf.executing_eagerly() ? | |||
new _EagerTensorArray(dtype, size: size, dynamic_size: dynamic_size, | |||
clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape, | |||
colocate_with_first_write_call: colocate_with_first_write_call) : | |||
new _GraphTensorArray(dtype, size: size, dynamic_size: dynamic_size, | |||
clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape, | |||
colocate_with_first_write_call: colocate_with_first_write_call); | |||
} | |||
} |
@@ -294,10 +294,9 @@ namespace Tensorflow.Operations | |||
Func<string, Shape, TF_DataType, TensorArray> _create_ta = (name, element_shape, dtype_) => | |||
{ | |||
var ta = new TensorArray(dtype: dtype_, | |||
var ta = tf.TensorArray(dtype: dtype_, | |||
size: time_steps, | |||
element_shape: element_shape, | |||
tensor_array_name: base_name + name); | |||
element_shape: element_shape); | |||
return ta; | |||
}; | |||
@@ -0,0 +1,184 @@ | |||
/***************************************************************************** | |||
Copyright 2022 Haiping Chen. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Framework; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Operations | |||
{ | |||
public class _EagerTensorArray : TensorArray | |||
{ | |||
TF_DataType _dtype; | |||
public override TF_DataType dtype => _dtype; | |||
/// <summary> | |||
/// Used to keep track of what tensors the TensorArray should be | |||
/// colocated with. We choose to colocate the TensorArray with the | |||
/// first tensor written to it. | |||
/// </summary> | |||
bool _colocate_with_first_write_call; | |||
public override bool colocate_with_first_write_call => _colocate_with_first_write_call; | |||
bool _infer_shape; | |||
public override bool infer_shape => _infer_shape; | |||
public bool _dynamic_size; | |||
public Shape _element_shape; | |||
public List<Tensor> _colocate_with; | |||
Tensor _handle; | |||
public override Tensor handle => _handle; | |||
Tensor _flow; | |||
public override Tensor flow => _flow; | |||
bool _clear_after_read; | |||
List<Tensor> _tensor_array; | |||
public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false, | |||
bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||
bool infer_shape = true, Shape? element_shape = null, | |||
bool colocate_with_first_write_call = true, string name = null) | |||
{ | |||
_flow = constant_op.constant(0); | |||
_infer_shape = infer_shape; | |||
_element_shape = element_shape ?? Shape.Null; | |||
_colocate_with_first_write_call = colocate_with_first_write_call; | |||
_dtype = dtype.as_base_dtype(); | |||
_dynamic_size = dynamic_size; | |||
_clear_after_read = clear_after_read; | |||
_tensor_array = new List<Tensor>(); | |||
} | |||
public override TensorArray unstack(Tensor value, string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate | |||
{ | |||
var num_elements = array_ops.shape(value)[0]; | |||
return scatter(indices: math_ops.range(0, num_elements), value: value, name: name); | |||
}); | |||
} | |||
public TensorArray scatter(Tensor indices, Tensor value, string name = null) | |||
{ | |||
/*return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate | |||
{ | |||
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||
if (_infer_shape) | |||
{ | |||
var shape = new Shape(value.shape.dims.Skip(1).ToArray()); | |||
_merge_element_shape(shape); | |||
} | |||
_maybe_colocate_with(value); | |||
var flow_out = gen_data_flow_ops.tensor_array_scatter_v3( | |||
handle: _handle, | |||
indices: indices, | |||
value: value, | |||
flow_in: _flow, | |||
name: name); | |||
var ta = new _EagerTensorArray(_dtype, | |||
infer_shape: _infer_shape, | |||
element_shape: _element_shape[0], | |||
dynamic_size: _dynamic_size, | |||
handle: _handle, | |||
flow: flow_out, | |||
colocate_with_first_write_call: _colocate_with_first_write_call); | |||
return ta; | |||
});*/ | |||
throw new NotImplementedException(""); | |||
} | |||
public void _merge_element_shape(Shape shape) | |||
{ | |||
_element_shape.concatenate(shape); | |||
} | |||
public void _maybe_colocate_with(Tensor value) | |||
{ | |||
_colocate_with.Add(value); | |||
} | |||
public override Tensor read<T>(T index, string name = null) | |||
{ | |||
int index_int = -1; | |||
if (index is int int_index) | |||
index_int = int_index; | |||
else if (index is Tensor tensor_index) | |||
index_int = tensor_index.numpy(); | |||
else | |||
throw new ValueError(""); | |||
if (_clear_after_read) | |||
{ | |||
_tensor_array[index_int] = null; | |||
} | |||
return _tensor_array[index_int]; | |||
} | |||
public override TensorArray write(Tensor index, Tensor value, string name = null) | |||
{ | |||
if (_infer_shape) | |||
_element_shape = _element_shape.merge_with(value.shape); | |||
_tensor_array.add(value); | |||
return this; | |||
} | |||
public override TensorArray write<T>(int index, T value, string name = null) | |||
{ | |||
var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||
var index_tensor = ops.convert_to_tensor(index, name: "index"); | |||
return write(index_tensor, value_tensor, name: name); | |||
} | |||
private Tensor size(string name = null) | |||
{ | |||
return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name); | |||
} | |||
public override Tensor stack(string name = null) | |||
{ | |||
ops.colocate_with(_handle); | |||
return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate | |||
{ | |||
return gather(math_ops.range(0, size()), name: name); | |||
}); | |||
} | |||
public override Tensor gather(Tensor indices, string name = null) | |||
{ | |||
var element_shape = Shape.Null; | |||
var value = gen_data_flow_ops.tensor_array_gather_v3( | |||
handle: _handle, | |||
indices: indices, | |||
flow_in: _flow, | |||
dtype: _dtype, | |||
name: name, | |||
element_shape: element_shape); | |||
//if (element_shape != null) | |||
//value.set_shape(-1, element_shape.dims); | |||
return value; | |||
} | |||
} | |||
} |
@@ -21,7 +21,7 @@ using static Tensorflow.Binding; | |||
namespace Tensorflow.Operations | |||
{ | |||
public class _GraphTensorArray | |||
public class _GraphTensorArray : TensorArray | |||
{ | |||
internal TF_DataType _dtype; | |||
public TF_DataType dtype => _dtype; | |||
@@ -47,7 +47,7 @@ namespace Tensorflow.Operations | |||
public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, | |||
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||
bool infer_shape = true, Shape element_shape = null, | |||
bool infer_shape = true, Shape? element_shape = null, | |||
bool colocate_with_first_write_call = true, string name = null) | |||
{ | |||
clear_after_read = clear_after_read ?? true; | |||
@@ -108,7 +108,7 @@ namespace Tensorflow.Operations | |||
}); | |||
} | |||
public TensorArray unstack(Tensor value, string name = null) | |||
public override TensorArray unstack(Tensor value, string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate | |||
{ | |||
@@ -119,7 +119,7 @@ namespace Tensorflow.Operations | |||
public TensorArray scatter(Tensor indices, Tensor value, string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate | |||
/*return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate | |||
{ | |||
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||
if (_infer_shape) | |||
@@ -136,7 +136,7 @@ namespace Tensorflow.Operations | |||
flow_in: _flow, | |||
name: name); | |||
var ta = new TensorArray(_dtype, | |||
var ta = new _GraphTensorArray(_dtype, | |||
infer_shape: _infer_shape, | |||
element_shape: _element_shape[0], | |||
dynamic_size: _dynamic_size, | |||
@@ -144,9 +144,9 @@ namespace Tensorflow.Operations | |||
flow: flow_out, | |||
colocate_with_first_write_call: _colocate_with_first_write_call); | |||
return ta; | |||
}); | |||
});*/ | |||
throw new NotImplementedException(""); | |||
} | |||
public void _merge_element_shape(Shape shape) | |||
@@ -159,11 +159,11 @@ namespace Tensorflow.Operations | |||
_colocate_with.Add(value); | |||
} | |||
public Tensor read(Tensor index, string name = null) | |||
public override Tensor read<T>(T index, string name = null) | |||
{ | |||
var value = gen_data_flow_ops.tensor_array_read_v3( | |||
handle: _handle, | |||
index: index, | |||
index: constant_op.constant(index), | |||
flow_in: _flow, | |||
dtype: _dtype, | |||
name: name); | |||
@@ -174,11 +174,10 @@ namespace Tensorflow.Operations | |||
return value; | |||
} | |||
public TensorArray write(Tensor index, Tensor value, string name = null) | |||
public override TensorArray write(Tensor index, Tensor value, string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, "TensorArrayWrite", new { _handle, index, value }), delegate | |||
{ | |||
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||
_maybe_colocate_with(value); | |||
var flow_out = gen_data_flow_ops.tensor_array_write_v3( | |||
handle: _handle, | |||
@@ -191,12 +190,19 @@ namespace Tensorflow.Operations | |||
}); | |||
} | |||
public override TensorArray write<T>(int index, T value, string name = null) | |||
{ | |||
var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||
var index_tensor = ops.convert_to_tensor(index, name: "index"); | |||
return write(index_tensor, value_tensor); | |||
} | |||
private Tensor size(string name = null) | |||
{ | |||
return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name); | |||
} | |||
public Tensor stack(string name = null) | |||
public override Tensor stack(string name = null) | |||
{ | |||
ops.colocate_with(_handle); | |||
return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate | |||
@@ -205,7 +211,7 @@ namespace Tensorflow.Operations | |||
}); | |||
} | |||
public Tensor gather(Tensor indices, string name = null) | |||
public override Tensor gather(Tensor indices, string name = null) | |||
{ | |||
var element_shape = Shape.Null; | |||
@@ -87,9 +87,9 @@ namespace Tensorflow | |||
// n = array_ops.shape(elems_flat[0])[0]; | |||
//} | |||
var elems_ta = elems_flat.Select(elem => new TensorArray( | |||
var elems_ta = elems_flat.Select(elem => tf.TensorArray( | |||
elem.dtype, | |||
size: tf.constant(n), | |||
size: n, | |||
dynamic_size: false, | |||
element_shape: elem.shape.dims.Skip(1).ToArray(), | |||
infer_shape: true)).ToList(); | |||
@@ -113,9 +113,9 @@ namespace Tensorflow | |||
i = 0; | |||
} | |||
var accs_ta = a_flat.Select(init => new TensorArray( | |||
var accs_ta = a_flat.Select(init => tf.TensorArray( | |||
dtype: init.dtype, | |||
size: tf.constant(n), | |||
size: n, | |||
element_shape: infer_shape ? init.shape : null, | |||
dynamic_size: false, | |||
infer_shape: infer_shape)).ToArray(); | |||
@@ -124,7 +124,7 @@ namespace Tensorflow | |||
{ | |||
for (int index = 0; index < accs_ta.Length; index++) | |||
{ | |||
accs_ta[index].write(tf.constant(reverse ? n - 1 : 0), a_flat[index]); | |||
accs_ta[index].write(reverse ? n - 1 : 0, a_flat[index]); | |||
} | |||
} | |||
@@ -78,8 +78,8 @@ namespace Tensorflow | |||
var n = static_shape[0]; | |||
// TensorArrays are always flat | |||
var elems_ta = elems_flat.Select(elem => new TensorArray(dtype: elem.dtype, | |||
size: ops.convert_to_tensor(n), | |||
var elems_ta = elems_flat.Select(elem => tf.TensorArray(dtype: elem.dtype, | |||
size: Convert.ToInt32(n), | |||
dynamic_size: false, | |||
infer_shape: true)).ToArray(); | |||
@@ -92,8 +92,8 @@ namespace Tensorflow | |||
var i = constant_op.constant(0); | |||
var accs_ta = dtype_flat.Select(dt => new TensorArray(dtype: dt, | |||
size: ops.convert_to_tensor(n), | |||
var accs_ta = dtype_flat.Select(dt => tf.TensorArray(dtype: dt, | |||
size: Convert.ToInt32(n), | |||
dynamic_size: false, | |||
infer_shape: infer_shape)).ToArray(); | |||
@@ -1,4 +1,5 @@ | |||
using Tensorflow.Operations; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
@@ -12,37 +13,21 @@ namespace Tensorflow | |||
/// <returns></returns> | |||
public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow) | |||
{ | |||
var impl = old_ta._implementation; | |||
var new_ta = tf.TensorArray( | |||
dtype: old_ta.dtype, | |||
infer_shape: old_ta.infer_shape, | |||
colocate_with_first_write_call: old_ta.colocate_with_first_write_call); | |||
var new_ta = new TensorArray( | |||
dtype: impl.dtype, | |||
handle: impl.handle, | |||
flow: flow, | |||
infer_shape: impl.infer_shape, | |||
colocate_with_first_write_call: impl.colocate_with_first_write_call); | |||
var new_impl = new_ta._implementation; | |||
new_impl._dynamic_size = impl._dynamic_size; | |||
new_impl._colocate_with = impl._colocate_with; | |||
new_impl._element_shape = impl._element_shape; | |||
return new_ta; | |||
} | |||
public static TensorArray build_ta_with_new_flow(_GraphTensorArray old_ta, Tensor flow) | |||
{ | |||
var impl = old_ta; | |||
var new_ta = new TensorArray( | |||
dtype: impl.dtype, | |||
handle: impl.handle, | |||
flow: flow, | |||
infer_shape: impl.infer_shape, | |||
colocate_with_first_write_call: impl.colocate_with_first_write_call); | |||
var new_ta = tf.TensorArray( | |||
dtype: old_ta.dtype, | |||
infer_shape: old_ta.infer_shape, | |||
colocate_with_first_write_call: old_ta.colocate_with_first_write_call); | |||
var new_impl = new_ta._implementation; | |||
new_impl._dynamic_size = impl._dynamic_size; | |||
new_impl._colocate_with = impl._colocate_with; | |||
new_impl._element_shape = impl._element_shape; | |||
return new_ta; | |||
} | |||
} | |||
@@ -27,42 +27,22 @@ namespace Tensorflow | |||
/// `while_loop` and `map_fn`. It supports gradient back-propagation via special | |||
/// "flow" control flow dependencies. | |||
/// </summary> | |||
public class TensorArray : ITensorOrTensorArray | |||
public abstract class TensorArray : ITensorOrTensorArray | |||
{ | |||
internal _GraphTensorArray _implementation; | |||
public virtual TF_DataType dtype { get; } | |||
public virtual Tensor handle { get; } | |||
public virtual Tensor flow { get; } | |||
public virtual bool infer_shape { get; } | |||
public virtual bool colocate_with_first_write_call { get; } | |||
public TF_DataType dtype => _implementation._dtype; | |||
public Tensor handle => _implementation._handle; | |||
public Tensor flow => _implementation._flow; | |||
public abstract TensorArray unstack(Tensor value, string name = null); | |||
public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null, | |||
string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||
bool infer_shape = true, Shape element_shape = null, | |||
bool colocate_with_first_write_call = true, string name = null) | |||
{ | |||
_implementation = new _GraphTensorArray(dtype, | |||
size: size, | |||
dynamic_size: dynamic_size, | |||
clear_after_read: clear_after_read, | |||
tensor_array_name: tensor_array_name, | |||
handle: handle, | |||
flow: flow, | |||
infer_shape: infer_shape, | |||
element_shape: element_shape, | |||
colocate_with_first_write_call: colocate_with_first_write_call, | |||
name: name); | |||
} | |||
public abstract Tensor read<T>(T index, string name = null); | |||
public TensorArray unstack(Tensor value, string name = null) | |||
=> _implementation.unstack(value, name: name); | |||
public abstract TensorArray write<T>(int index, T value, string name = null); | |||
public abstract TensorArray write(Tensor index, Tensor value, string name = null); | |||
public Tensor read(Tensor index, string name = null) | |||
=> _implementation.read(index, name: name); | |||
public TensorArray write(Tensor index, Tensor value, string name = null) | |||
=> _implementation.write(index, value, name: name); | |||
public Tensor stack(string name = null) | |||
=> _implementation.stack(name: name); | |||
public abstract Tensor stack(string name = null); | |||
public abstract Tensor gather(Tensor indices, string name = null); | |||
} | |||
} |
@@ -77,5 +77,20 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
var r3 = tf.gather(p2, i2, axis: 1); | |||
Assert.AreEqual(new Shape(4,1,2), r3.shape); | |||
} | |||
/// <summary> | |||
/// https://www.tensorflow.org/api_docs/python/tf/TensorArray | |||
/// </summary> | |||
[TestMethod] | |||
public void TensorArray() | |||
{ | |||
var ta = tf.TensorArray(tf.float32, size: 0, dynamic_size: true, clear_after_read: false); | |||
ta.write(0, 10); | |||
ta.write(1, 20); | |||
ta.write(2, 30); | |||
Assert.AreEqual(ta.read(0).numpy(), 10f); | |||
Assert.AreEqual(ta.read(1).numpy(), 20f); | |||
Assert.AreEqual(ta.read(2).numpy(), 30f); | |||
} | |||
} | |||
} |