Browse Source

refactor TensorArray #903

tags/v0.70.2-NET6
Oceania2018 3 years ago
parent
commit
e27145e427
9 changed files with 273 additions and 81 deletions
  1. +23
    -0
      src/TensorFlowNET.Core/APIs/tf.array.cs
  2. +2
    -3
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
  3. +184
    -0
      src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs
  4. +19
    -13
      src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
  5. +5
    -5
      src/TensorFlowNET.Core/Operations/functional_ops.cs
  6. +4
    -4
      src/TensorFlowNET.Core/Operations/map_fn.cs
  7. +9
    -24
      src/TensorFlowNET.Core/Operations/tensor_array_ops.cs
  8. +12
    -32
      src/TensorFlowNET.Core/Tensors/TensorArray.cs
  9. +15
    -0
      test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

+ 23
- 0
src/TensorFlowNET.Core/APIs/tf.array.cs View File

@@ -19,6 +19,7 @@ using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using static Tensorflow.Binding;
using Tensorflow.Operations;

namespace Tensorflow
{
@@ -309,5 +310,27 @@ namespace Tensorflow
/// <returns></returns>
public Tensor stop_gradient(Tensor x, string name = null)
=> gen_array_ops.stop_gradient(x, name: name);

public TensorArray TensorArray(TF_DataType dtype, int size = 0, bool dynamic_size = false,
bool clear_after_read = true, Shape? element_shape = null, bool colocate_with_first_write_call = true,
bool infer_shape = true)
=> tf.executing_eagerly() ?
new _EagerTensorArray(dtype, size: constant_op.constant(size), dynamic_size: dynamic_size,
clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape,
colocate_with_first_write_call: colocate_with_first_write_call) :
new _GraphTensorArray(dtype, size: constant_op.constant(size), dynamic_size: dynamic_size,
clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape,
colocate_with_first_write_call: colocate_with_first_write_call);

public TensorArray TensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false,
bool clear_after_read = true, Shape? element_shape = null, bool colocate_with_first_write_call = true,
bool infer_shape = true)
=> tf.executing_eagerly() ?
new _EagerTensorArray(dtype, size: size, dynamic_size: dynamic_size,
clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape,
colocate_with_first_write_call: colocate_with_first_write_call) :
new _GraphTensorArray(dtype, size: size, dynamic_size: dynamic_size,
clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape,
colocate_with_first_write_call: colocate_with_first_write_call);
}
}

+ 2
- 3
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -294,10 +294,9 @@ namespace Tensorflow.Operations

Func<string, Shape, TF_DataType, TensorArray> _create_ta = (name, element_shape, dtype_) =>
{
var ta = new TensorArray(dtype: dtype_,
var ta = tf.TensorArray(dtype: dtype_,
size: time_steps,
element_shape: element_shape,
tensor_array_name: base_name + name);
element_shape: element_shape);
return ta;
};



+ 184
- 0
src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs View File

@@ -0,0 +1,184 @@
/*****************************************************************************
Copyright 2022 Haiping Chen. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Framework;
using static Tensorflow.Binding;

namespace Tensorflow.Operations
{
public class _EagerTensorArray : TensorArray
{
TF_DataType _dtype;
public override TF_DataType dtype => _dtype;

/// <summary>
/// Used to keep track of what tensors the TensorArray should be
/// colocated with. We choose to colocate the TensorArray with the
/// first tensor written to it.
/// </summary>
bool _colocate_with_first_write_call;
public override bool colocate_with_first_write_call => _colocate_with_first_write_call;

bool _infer_shape;
public override bool infer_shape => _infer_shape;
public bool _dynamic_size;
public Shape _element_shape;

public List<Tensor> _colocate_with;

Tensor _handle;
public override Tensor handle => _handle;
Tensor _flow;
public override Tensor flow => _flow;
bool _clear_after_read;
List<Tensor> _tensor_array;

public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false,
bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
bool infer_shape = true, Shape? element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
_flow = constant_op.constant(0);
_infer_shape = infer_shape;
_element_shape = element_shape ?? Shape.Null;
_colocate_with_first_write_call = colocate_with_first_write_call;
_dtype = dtype.as_base_dtype();
_dynamic_size = dynamic_size;
_clear_after_read = clear_after_read;
_tensor_array = new List<Tensor>();
}

public override TensorArray unstack(Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate
{
var num_elements = array_ops.shape(value)[0];
return scatter(indices: math_ops.range(0, num_elements), value: value, name: name);
});
}

public TensorArray scatter(Tensor indices, Tensor value, string name = null)
{
/*return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate
{
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
if (_infer_shape)
{
var shape = new Shape(value.shape.dims.Skip(1).ToArray());
_merge_element_shape(shape);
}

_maybe_colocate_with(value);
var flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
handle: _handle,
indices: indices,
value: value,
flow_in: _flow,
name: name);

var ta = new _EagerTensorArray(_dtype,
infer_shape: _infer_shape,
element_shape: _element_shape[0],
dynamic_size: _dynamic_size,
handle: _handle,
flow: flow_out,
colocate_with_first_write_call: _colocate_with_first_write_call);


return ta;
});*/
throw new NotImplementedException("");
}

public void _merge_element_shape(Shape shape)
{
_element_shape.concatenate(shape);
}

public void _maybe_colocate_with(Tensor value)
{
_colocate_with.Add(value);
}

public override Tensor read<T>(T index, string name = null)
{
int index_int = -1;
if (index is int int_index)
index_int = int_index;
else if (index is Tensor tensor_index)
index_int = tensor_index.numpy();
else
throw new ValueError("");

if (_clear_after_read)
{
_tensor_array[index_int] = null;
}

return _tensor_array[index_int];
}

public override TensorArray write(Tensor index, Tensor value, string name = null)
{
if (_infer_shape)
_element_shape = _element_shape.merge_with(value.shape);
_tensor_array.add(value);
return this;
}

public override TensorArray write<T>(int index, T value, string name = null)
{
var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
var index_tensor = ops.convert_to_tensor(index, name: "index");
return write(index_tensor, value_tensor, name: name);
}

private Tensor size(string name = null)
{
return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name);
}

public override Tensor stack(string name = null)
{
ops.colocate_with(_handle);
return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
{
return gather(math_ops.range(0, size()), name: name);
});
}

public override Tensor gather(Tensor indices, string name = null)
{
var element_shape = Shape.Null;

var value = gen_data_flow_ops.tensor_array_gather_v3(
handle: _handle,
indices: indices,
flow_in: _flow,
dtype: _dtype,
name: name,
element_shape: element_shape);

//if (element_shape != null)
//value.set_shape(-1, element_shape.dims);

return value;
}
}
}

+ 19
- 13
src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs View File

@@ -21,7 +21,7 @@ using static Tensorflow.Binding;

namespace Tensorflow.Operations
{
public class _GraphTensorArray
public class _GraphTensorArray : TensorArray
{
internal TF_DataType _dtype;
public TF_DataType dtype => _dtype;
@@ -47,7 +47,7 @@ namespace Tensorflow.Operations

public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null,
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
bool infer_shape = true, Shape element_shape = null,
bool infer_shape = true, Shape? element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
clear_after_read = clear_after_read ?? true;
@@ -108,7 +108,7 @@ namespace Tensorflow.Operations
});
}

public TensorArray unstack(Tensor value, string name = null)
public override TensorArray unstack(Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate
{
@@ -119,7 +119,7 @@ namespace Tensorflow.Operations

public TensorArray scatter(Tensor indices, Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate
/*return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate
{
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
if (_infer_shape)
@@ -136,7 +136,7 @@ namespace Tensorflow.Operations
flow_in: _flow,
name: name);

var ta = new TensorArray(_dtype,
var ta = new _GraphTensorArray(_dtype,
infer_shape: _infer_shape,
element_shape: _element_shape[0],
dynamic_size: _dynamic_size,
@@ -144,9 +144,9 @@ namespace Tensorflow.Operations
flow: flow_out,
colocate_with_first_write_call: _colocate_with_first_write_call);


return ta;
});
});*/
throw new NotImplementedException("");
}

public void _merge_element_shape(Shape shape)
@@ -159,11 +159,11 @@ namespace Tensorflow.Operations
_colocate_with.Add(value);
}

public Tensor read(Tensor index, string name = null)
public override Tensor read<T>(T index, string name = null)
{
var value = gen_data_flow_ops.tensor_array_read_v3(
handle: _handle,
index: index,
index: constant_op.constant(index),
flow_in: _flow,
dtype: _dtype,
name: name);
@@ -174,11 +174,10 @@ namespace Tensorflow.Operations
return value;
}

public TensorArray write(Tensor index, Tensor value, string name = null)
public override TensorArray write(Tensor index, Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayWrite", new { _handle, index, value }), delegate
{
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
_maybe_colocate_with(value);
var flow_out = gen_data_flow_ops.tensor_array_write_v3(
handle: _handle,
@@ -191,12 +190,19 @@ namespace Tensorflow.Operations
});
}

public override TensorArray write<T>(int index, T value, string name = null)
{
var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
var index_tensor = ops.convert_to_tensor(index, name: "index");
return write(index_tensor, value_tensor);
}

private Tensor size(string name = null)
{
return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name);
}

public Tensor stack(string name = null)
public override Tensor stack(string name = null)
{
ops.colocate_with(_handle);
return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
@@ -205,7 +211,7 @@ namespace Tensorflow.Operations
});
}

public Tensor gather(Tensor indices, string name = null)
public override Tensor gather(Tensor indices, string name = null)
{
var element_shape = Shape.Null;



+ 5
- 5
src/TensorFlowNET.Core/Operations/functional_ops.cs View File

@@ -87,9 +87,9 @@ namespace Tensorflow
// n = array_ops.shape(elems_flat[0])[0];
//}

var elems_ta = elems_flat.Select(elem => new TensorArray(
var elems_ta = elems_flat.Select(elem => tf.TensorArray(
elem.dtype,
size: tf.constant(n),
size: n,
dynamic_size: false,
element_shape: elem.shape.dims.Skip(1).ToArray(),
infer_shape: true)).ToList();
@@ -113,9 +113,9 @@ namespace Tensorflow
i = 0;
}

var accs_ta = a_flat.Select(init => new TensorArray(
var accs_ta = a_flat.Select(init => tf.TensorArray(
dtype: init.dtype,
size: tf.constant(n),
size: n,
element_shape: infer_shape ? init.shape : null,
dynamic_size: false,
infer_shape: infer_shape)).ToArray();
@@ -124,7 +124,7 @@ namespace Tensorflow
{
for (int index = 0; index < accs_ta.Length; index++)
{
accs_ta[index].write(tf.constant(reverse ? n - 1 : 0), a_flat[index]);
accs_ta[index].write(reverse ? n - 1 : 0, a_flat[index]);
}
}



+ 4
- 4
src/TensorFlowNET.Core/Operations/map_fn.cs View File

@@ -78,8 +78,8 @@ namespace Tensorflow
var n = static_shape[0];

// TensorArrays are always flat
var elems_ta = elems_flat.Select(elem => new TensorArray(dtype: elem.dtype,
size: ops.convert_to_tensor(n),
var elems_ta = elems_flat.Select(elem => tf.TensorArray(dtype: elem.dtype,
size: Convert.ToInt32(n),
dynamic_size: false,
infer_shape: true)).ToArray();

@@ -92,8 +92,8 @@ namespace Tensorflow

var i = constant_op.constant(0);

var accs_ta = dtype_flat.Select(dt => new TensorArray(dtype: dt,
size: ops.convert_to_tensor(n),
var accs_ta = dtype_flat.Select(dt => tf.TensorArray(dtype: dt,
size: Convert.ToInt32(n),
dynamic_size: false,
infer_shape: infer_shape)).ToArray();



+ 9
- 24
src/TensorFlowNET.Core/Operations/tensor_array_ops.cs View File

@@ -1,4 +1,5 @@
using Tensorflow.Operations;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -12,37 +13,21 @@ namespace Tensorflow
/// <returns></returns>
public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow)
{
var impl = old_ta._implementation;
var new_ta = tf.TensorArray(
dtype: old_ta.dtype,
infer_shape: old_ta.infer_shape,
colocate_with_first_write_call: old_ta.colocate_with_first_write_call);

var new_ta = new TensorArray(
dtype: impl.dtype,
handle: impl.handle,
flow: flow,
infer_shape: impl.infer_shape,
colocate_with_first_write_call: impl.colocate_with_first_write_call);

var new_impl = new_ta._implementation;
new_impl._dynamic_size = impl._dynamic_size;
new_impl._colocate_with = impl._colocate_with;
new_impl._element_shape = impl._element_shape;
return new_ta;
}

public static TensorArray build_ta_with_new_flow(_GraphTensorArray old_ta, Tensor flow)
{
var impl = old_ta;

var new_ta = new TensorArray(
dtype: impl.dtype,
handle: impl.handle,
flow: flow,
infer_shape: impl.infer_shape,
colocate_with_first_write_call: impl.colocate_with_first_write_call);
var new_ta = tf.TensorArray(
dtype: old_ta.dtype,
infer_shape: old_ta.infer_shape,
colocate_with_first_write_call: old_ta.colocate_with_first_write_call);

var new_impl = new_ta._implementation;
new_impl._dynamic_size = impl._dynamic_size;
new_impl._colocate_with = impl._colocate_with;
new_impl._element_shape = impl._element_shape;
return new_ta;
}
}


+ 12
- 32
src/TensorFlowNET.Core/Tensors/TensorArray.cs View File

@@ -27,42 +27,22 @@ namespace Tensorflow
/// `while_loop` and `map_fn`. It supports gradient back-propagation via special
/// "flow" control flow dependencies.
/// </summary>
public class TensorArray : ITensorOrTensorArray
public abstract class TensorArray : ITensorOrTensorArray
{
internal _GraphTensorArray _implementation;
public virtual TF_DataType dtype { get; }
public virtual Tensor handle { get; }
public virtual Tensor flow { get; }
public virtual bool infer_shape { get; }
public virtual bool colocate_with_first_write_call { get; }

public TF_DataType dtype => _implementation._dtype;
public Tensor handle => _implementation._handle;
public Tensor flow => _implementation._flow;
public abstract TensorArray unstack(Tensor value, string name = null);

public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null,
string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
bool infer_shape = true, Shape element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
_implementation = new _GraphTensorArray(dtype,
size: size,
dynamic_size: dynamic_size,
clear_after_read: clear_after_read,
tensor_array_name: tensor_array_name,
handle: handle,
flow: flow,
infer_shape: infer_shape,
element_shape: element_shape,
colocate_with_first_write_call: colocate_with_first_write_call,
name: name);
}
public abstract Tensor read<T>(T index, string name = null);

public TensorArray unstack(Tensor value, string name = null)
=> _implementation.unstack(value, name: name);
public abstract TensorArray write<T>(int index, T value, string name = null);
public abstract TensorArray write(Tensor index, Tensor value, string name = null);

public Tensor read(Tensor index, string name = null)
=> _implementation.read(index, name: name);

public TensorArray write(Tensor index, Tensor value, string name = null)
=> _implementation.write(index, value, name: name);

public Tensor stack(string name = null)
=> _implementation.stack(name: name);
public abstract Tensor stack(string name = null);
public abstract Tensor gather(Tensor indices, string name = null);
}
}

+ 15
- 0
test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs View File

@@ -77,5 +77,20 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
var r3 = tf.gather(p2, i2, axis: 1);
Assert.AreEqual(new Shape(4,1,2), r3.shape);
}

/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/TensorArray
/// </summary>
[TestMethod]
public void TensorArray()
{
var ta = tf.TensorArray(tf.float32, size: 0, dynamic_size: true, clear_after_read: false);
ta.write(0, 10);
ta.write(1, 20);
ta.write(2, 30);
Assert.AreEqual(ta.read(0).numpy(), 10f);
Assert.AreEqual(ta.read(1).numpy(), 20f);
Assert.AreEqual(ta.read(2).numpy(), 30f);
}
}
}

Loading…
Cancel
Save