Browse Source

built BasicLstmCell

tags/v0.13
Oceania2018 5 years ago
parent
commit
a2e7c9fef7
13 changed files with 244 additions and 74 deletions
  1. +4
    -2
      src/TensorFlowNET.Core/Layers/Layer.cs
  2. +121
    -6
      src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
  4. +8
    -14
      src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs
  5. +18
    -14
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  6. +63
    -9
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
  7. +15
    -0
      src/TensorFlowNET.Core/Operations/array_ops.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  9. +7
    -5
      src/TensorFlowNET.Core/TensorFlow.Binding.csproj
  10. +0
    -14
      src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs
  11. +0
    -8
      src/TensorFlowNET.Core/Util/nest.py.cs
  12. +5
    -0
      src/TensorFlowNET.Core/Variables/VariableScope.cs
  13. +1
    -0
      src/TensorFlowNET.Core/Variables/_ReuseMode.cs

+ 4
- 2
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -65,7 +65,9 @@ namespace Tensorflow.Layers
variable_scope scope_context_manager = null; variable_scope scope_context_manager = null;
if (built) if (built)
{ {

scope_context_manager = tf.variable_scope(_scope,
reuse: true,
auxiliary_name_scope: false);
} }
else else
{ {
@@ -181,7 +183,7 @@ namespace Tensorflow.Layers
return _current_scope.original_name_scope; return _current_scope.original_name_scope;
} }


private void _set_scope(VariableScope scope = null)
protected void _set_scope(VariableScope scope = null)
{ {
if (_scope == null) if (_scope == null)
{ {


+ 121
- 6
src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs View File

@@ -14,12 +14,17 @@ namespace Tensorflow
/// Basic LSTM recurrent network cell. /// Basic LSTM recurrent network cell.
/// The implementation is based on: http://arxiv.org/abs/1409.2329. /// The implementation is based on: http://arxiv.org/abs/1409.2329.
/// </summary> /// </summary>
public class BasicLSTMCell : LayerRnnCell
public class BasicLstmCell : LayerRnnCell
{ {
int _num_units; int _num_units;
float _forget_bias; float _forget_bias;
bool _state_is_tuple; bool _state_is_tuple;
IActivation _activation; IActivation _activation;
LSTMStateTuple _state;
VariableV1 _kernel;
VariableV1 _bias;
string _WEIGHTS_VARIABLE_NAME = "kernel";
string _BIAS_VARIABLE_NAME = "bias";


/// <summary> /// <summary>
/// Initialize the basic LSTM cell. /// Initialize the basic LSTM cell.
@@ -31,7 +36,7 @@ namespace Tensorflow
/// <param name="reuse"></param> /// <param name="reuse"></param>
/// <param name="name"></param> /// <param name="name"></param>
/// <param name="dtype"></param> /// <param name="dtype"></param>
public BasicLSTMCell(int num_units, float forget_bias = 1.0f, bool state_is_tuple = true,
public BasicLstmCell(int num_units, float forget_bias = 1.0f, bool state_is_tuple = true,
IActivation activation = null, bool? reuse = null, string name = null, IActivation activation = null, bool? reuse = null, string name = null,
TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse, name: name, dtype: dtype) TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse, name: name, dtype: dtype)
{ {
@@ -44,13 +49,123 @@ namespace Tensorflow
_activation = tf.nn.tanh(); _activation = tf.nn.tanh();
} }


public LSTMStateTuple state_size
protected override void build(TensorShape input_shape)
{
var input_depth = input_shape.dims.Last();
var h_depth = _num_units;
_kernel = add_weight(_WEIGHTS_VARIABLE_NAME,
shape: new[] { input_depth + h_depth, 4 * _num_units });
_bias = add_weight(_BIAS_VARIABLE_NAME,
shape: new[] { 4 * _num_units },
initializer: tf.zeros_initializer);
built = true;
}

public Tensor[] __call__(Tensor inputs, LSTMStateTuple state)
{
_state = state;
return base.__call__(inputs);
}

/// <summary>
/// Long short-term memory cell (LSTM).
/// </summary>
/// <param name="inputs"></param>
/// <param name="training"></param>
/// <param name="state"></param>
/// <returns></returns>
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
{
var one = constant_op.constant(1, dtype: dtypes.int32);
// Parameters of gates are concatenated into one multiply for efficiency.
Tensor c = null;
Tensor h = null;
if(_state_is_tuple)
(c, h) = ((Tensor)_state.c, (Tensor)_state.h);
else
{
// array_ops.split(value: state, num_or_size_splits: 2, axis: one);
throw new NotImplementedException("BasicLstmCell call");
}
var gate_inputs = math_ops.matmul(array_ops.concat(new[] { inputs, h }, 1), _kernel as RefVariable);
gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable);

// i = input_gate, j = new_input, f = forget_gate, o = output_gate
var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one);
var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]);

var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype);
// Note that using `add` and `multiply` instead of `+` and `*` gives a
// performance improvement. So using those at the cost of readability.
var new_c = gen_math_ops.add(
math_ops.multiply(c, math_ops.sigmoid(gen_math_ops.add(f, forget_bias_tensor))),
math_ops.multiply(math_ops.sigmoid(i), _activation.Activate(j)));

var new_h = math_ops.multiply(_activation.Activate(new_c), math_ops.sigmoid(o));


if (_state_is_tuple)
return new[] { new_c, new_h };
else
return new[] { array_ops.concat(new[] { new_c, new_h }, 1) };
}

public override object get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid)
{
if (inputs != null)
throw new NotImplementedException("get_initial_state input is not null");

return zero_state(batch_size, dtype);
}

/// <summary>
/// Return zero-filled state tensor(s).
/// </summary>
/// <param name="batch_size"></param>
/// <param name="dtype"></param>
/// <returns></returns>
private LSTMStateTuple zero_state(Tensor batch_size, TF_DataType dtype)
{
LSTMStateTuple output = null;
tf_with(ops.name_scope($"{GetType().Name}ZeroState", values: new { batch_size }), delegate
{
output = _zero_state_tensors(state_size, batch_size, dtype);
});

return output;
}

private LSTMStateTuple _zero_state_tensors(object state_size, Tensor batch_size, TF_DataType dtype)
{
if (state_size is LSTMStateTuple state_size_tuple)
{
var outputs = state_size_tuple.Flatten()
.Select(x => (int)x)
.Select(s =>
{
var c = rnn_cell_impl._concat(batch_size, s);
var size = array_ops.zeros(c, dtype: dtype);

var c_static = rnn_cell_impl._concat(batch_size, s, @static: true);
size.set_shape(c_static);

return size;
}).ToArray();

return new LSTMStateTuple(outputs[0], outputs[1]);
}

throw new NotImplementedException("_zero_state_tensors");
}

public override object state_size
{ {
get get
{ {
return _state_is_tuple ?
new LSTMStateTuple(_num_units, _num_units) :
(LSTMStateTuple)(2 * _num_units);
if (_state_is_tuple)
return new LSTMStateTuple(_num_units, _num_units);
else
return 2 * _num_units;
} }
} }
} }


+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs View File

@@ -26,7 +26,7 @@ namespace Tensorflow
int _num_units; int _num_units;
Func<Tensor, string, Tensor> _activation; Func<Tensor, string, Tensor> _activation;


public override LSTMStateTuple state_size => _num_units;
public override object state_size => _num_units;
public override int output_size => _num_units; public override int output_size => _num_units;
public VariableV1 _kernel; public VariableV1 _kernel;
string _WEIGHTS_VARIABLE_NAME = "kernel"; string _WEIGHTS_VARIABLE_NAME = "kernel";


+ 8
- 14
src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs View File

@@ -12,15 +12,10 @@ namespace Tensorflow.Operations
/// ///
/// Only used when `state_is_tuple=True`. /// Only used when `state_is_tuple=True`.
/// </summary> /// </summary>
public class LSTMStateTuple
public class LSTMStateTuple : ICanBeFlattened
{ {
int c;
int h;

public LSTMStateTuple(int c)
{
this.c = c;
}
public object c;
public object h;


public LSTMStateTuple(int c, int h) public LSTMStateTuple(int c, int h)
{ {
@@ -28,14 +23,13 @@ namespace Tensorflow.Operations
this.h = h; this.h = h;
} }


public static implicit operator int(LSTMStateTuple tuple)
public LSTMStateTuple(Tensor c, Tensor h)
{ {
return tuple.c;
this.c = c;
this.h = h;
} }


public static implicit operator LSTMStateTuple(int c)
{
return new LSTMStateTuple(c);
}
public object[] Flatten()
=> new[] { c, h };
} }
} }

+ 18
- 14
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -49,7 +49,7 @@ namespace Tensorflow
/// difference between TF and Keras RNN cell. /// difference between TF and Keras RNN cell.
/// </summary> /// </summary>
protected bool _is_tf_rnn_cell = false; protected bool _is_tf_rnn_cell = false;
public virtual LSTMStateTuple state_size { get; }
public virtual object state_size { get; }


public virtual int output_size { get; } public virtual int output_size { get; }


@@ -64,7 +64,7 @@ namespace Tensorflow
_is_tf_rnn_cell = true; _is_tf_rnn_cell = true;
} }


public virtual Tensor get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid)
public virtual object get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid)
{ {
if (inputs != null) if (inputs != null)
throw new NotImplementedException("get_initial_state input is not null"); throw new NotImplementedException("get_initial_state input is not null");
@@ -78,11 +78,10 @@ namespace Tensorflow
/// <param name="batch_size"></param> /// <param name="batch_size"></param>
/// <param name="dtype"></param> /// <param name="dtype"></param>
/// <returns></returns> /// <returns></returns>
public Tensor zero_state(Tensor batch_size, TF_DataType dtype)
private Tensor zero_state(Tensor batch_size, TF_DataType dtype)
{ {
Tensor output = null; Tensor output = null;
var state_size = this.state_size;
tf_with(ops.name_scope($"{this.GetType().Name}ZeroState", values: new { batch_size }), delegate
tf_with(ops.name_scope($"{GetType().Name}ZeroState", values: new { batch_size }), delegate
{ {
output = _zero_state_tensors(state_size, batch_size, dtype); output = _zero_state_tensors(state_size, batch_size, dtype);
}); });
@@ -90,20 +89,25 @@ namespace Tensorflow
return output; return output;
} }


private Tensor _zero_state_tensors(int state_size, Tensor batch_size, TF_DataType dtype)
private Tensor _zero_state_tensors(object state_size, Tensor batch_size, TF_DataType dtype)
{ {
var output = nest.map_structure(s =>
if(state_size is int state_size_int)
{ {
var c = rnn_cell_impl._concat(batch_size, s);
var size = array_ops.zeros(c, dtype: dtype);
var output = nest.map_structure(s =>
{
var c = rnn_cell_impl._concat(batch_size, s);
var size = array_ops.zeros(c, dtype: dtype);


var c_static = rnn_cell_impl._concat(batch_size, s, @static: true);
size.set_shape(c_static);
var c_static = rnn_cell_impl._concat(batch_size, s, @static: true);
size.set_shape(c_static);


return size;
}, state_size);
return size;
}, state_size_int);


return output;
return output;
}

throw new NotImplementedException("_zero_state_tensors");
} }
} }
} }

+ 63
- 9
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -29,8 +29,8 @@ namespace Tensorflow.Operations
/// <summary> /// <summary>
/// Creates a bidirectional recurrent neural network. /// Creates a bidirectional recurrent neural network.
/// </summary> /// </summary>
public static void static_bidirectional_rnn(BasicLSTMCell cell_fw,
BasicLSTMCell cell_bw,
public static (Tensor[], LSTMStateTuple, LSTMStateTuple) static_bidirectional_rnn(BasicLstmCell cell_fw,
BasicLstmCell cell_bw,
Tensor[] inputs, Tensor[] inputs,
Tensor initial_state_fw = null, Tensor initial_state_fw = null,
Tensor initial_state_bw = null, Tensor initial_state_bw = null,
@@ -41,12 +41,17 @@ namespace Tensorflow.Operations
if (inputs == null || inputs.Length == 0) if (inputs == null || inputs.Length == 0)
throw new ValueError("inputs must not be empty"); throw new ValueError("inputs must not be empty");


Tensor[] output_fw = null;
Tensor[] output_bw = null;
LSTMStateTuple output_state_fw = null;
LSTMStateTuple output_state_bw = null;

tf_with(tf.variable_scope(scope ?? "bidirectional_rnn"), delegate tf_with(tf.variable_scope(scope ?? "bidirectional_rnn"), delegate
{ {
// Forward direction // Forward direction
tf_with(tf.variable_scope("fw"), fw_scope => tf_with(tf.variable_scope("fw"), fw_scope =>
{ {
static_rnn(
(output_fw, output_state_fw) = static_rnn(
cell_fw, cell_fw,
inputs, inputs,
initial_state_fw, initial_state_fw,
@@ -54,16 +59,48 @@ namespace Tensorflow.Operations
sequence_length, sequence_length,
scope: fw_scope); scope: fw_scope);
}); });

// backward direction
tf_with(tf.variable_scope("bw"), bw_scope =>
{
var reversed_inputs = _reverse_seq(inputs, sequence_length);
(output_bw, output_state_bw) = static_rnn(
cell_bw,
reversed_inputs,
initial_state_bw,
dtype,
sequence_length,
scope: bw_scope);
});
}); });

output_bw = _reverse_seq(output_bw, sequence_length);

var flat_outputs = zip(output_fw, output_bw)
.Select(x => array_ops.concat(new[] { x.Item1, x.Item2 }, 1))
.ToArray();

return (flat_outputs, output_state_fw, output_state_bw);
} }


public static void static_rnn(BasicLSTMCell cell,
private static Tensor[] _reverse_seq(Tensor[] input_seq, Tensor lengths)
{
if (lengths == null)
return input_seq.Reverse().ToArray();

throw new NotImplementedException("_reverse_seq");
}

public static (Tensor[], LSTMStateTuple) static_rnn(BasicLstmCell cell,
Tensor[] inputs, Tensor[] inputs,
Tensor initial_state, Tensor initial_state,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
Tensor sequence_length = null, Tensor sequence_length = null,
VariableScope scope = null) VariableScope scope = null)
{ {
List<Tensor> outputs = new List<Tensor>();
object state = null;

// Create a new scope in which the caching device is either // Create a new scope in which the caching device is either
// determined by the parent scope, or is set to place the cached // determined by the parent scope, or is set to place the cached
// Variable using the same placement as for the rest of the RNN. // Variable using the same placement as for the rest of the RNN.
@@ -73,12 +110,12 @@ namespace Tensorflow.Operations
throw new NotImplementedException("static_rnn"); throw new NotImplementedException("static_rnn");
}); });
else else
tf_with(tf.variable_scope(scope), varscope =>
tf_with(tf.variable_scope(scope), scope1 =>
{ {
Dimension fixed_batch_size = null; Dimension fixed_batch_size = null;
Dimension batch_size = null; Dimension batch_size = null;
Tensor batch_size_tensor = null; Tensor batch_size_tensor = null;
VariableScope varscope = scope1;
// Obtain the first sequence of the input // Obtain the first sequence of the input
var first_input = inputs[0]; var first_input = inputs[0];
if (first_input.TensorShape.rank != 1) if (first_input.TensorShape.rank != 1)
@@ -108,14 +145,31 @@ namespace Tensorflow.Operations
else else
batch_size_tensor = array_ops.shape(first_input)[0]; batch_size_tensor = array_ops.shape(first_input)[0];


Tensor state = null;
if (initial_state != null) if (initial_state != null)
state = initial_state; state = initial_state;
else else
{ {
cell.get_initial_state(batch_size: batch_size_tensor, dtype: dtype);
state = cell.get_initial_state(batch_size: batch_size_tensor, dtype: dtype);
}

Tensor output = null;
if (state is LSTMStateTuple state_tuple)
{
foreach (var (time, input_) in enumerate(inputs))
{
if (time > 0)
varscope.reuse_variables();
if (sequence_length != null)
throw new NotImplementedException("static_rnn");

var results = cell.__call__(input_, state_tuple);
(output, state_tuple) = (results[1], new LSTMStateTuple(results[0], results[1]));
outputs.Add(output);
}
} }
}); });

return (outputs.ToArray(), state as LSTMStateTuple);
} }


public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor, public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor,
@@ -145,7 +199,7 @@ namespace Tensorflow.Operations
if (initial_state != null) if (initial_state != null)
state = initial_state; state = initial_state;
else else
state = cell.get_initial_state(batch_size: batch_size, dtype: dtype);
state = cell.get_initial_state(batch_size: batch_size, dtype: dtype) as Tensor;


var inputs = nest.pack_sequence_as(structure: inputs_tensor, flat_sequence: flat_input); var inputs = nest.pack_sequence_as(structure: inputs_tensor, flat_sequence: flat_input);




+ 15
- 0
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -604,6 +604,11 @@ namespace Tensorflow
return gen_array_ops.concat_v2(values, axis, name: name); return gen_array_ops.concat_v2(values, axis, name: name);
} }


public static Tensor concat(Tensor[] values, Tensor axis, string name = "concat")
{
return gen_array_ops.concat_v2(values, axis, name: name);
}

public static Tensor concat(object[] values, int axis, string name = "concat") public static Tensor concat(object[] values, int axis, string name = "concat")
{ {
return gen_array_ops.concat_v2(values, axis, name: name); return gen_array_ops.concat_v2(values, axis, name: name);
@@ -629,6 +634,16 @@ namespace Tensorflow
}); });
} }


public static Tensor[] split(Tensor value, int num_or_size_splits, Tensor axis,
string name = "split")
{
var size_splits = ops.convert_to_tensor(num_or_size_splits);
return gen_array_ops.split(axis: axis,
num_split: num_or_size_splits,
value: value,
name: name);
}

public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null) public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)
=> gen_array_ops.slice(input, begin, size, name: name); => gen_array_ops.slice(input, begin, size, name: name);




+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -47,7 +47,7 @@ namespace Tensorflow
/// <param name="axis"></param> /// <param name="axis"></param>
/// <param name="name"></param> /// <param name="name"></param>
/// <returns></returns> /// <returns></returns>
public static Tensor concat_v2<T>(T[] values, int axis, string name = null)
public static Tensor concat_v2<T, Ta>(T[] values, Ta axis, string name = null)
{ {
var _op = _op_def_lib._apply_op_helper("ConcatV2", name: name, args: new { values, axis }); var _op = _op_def_lib._apply_op_helper("ConcatV2", name: name, args: new { values, axis });




+ 7
- 5
src/TensorFlowNET.Core/TensorFlow.Binding.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName> <AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace> <RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>1.14.1</TargetTensorFlow> <TargetTensorFlow>1.14.1</TargetTensorFlow>
<Version>0.12.1</Version>
<Version>0.13.0</Version>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
<Company>SciSharp STACK</Company> <Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> <GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -18,14 +18,16 @@
<Description>Google's TensorFlow full binding in .NET Standard. <Description>Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models. Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io</Description> https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.12.1.0</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.11.0:
<AssemblyVersion>0.13.0.0</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.12.0:
1: Add ICanBeFlattened for nest.flatten2. 1: Add ICanBeFlattened for nest.flatten2.
2: Complete the WhileContext. 2: Complete the WhileContext.
3: Add tf.nn.rnn_cell.BasicRNNCell and tf.nn.dynamic_rnn. 3: Add tf.nn.rnn_cell.BasicRNNCell and tf.nn.dynamic_rnn.
4: Add EstimatorSpec.</PackageReleaseNotes>
4: Add EstimatorSpec.
5: Add rnn.static_rnn.
6: Add array_grad._SplitGrad().</PackageReleaseNotes>
<LangVersion>7.3</LangVersion> <LangVersion>7.3</LangVersion>
<FileVersion>0.12.1.0</FileVersion>
<FileVersion>0.13.0.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile> <PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly> <SignAssembly>true</SignAssembly>


+ 0
- 14
src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs View File

@@ -7,20 +7,6 @@ namespace Tensorflow
{ {
public partial class Tensor public partial class Tensor
{ {
/// <summary>
/// Issue unresolved, will cause name_scope problem.
/// </summary>
/// <param name="scalar"></param>
/*public static implicit operator Tensor(double scalar)
{
return constant_op.constant(scalar);
}*/

/*public static implicit operator Tensor(int scalar)
{
return constant_op.constant(scalar);
}*/

public static implicit operator IntPtr(Tensor tensor) public static implicit operator IntPtr(Tensor tensor)
{ {
if (tensor._handle == IntPtr.Zero) if (tensor._handle == IntPtr.Zero)


+ 0
- 8
src/TensorFlowNET.Core/Util/nest.py.cs View File

@@ -526,14 +526,6 @@ namespace Tensorflow.Util
return pack_sequence_as(structure, mapped_flat_structure) as Tensor; return pack_sequence_as(structure, mapped_flat_structure) as Tensor;
} }
public static Tensor map_structure2<T>(Func<T, Tensor> func, T structure)
{
var flat_structure = flatten(structure);
var mapped_flat_structure = flat_structure.Select(func).ToList();
return pack_sequence_as(structure, mapped_flat_structure) as Tensor;
}
/// <summary> /// <summary>
/// Same as map_structure, but with only one structure (no combining of multiple structures) /// Same as map_structure, but with only one structure (no combining of multiple structures)
/// </summary> /// </summary>


+ 5
- 0
src/TensorFlowNET.Core/Variables/VariableScope.cs View File

@@ -74,5 +74,10 @@ namespace Tensorflow
aggregation: aggregation) as RefVariable; aggregation: aggregation) as RefVariable;
}); });
} }

public void reuse_variables()
{
_reuse = _ReuseMode.AUTO_REUSE;
}
} }
} }

+ 1
- 0
src/TensorFlowNET.Core/Variables/_ReuseMode.cs View File

@@ -5,6 +5,7 @@
/// </summary> /// </summary>
public enum _ReuseMode public enum _ReuseMode
{ {
NOT_REUSE = 0,
// Indicates that variables are to be fetched if they already exist or // Indicates that variables are to be fetched if they already exist or
// otherwise created. // otherwise created.
AUTO_REUSE = 1 AUTO_REUSE = 1


Loading…
Cancel
Save