Browse Source

fix: some possible errors of RNN.

pull/1106/head
Yaohui Liu Wanglongzhi2001 2 years ago
parent
commit
c1d07bf9b8
2 changed files with 46 additions and 35 deletions
  1. +31
    -10
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  2. +15
    -25
      src/TensorFlowNET.Keras/BackendImpl.cs

+ 31
- 10
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -58,17 +58,12 @@ namespace Tensorflow
public Tensor this[params string[] slices]
=> this.First()[slices];

public Tensors(Tensor tensor) : base(tensor)
{

}

private Tensors(Nest<Tensor> nested) : base(nested)
{

}

public Tensors(params Tensor[] tensors): base(tensors.Select(x => new Nest<Tensor>(x)))
public Tensors(params Tensor[] tensors): base(DealWithConstructorArrayInput(tensors))
{
}
@@ -83,6 +78,22 @@ namespace Tensorflow
}

private static Nest<Tensor> DealWithConstructorArrayInput(Tensor[] tensors)
{
if (tensors.Length == 0)
{
return Nest<Tensor>.Empty;
}
else if(tensors.Length == 1)
{
return new Nest<Tensor>(tensors[0]);
}
else
{
return new Nest<Tensor>(tensors.Select(x => new Nest<Tensor>(x)));
}
}

public bool IsSingle()
{
return Length == 1;
@@ -107,9 +118,14 @@ namespace Tensorflow
ListValue = new() { new Nest<Tensor>(Value), new Nest<Tensor>(tensor) };
Value = null;
}
else
else if(NestType == NestType.List)
{
ListValue!.Add(new Nest<Tensor>(tensor));
}
else //Empty
{
ListValue.Add(new Nest<Tensor>(tensor));
NestType = NestType.Node;
Value = tensor;
}
}

@@ -128,9 +144,14 @@ namespace Tensorflow
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
Value = null;
}
else
else if(NestType == NestType.List)
{
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
ListValue!.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
}
else // empty
{
NestType = NestType.List;
ListValue = tensors.Select(x => new Nest<Tensor>(x)).ToList();
}
}



+ 15
- 25
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -651,13 +651,13 @@ namespace Tensorflow.Keras
states = Nest.PackSequenceAs(states, flat_final_states).ToTensors();
if (return_all_outputs)
{
successive_outputs.Add(output);
successive_states.Add(states);
successive_outputs = successive_outputs.MergeWith(output);
successive_outputs = successive_states.MergeWith(states);
}
else
{
successive_outputs = new Tensors { output };
successive_states = new Tensors { states };
successive_outputs = new Tensors(output);
successive_states = new Tensors(states);
}

}
@@ -722,16 +722,11 @@ namespace Tensorflow.Keras
// Get the time(0) input and compute the output for that, the output will
// be used to determine the dtype of output tensor array. Don't read from
// input_ta due to TensorArray clear_after_read default to True.
var inps = new Tensors();
foreach (var inp in flatted_inptus)
{
inps.Add(inp[0]);
}
var input_time_zero = Nest.PackSequenceAs(inputs, inps).ToTensors();
var input_time_zero = Nest.PackSequenceAs(inputs, flatted_inptus.Select(x => x[0]).ToArray()).ToTensors();

// output_time_zero is used to determine the cell output shape and its
// dtype. the value is discarded.
(output_time_zero, _) = step_function((Tensor)input_time_zero,
(output_time_zero, _) = step_function(input_time_zero,
constants is null ? initial_states : initial_states.MergeWith(constants));

int output_ta_size = return_all_outputs ? time_steps_t : 1;
@@ -816,6 +811,7 @@ namespace Tensorflow.Keras

Func<Tensor, Tensor> cond = (time) => (time < time_steps_t);
int parallel_iterations = 32;
new_states = states;
if (masking_fn != null)
{
// Mask for the T output will be base on the output of T - 1. In the
@@ -846,7 +842,7 @@ namespace Tensorflow.Keras
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors();
var mask_t = masking_fn(time);
var (output, new_states_internal) = step_function(current_input, states.MergeWith(constants));
var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants));
// mask output
var flat_output = Nest.Flatten(output).ToList();

@@ -871,11 +867,12 @@ namespace Tensorflow.Keras
new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors();

var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
// TODO(Wanglongzhi2001),deal with zip output_ta_t
foreach (var (ta, Out) in zip(output_ta_t, flat_new_output))
output_ta_t = zip(output_ta_t, flat_new_output).Select(item =>
{
output_ta_t.Add(ta.write(ta_index_to_write, Out));
}
var (ta, out_) = item;
return ta.write(ta_index_to_write, out_);
}).ToList();


new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();

@@ -921,15 +918,8 @@ namespace Tensorflow.Keras
}
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations);
}
//Tensors outputs = new Tensors();
foreach (var o in output_ta)
{
outputs.Add(o.stack());
}
foreach (var o in outputs)
{
last_output.Add(o[-1]);
}
outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToTensors());
last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToTensors());
outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors();
last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors();



Loading…
Cancel
Save