Browse Source

tf.while_loop #348

tags/v0.12
Oceania2018 6 years ago
parent
commit
5ee46e494a
19 changed files with 150 additions and 80 deletions
  1. +32
    -0
      src/TensorFlowNET.Core/Device/c_api.device.cs
  2. +3
    -1
      src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Conv.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Dense.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Embedding.cs
  7. +11
    -3
      src/TensorFlowNET.Core/Keras/Layers/Layer.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs
  9. +10
    -1
      src/TensorFlowNET.Core/Layers/Layer.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
  11. +4
    -8
      src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
  12. +1
    -0
      src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
  13. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
  14. +0
    -2
      src/TensorFlowNET.Core/Operations/Operation.Control.cs
  15. +5
    -0
      src/TensorFlowNET.Core/Operations/Operation.cs
  16. +1
    -1
      src/TensorFlowNET.Core/Operations/random_ops.py.cs
  17. +5
    -0
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  18. +51
    -48
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  19. +20
    -9
      src/TensorFlowNET.Core/ops.cs

+ 32
- 0
src/TensorFlowNET.Core/Device/c_api.device.cs View File

@@ -0,0 +1,32 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Runtime.InteropServices;

namespace Tensorflow
{
public partial class c_api
{
/// <summary>
/// Specify the device for `desc`. Defaults to empty, meaning unconstrained.
/// </summary>
/// <param name="desc"></param>
/// <param name="device"></param>
[DllImport(TensorFlowLibName)]
public static extern void TF_SetDevice(IntPtr desc, string device);
}
}

+ 3
- 1
src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs View File

@@ -69,7 +69,9 @@ namespace Tensorflow
_new_stack = false;
}

_seen_nodes = new List<ITensorOrOperation>();
_seen_nodes = new List<ITensorOrOperation>();
_old_stack = null;
_old_control_flow_context = null;
}

public void add_op(ITensorOrOperation op)


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs View File

@@ -139,7 +139,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
{
Tensor outputs = null;



+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Conv.cs View File

@@ -108,7 +108,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
{
var outputs = _convolution_op.__call__(inputs, kernel);
if (use_bias)


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Dense.cs View File

@@ -72,7 +72,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
{
Tensor outputs = null;
var rank = inputs.rank;


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Embedding.cs View File

@@ -50,7 +50,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
{
var dtype = inputs.dtype;
if (dtype != tf.int32 && dtype != tf.int64)


+ 11
- 3
src/TensorFlowNET.Core/Keras/Layers/Layer.cs View File

@@ -52,6 +52,7 @@ namespace Tensorflow.Keras.Layers
protected InputSpec input_spec;
protected bool supports_masking;
protected List<VariableV1> _trainable_weights;
protected List<VariableV1> _non_trainable_weights;
private string _name;
public string name => _name;
protected string _base_name;
@@ -84,6 +85,7 @@ namespace Tensorflow.Keras.Layers

_init_set_name(name);
_trainable_weights = new List<VariableV1>();
_non_trainable_weights = new List<VariableV1>();
_compute_previous_mask = false;
_updates = new List<Operation>();

@@ -103,6 +105,7 @@ namespace Tensorflow.Keras.Layers

public (Tensor, Tensor) __call__(Tensor[] inputs,
Tensor training = null,
Tensor state = null,
VariableScope scope = null)
{
var input_list = inputs;
@@ -139,7 +142,9 @@ namespace Tensorflow.Keras.Layers
// overridden).
_maybe_build(inputs[0]);

(input, outputs) = call(inputs[0], training: training);
(input, outputs) = call(inputs[0],
training: training,
state: state);
(input, outputs) = _set_connectivity_metadata_(input, outputs);
_handle_activity_regularization(inputs[0], outputs);
_set_mask_metadata(inputs[0], outputs, null);
@@ -173,7 +178,7 @@ namespace Tensorflow.Keras.Layers
return null;
}

protected virtual (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
protected virtual (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
{
return (inputs, inputs);
}
@@ -233,7 +238,10 @@ namespace Tensorflow.Keras.Layers
initializer: initializer,
trainable: trainable.Value);
//backend.track_variable(variable);
_trainable_weights.Add(variable);
if (trainable == true)
_trainable_weights.Add(variable);
else
_non_trainable_weights.Add(variable);

return variable;
}


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs View File

@@ -43,7 +43,7 @@ namespace Tensorflow.Keras.Layers
this.input_spec = new InputSpec(ndim: 4);
}

protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null)
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
{
int[] pool_shape;
if (data_format == "channels_last")


+ 10
- 1
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -43,6 +43,7 @@ namespace Tensorflow.Layers

// Avoid an incorrect lint error
_trainable_weights = new List<VariableV1>();
_non_trainable_weights = new List<VariableV1>();
this.built = false;
_keras_style = false;
}
@@ -54,6 +55,7 @@ namespace Tensorflow.Layers

public (Tensor, Tensor) __call__(Tensor inputs,
Tensor training = null,
Tensor state = null,
VariableScope scope = null)
{
_set_scope(scope);
@@ -76,7 +78,9 @@ namespace Tensorflow.Layers
{
_current_scope = scope2;
// Actually call layer
outputs = base.__call__(new Tensor[] { inputs }, training: training);
outputs = base.__call__(new Tensor[] { inputs },
training: training,
state: state);
});


@@ -121,6 +125,11 @@ namespace Tensorflow.Layers
Graph init_graph = null;
VariableV1[] existing_variables = null;

if (synchronization == VariableSynchronization.OnRead)
trainable = false;
else if (!trainable.HasValue)
trainable = true;

if (default_graph.building_function)
{
throw new NotImplementedException("add_weight");


+ 1
- 1
src/TensorFlowNET.Core/Operations/BasicRNNCell.cs View File

@@ -66,7 +66,7 @@ namespace Tensorflow
built = true;
}

protected override (Tensor, Tensor) call(Tensor inputs, Tensor state = null)
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
{
// Most basic RNN: output = new_state = act(W * input + U * state + B).
var concat = array_ops.concat(new[] { inputs, state }, 1);


+ 4
- 8
src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs View File

@@ -307,12 +307,6 @@ namespace Tensorflow.Operations

protected override void _AddOpInternal(Operation op)
{
if(op.name == "rnn/while/basic_rnn_cell/MatMul" ||
op.name == "rnn/while/TensorArrayReadV3")
{

}

Operation[] external_inputs = new Operation[0];
if (op.inputs.Length == 0)
{
@@ -412,10 +406,12 @@ namespace Tensorflow.Operations
}

if (_outer_context != null)
{
result = _outer_context.AddValue(val);
}

if (tf.get_default_graph()._nodes_by_name.Count >= 83)
{

}
// Create an Enter to make `result` known to this loop context.
Tensor enter = null;
tf_with(ops.control_dependencies(null), delegate


+ 1
- 0
src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs View File

@@ -16,6 +16,7 @@

using System;
using System.Linq;
using static Tensorflow.Binding;

namespace Tensorflow.Operations.Initializers
{


+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -214,7 +214,7 @@ namespace Tensorflow.Operations
if (sequence_length != null)
throw new NotImplementedException("sequence_length != null");
else
a = cell.__call__(input_t_t, state1);
a = cell.__call__(input_t_t, state: state1);

return item;
};


+ 0
- 2
src/TensorFlowNET.Core/Operations/Operation.Control.cs View File

@@ -32,9 +32,7 @@ namespace Tensorflow
public void _control_flow_post_processing()
{
foreach(Tensor input_tensor in inputs)
{
control_flow_util.CheckInputFromValidContext(this, input_tensor.op);
}
if (_control_flow_context != null)
_control_flow_context.AddOp(this);


+ 5
- 0
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -78,6 +78,7 @@ namespace Tensorflow
#if SERIALIZABLE
[JsonIgnore]
#endif
bool _is_stateful;
public NodeDef node_def
{
get
@@ -173,6 +174,8 @@ namespace Tensorflow
}
}

_id_value = _graph._next_id();
// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();
@@ -184,6 +187,8 @@ namespace Tensorflow
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());

_is_stateful = op_def.IsStateful;

// Initialize self._outputs.
output_types = new TF_DataType[NumOutputs];
for (int i = 0; i < NumOutputs; i++)


+ 1
- 1
src/TensorFlowNET.Core/Operations/random_ops.py.cs View File

@@ -71,7 +71,7 @@ namespace Tensorflow
return tf_with(ops.name_scope(name, "random_uniform", new { shape, minval, maxval }), scope =>
{
name = scope;
var tensorShape = _ShapeTensor(shape);
var tensorShape = tensor_util.shape_tensor(shape);
var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min");
var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max");
var rnd = gen_random_ops.random_uniform(tensorShape, dtype);


+ 5
- 0
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -335,5 +335,10 @@ namespace Tensorflow

return shape;
}

public static Tensor shape_tensor(int[] shape)
{
return ops.convert_to_tensor(shape, dtype: TF_DataType.TF_INT32, name: "shape");
}
}
}

+ 51
- 48
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -133,66 +133,69 @@ namespace Tensorflow
if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES))
collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES);

ops.init_scope();
var values = init_from_fn ? new object[0] : new object[] { initial_value };
tf_with(ops.name_scope(name, "Variable", values), scope =>
tf_with(ops.init_scope2(), delegate
{
name = scope;
if (init_from_fn)
var values = init_from_fn ? new object[0] : new object[] { initial_value };
tf_with(ops.name_scope(name, "Variable", values), scope =>
{
// Use attr_scope and device(None) to simulate the behavior of
// colocate_with when the variable we want to colocate with doesn't
// yet exist.
string true_name = ops.name_from_scope_name(name);
var attr = new AttrValue
name = scope;

if (init_from_fn)
{
List = new AttrValue.Types.ListValue()
};
attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}"));
tf_with(ops.name_scope("Initializer"), scope2 =>
// Use attr_scope and device(None) to simulate the behavior of
// colocate_with when the variable we want to colocate with doesn't
// yet exist.
string true_name = ops.name_from_scope_name(name);
var attr = new AttrValue
{
List = new AttrValue.Types.ListValue()
};
attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}"));
tf_with(ops.name_scope("Initializer"), scope2 =>
{
_initial_value = (initial_value as Func<Tensor>)();
_initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype);
});
_variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
}
// Or get the initial value from a Tensor or Python object.
else
{
_initial_value = (initial_value as Func<Tensor>)();
_initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype);
});
_variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
}
// Or get the initial value from a Tensor or Python object.
else
{
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype);
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype);

var shape = _initial_value.shape;
dtype = _initial_value.dtype;
_variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope);
}
var shape = _initial_value.shape;
dtype = _initial_value.dtype;
_variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope);
}

// Manually overrides the variable's shape with the initial value's.
if (validate_shape)
{
var initial_value_shape = _initial_value.TensorShape;
if (!initial_value_shape.is_fully_defined())
throw new ValueError($"initial_value must have a shape specified: {_initial_value}");
}
// Manually overrides the variable's shape with the initial value's.
if (validate_shape)
{
var initial_value_shape = _initial_value.TensorShape;
if (!initial_value_shape.is_fully_defined())
throw new ValueError($"initial_value must have a shape specified: {_initial_value}");
}

// If 'initial_value' makes use of other variables, make sure we don't
// have an issue if these other variables aren't initialized first by
// using their initialized_value() method.
var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value);
// If 'initial_value' makes use of other variables, make sure we don't
// have an issue if these other variables aren't initialized first by
// using their initialized_value() method.
var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value);

_initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op;
_initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op;

if (!String.IsNullOrEmpty(caching_device))
{
if (!String.IsNullOrEmpty(caching_device))
{

}
else
{
ops.colocate_with(_initializer_op);
}
else
{
ops.colocate_with(_initializer_op);

_snapshot = gen_array_ops.identity(_variable, name = "read");
}
_snapshot = gen_array_ops.identity(_variable, name = "read");
}

ops.add_to_collections(collections, this as VariableV1);
ops.add_to_collections(collections, this as VariableV1);
});
});
}



+ 20
- 9
src/TensorFlowNET.Core/ops.cs View File

@@ -186,12 +186,7 @@ namespace Tensorflow
/// operations constructed within the context.
/// </returns>
public static _ControlDependenciesController control_dependencies(object[] control_inputs)
{
return get_default_graph().control_dependencies(control_inputs);
}

public static _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
=> control_dependencies(control_inputs == null ? null : control_inputs.OfType<object>().ToArray());
=> get_default_graph().control_dependencies(control_inputs);

/// <summary>
/// Creates a TF_Operation.
@@ -212,9 +207,9 @@ namespace Tensorflow
{
var op_desc = graph.NewOperation(node_def.Op, node_def.Name);

//TODO: Implement TF_SetDevice
//if node_def.device:
// c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device))
if (!string.IsNullOrEmpty(node_def.Device))
c_api.TF_SetDevice(op_desc, node_def.Device);
// Add inputs
foreach (var op_input in inputs)
{
@@ -310,6 +305,22 @@ namespace Tensorflow
});
}

public static IObjectLife init_scope2()
{
// Retrieve the active name scope: entering an `init_scope` preserves
// the name scope of the current context.
var default_graph = get_default_graph();
var scope = default_graph.get_name_scope();
if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/"))
// Names that end with trailing slashes are treated by `name_scope` as
// absolute.
scope += "/";
// inner_device_stack = default_graph._device_function_stack
// var outer_context = default_graph.as_default;

return ops.control_dependencies(null);
}

private static int uid_number = 0;

/// <summary>


Loading…
Cancel
Save