@@ -0,0 +1,32 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using System.Runtime.InteropServices; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class c_api | |||||
{ | |||||
/// <summary> | |||||
/// Specify the device for `desc`. Defaults to empty, meaning unconstrained. | |||||
/// </summary> | |||||
/// <param name="desc"></param> | |||||
/// <param name="device"></param> | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern void TF_SetDevice(IntPtr desc, string device); | |||||
} | |||||
} |
@@ -69,7 +69,9 @@ namespace Tensorflow | |||||
_new_stack = false; | _new_stack = false; | ||||
} | } | ||||
_seen_nodes = new List<ITensorOrOperation>(); | |||||
_seen_nodes = new List<ITensorOrOperation>(); | |||||
_old_stack = null; | |||||
_old_control_flow_context = null; | |||||
} | } | ||||
public void add_op(ITensorOrOperation op) | public void add_op(ITensorOrOperation op) | ||||
@@ -139,7 +139,7 @@ namespace Tensorflow.Keras.Layers | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) | |||||
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) | |||||
{ | { | ||||
Tensor outputs = null; | Tensor outputs = null; | ||||
@@ -108,7 +108,7 @@ namespace Tensorflow.Keras.Layers | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) | |||||
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) | |||||
{ | { | ||||
var outputs = _convolution_op.__call__(inputs, kernel); | var outputs = _convolution_op.__call__(inputs, kernel); | ||||
if (use_bias) | if (use_bias) | ||||
@@ -72,7 +72,7 @@ namespace Tensorflow.Keras.Layers | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) | |||||
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) | |||||
{ | { | ||||
Tensor outputs = null; | Tensor outputs = null; | ||||
var rank = inputs.rank; | var rank = inputs.rank; | ||||
@@ -50,7 +50,7 @@ namespace Tensorflow.Keras.Layers | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) | |||||
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) | |||||
{ | { | ||||
var dtype = inputs.dtype; | var dtype = inputs.dtype; | ||||
if (dtype != tf.int32 && dtype != tf.int64) | if (dtype != tf.int32 && dtype != tf.int64) | ||||
@@ -52,6 +52,7 @@ namespace Tensorflow.Keras.Layers | |||||
protected InputSpec input_spec; | protected InputSpec input_spec; | ||||
protected bool supports_masking; | protected bool supports_masking; | ||||
protected List<VariableV1> _trainable_weights; | protected List<VariableV1> _trainable_weights; | ||||
protected List<VariableV1> _non_trainable_weights; | |||||
private string _name; | private string _name; | ||||
public string name => _name; | public string name => _name; | ||||
protected string _base_name; | protected string _base_name; | ||||
@@ -84,6 +85,7 @@ namespace Tensorflow.Keras.Layers | |||||
_init_set_name(name); | _init_set_name(name); | ||||
_trainable_weights = new List<VariableV1>(); | _trainable_weights = new List<VariableV1>(); | ||||
_non_trainable_weights = new List<VariableV1>(); | |||||
_compute_previous_mask = false; | _compute_previous_mask = false; | ||||
_updates = new List<Operation>(); | _updates = new List<Operation>(); | ||||
@@ -103,6 +105,7 @@ namespace Tensorflow.Keras.Layers | |||||
public (Tensor, Tensor) __call__(Tensor[] inputs, | public (Tensor, Tensor) __call__(Tensor[] inputs, | ||||
Tensor training = null, | Tensor training = null, | ||||
Tensor state = null, | |||||
VariableScope scope = null) | VariableScope scope = null) | ||||
{ | { | ||||
var input_list = inputs; | var input_list = inputs; | ||||
@@ -139,7 +142,9 @@ namespace Tensorflow.Keras.Layers | |||||
// overridden). | // overridden). | ||||
_maybe_build(inputs[0]); | _maybe_build(inputs[0]); | ||||
(input, outputs) = call(inputs[0], training: training); | |||||
(input, outputs) = call(inputs[0], | |||||
training: training, | |||||
state: state); | |||||
(input, outputs) = _set_connectivity_metadata_(input, outputs); | (input, outputs) = _set_connectivity_metadata_(input, outputs); | ||||
_handle_activity_regularization(inputs[0], outputs); | _handle_activity_regularization(inputs[0], outputs); | ||||
_set_mask_metadata(inputs[0], outputs, null); | _set_mask_metadata(inputs[0], outputs, null); | ||||
@@ -173,7 +178,7 @@ namespace Tensorflow.Keras.Layers | |||||
return null; | return null; | ||||
} | } | ||||
protected virtual (Tensor, Tensor) call(Tensor inputs, Tensor training = null) | |||||
protected virtual (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) | |||||
{ | { | ||||
return (inputs, inputs); | return (inputs, inputs); | ||||
} | } | ||||
@@ -233,7 +238,10 @@ namespace Tensorflow.Keras.Layers | |||||
initializer: initializer, | initializer: initializer, | ||||
trainable: trainable.Value); | trainable: trainable.Value); | ||||
//backend.track_variable(variable); | //backend.track_variable(variable); | ||||
_trainable_weights.Add(variable); | |||||
if (trainable == true) | |||||
_trainable_weights.Add(variable); | |||||
else | |||||
_non_trainable_weights.Add(variable); | |||||
return variable; | return variable; | ||||
} | } | ||||
@@ -43,7 +43,7 @@ namespace Tensorflow.Keras.Layers | |||||
this.input_spec = new InputSpec(ndim: 4); | this.input_spec = new InputSpec(ndim: 4); | ||||
} | } | ||||
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null) | |||||
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) | |||||
{ | { | ||||
int[] pool_shape; | int[] pool_shape; | ||||
if (data_format == "channels_last") | if (data_format == "channels_last") | ||||
@@ -43,6 +43,7 @@ namespace Tensorflow.Layers | |||||
// Avoid an incorrect lint error | // Avoid an incorrect lint error | ||||
_trainable_weights = new List<VariableV1>(); | _trainable_weights = new List<VariableV1>(); | ||||
_non_trainable_weights = new List<VariableV1>(); | |||||
this.built = false; | this.built = false; | ||||
_keras_style = false; | _keras_style = false; | ||||
} | } | ||||
@@ -54,6 +55,7 @@ namespace Tensorflow.Layers | |||||
public (Tensor, Tensor) __call__(Tensor inputs, | public (Tensor, Tensor) __call__(Tensor inputs, | ||||
Tensor training = null, | Tensor training = null, | ||||
Tensor state = null, | |||||
VariableScope scope = null) | VariableScope scope = null) | ||||
{ | { | ||||
_set_scope(scope); | _set_scope(scope); | ||||
@@ -76,7 +78,9 @@ namespace Tensorflow.Layers | |||||
{ | { | ||||
_current_scope = scope2; | _current_scope = scope2; | ||||
// Actually call layer | // Actually call layer | ||||
outputs = base.__call__(new Tensor[] { inputs }, training: training); | |||||
outputs = base.__call__(new Tensor[] { inputs }, | |||||
training: training, | |||||
state: state); | |||||
}); | }); | ||||
@@ -121,6 +125,11 @@ namespace Tensorflow.Layers | |||||
Graph init_graph = null; | Graph init_graph = null; | ||||
VariableV1[] existing_variables = null; | VariableV1[] existing_variables = null; | ||||
if (synchronization == VariableSynchronization.OnRead) | |||||
trainable = false; | |||||
else if (!trainable.HasValue) | |||||
trainable = true; | |||||
if (default_graph.building_function) | if (default_graph.building_function) | ||||
{ | { | ||||
throw new NotImplementedException("add_weight"); | throw new NotImplementedException("add_weight"); | ||||
@@ -66,7 +66,7 @@ namespace Tensorflow | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override (Tensor, Tensor) call(Tensor inputs, Tensor state = null) | |||||
protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) | |||||
{ | { | ||||
// Most basic RNN: output = new_state = act(W * input + U * state + B). | // Most basic RNN: output = new_state = act(W * input + U * state + B). | ||||
var concat = array_ops.concat(new[] { inputs, state }, 1); | var concat = array_ops.concat(new[] { inputs, state }, 1); | ||||
@@ -307,12 +307,6 @@ namespace Tensorflow.Operations | |||||
protected override void _AddOpInternal(Operation op) | protected override void _AddOpInternal(Operation op) | ||||
{ | { | ||||
if(op.name == "rnn/while/basic_rnn_cell/MatMul" || | |||||
op.name == "rnn/while/TensorArrayReadV3") | |||||
{ | |||||
} | |||||
Operation[] external_inputs = new Operation[0]; | Operation[] external_inputs = new Operation[0]; | ||||
if (op.inputs.Length == 0) | if (op.inputs.Length == 0) | ||||
{ | { | ||||
@@ -412,10 +406,12 @@ namespace Tensorflow.Operations | |||||
} | } | ||||
if (_outer_context != null) | if (_outer_context != null) | ||||
{ | |||||
result = _outer_context.AddValue(val); | result = _outer_context.AddValue(val); | ||||
} | |||||
if (tf.get_default_graph()._nodes_by_name.Count >= 83) | |||||
{ | |||||
} | |||||
// Create an Enter to make `result` known to this loop context. | // Create an Enter to make `result` known to this loop context. | ||||
Tensor enter = null; | Tensor enter = null; | ||||
tf_with(ops.control_dependencies(null), delegate | tf_with(ops.control_dependencies(null), delegate | ||||
@@ -16,6 +16,7 @@ | |||||
using System; | using System; | ||||
using System.Linq; | using System.Linq; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
{ | { | ||||
@@ -214,7 +214,7 @@ namespace Tensorflow.Operations | |||||
if (sequence_length != null) | if (sequence_length != null) | ||||
throw new NotImplementedException("sequence_length != null"); | throw new NotImplementedException("sequence_length != null"); | ||||
else | else | ||||
a = cell.__call__(input_t_t, state1); | |||||
a = cell.__call__(input_t_t, state: state1); | |||||
return item; | return item; | ||||
}; | }; | ||||
@@ -32,9 +32,7 @@ namespace Tensorflow | |||||
public void _control_flow_post_processing() | public void _control_flow_post_processing() | ||||
{ | { | ||||
foreach(Tensor input_tensor in inputs) | foreach(Tensor input_tensor in inputs) | ||||
{ | |||||
control_flow_util.CheckInputFromValidContext(this, input_tensor.op); | control_flow_util.CheckInputFromValidContext(this, input_tensor.op); | ||||
} | |||||
if (_control_flow_context != null) | if (_control_flow_context != null) | ||||
_control_flow_context.AddOp(this); | _control_flow_context.AddOp(this); | ||||
@@ -78,6 +78,7 @@ namespace Tensorflow | |||||
#if SERIALIZABLE | #if SERIALIZABLE | ||||
[JsonIgnore] | [JsonIgnore] | ||||
#endif | #endif | ||||
bool _is_stateful; | |||||
public NodeDef node_def | public NodeDef node_def | ||||
{ | { | ||||
get | get | ||||
@@ -173,6 +174,8 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
_id_value = _graph._next_id(); | |||||
// Dict mapping op name to file and line information for op colocation | // Dict mapping op name to file and line information for op colocation | ||||
// context managers. | // context managers. | ||||
_control_flow_context = graph._get_control_flow_context(); | _control_flow_context = graph._get_control_flow_context(); | ||||
@@ -184,6 +187,8 @@ namespace Tensorflow | |||||
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | ||||
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | ||||
_is_stateful = op_def.IsStateful; | |||||
// Initialize self._outputs. | // Initialize self._outputs. | ||||
output_types = new TF_DataType[NumOutputs]; | output_types = new TF_DataType[NumOutputs]; | ||||
for (int i = 0; i < NumOutputs; i++) | for (int i = 0; i < NumOutputs; i++) | ||||
@@ -71,7 +71,7 @@ namespace Tensorflow | |||||
return tf_with(ops.name_scope(name, "random_uniform", new { shape, minval, maxval }), scope => | return tf_with(ops.name_scope(name, "random_uniform", new { shape, minval, maxval }), scope => | ||||
{ | { | ||||
name = scope; | name = scope; | ||||
var tensorShape = _ShapeTensor(shape); | |||||
var tensorShape = tensor_util.shape_tensor(shape); | |||||
var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min"); | var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min"); | ||||
var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max"); | var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max"); | ||||
var rnd = gen_random_ops.random_uniform(tensorShape, dtype); | var rnd = gen_random_ops.random_uniform(tensorShape, dtype); | ||||
@@ -335,5 +335,10 @@ namespace Tensorflow | |||||
return shape; | return shape; | ||||
} | } | ||||
public static Tensor shape_tensor(int[] shape) | |||||
{ | |||||
return ops.convert_to_tensor(shape, dtype: TF_DataType.TF_INT32, name: "shape"); | |||||
} | |||||
} | } | ||||
} | } |
@@ -133,66 +133,69 @@ namespace Tensorflow | |||||
if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) | if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) | ||||
collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); | collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); | ||||
ops.init_scope(); | |||||
var values = init_from_fn ? new object[0] : new object[] { initial_value }; | |||||
tf_with(ops.name_scope(name, "Variable", values), scope => | |||||
tf_with(ops.init_scope2(), delegate | |||||
{ | { | ||||
name = scope; | |||||
if (init_from_fn) | |||||
var values = init_from_fn ? new object[0] : new object[] { initial_value }; | |||||
tf_with(ops.name_scope(name, "Variable", values), scope => | |||||
{ | { | ||||
// Use attr_scope and device(None) to simulate the behavior of | |||||
// colocate_with when the variable we want to colocate with doesn't | |||||
// yet exist. | |||||
string true_name = ops.name_from_scope_name(name); | |||||
var attr = new AttrValue | |||||
name = scope; | |||||
if (init_from_fn) | |||||
{ | { | ||||
List = new AttrValue.Types.ListValue() | |||||
}; | |||||
attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}")); | |||||
tf_with(ops.name_scope("Initializer"), scope2 => | |||||
// Use attr_scope and device(None) to simulate the behavior of | |||||
// colocate_with when the variable we want to colocate with doesn't | |||||
// yet exist. | |||||
string true_name = ops.name_from_scope_name(name); | |||||
var attr = new AttrValue | |||||
{ | |||||
List = new AttrValue.Types.ListValue() | |||||
}; | |||||
attr.List.S.Add(ByteString.CopyFromUtf8($"loc:{true_name}")); | |||||
tf_with(ops.name_scope("Initializer"), scope2 => | |||||
{ | |||||
_initial_value = (initial_value as Func<Tensor>)(); | |||||
_initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype); | |||||
}); | |||||
_variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); | |||||
} | |||||
// Or get the initial value from a Tensor or Python object. | |||||
else | |||||
{ | { | ||||
_initial_value = (initial_value as Func<Tensor>)(); | |||||
_initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype); | |||||
}); | |||||
_variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); | |||||
} | |||||
// Or get the initial value from a Tensor or Python object. | |||||
else | |||||
{ | |||||
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype); | |||||
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value", dtype: dtype); | |||||
var shape = _initial_value.shape; | |||||
dtype = _initial_value.dtype; | |||||
_variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope); | |||||
} | |||||
var shape = _initial_value.shape; | |||||
dtype = _initial_value.dtype; | |||||
_variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope); | |||||
} | |||||
// Manually overrides the variable's shape with the initial value's. | |||||
if (validate_shape) | |||||
{ | |||||
var initial_value_shape = _initial_value.TensorShape; | |||||
if (!initial_value_shape.is_fully_defined()) | |||||
throw new ValueError($"initial_value must have a shape specified: {_initial_value}"); | |||||
} | |||||
// Manually overrides the variable's shape with the initial value's. | |||||
if (validate_shape) | |||||
{ | |||||
var initial_value_shape = _initial_value.TensorShape; | |||||
if (!initial_value_shape.is_fully_defined()) | |||||
throw new ValueError($"initial_value must have a shape specified: {_initial_value}"); | |||||
} | |||||
// If 'initial_value' makes use of other variables, make sure we don't | |||||
// have an issue if these other variables aren't initialized first by | |||||
// using their initialized_value() method. | |||||
var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value); | |||||
// If 'initial_value' makes use of other variables, make sure we don't | |||||
// have an issue if these other variables aren't initialized first by | |||||
// using their initialized_value() method. | |||||
var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value); | |||||
_initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op; | |||||
_initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op; | |||||
if (!String.IsNullOrEmpty(caching_device)) | |||||
{ | |||||
if (!String.IsNullOrEmpty(caching_device)) | |||||
{ | |||||
} | |||||
else | |||||
{ | |||||
ops.colocate_with(_initializer_op); | |||||
} | |||||
else | |||||
{ | |||||
ops.colocate_with(_initializer_op); | |||||
_snapshot = gen_array_ops.identity(_variable, name = "read"); | |||||
} | |||||
_snapshot = gen_array_ops.identity(_variable, name = "read"); | |||||
} | |||||
ops.add_to_collections(collections, this as VariableV1); | |||||
ops.add_to_collections(collections, this as VariableV1); | |||||
}); | |||||
}); | }); | ||||
} | } | ||||
@@ -186,12 +186,7 @@ namespace Tensorflow | |||||
/// operations constructed within the context. | /// operations constructed within the context. | ||||
/// </returns> | /// </returns> | ||||
public static _ControlDependenciesController control_dependencies(object[] control_inputs) | public static _ControlDependenciesController control_dependencies(object[] control_inputs) | ||||
{ | |||||
return get_default_graph().control_dependencies(control_inputs); | |||||
} | |||||
public static _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | |||||
=> control_dependencies(control_inputs == null ? null : control_inputs.OfType<object>().ToArray()); | |||||
=> get_default_graph().control_dependencies(control_inputs); | |||||
/// <summary> | /// <summary> | ||||
/// Creates a TF_Operation. | /// Creates a TF_Operation. | ||||
@@ -212,9 +207,9 @@ namespace Tensorflow | |||||
{ | { | ||||
var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | ||||
//TODO: Implement TF_SetDevice | |||||
//if node_def.device: | |||||
// c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device)) | |||||
if (!string.IsNullOrEmpty(node_def.Device)) | |||||
c_api.TF_SetDevice(op_desc, node_def.Device); | |||||
// Add inputs | // Add inputs | ||||
foreach (var op_input in inputs) | foreach (var op_input in inputs) | ||||
{ | { | ||||
@@ -310,6 +305,22 @@ namespace Tensorflow | |||||
}); | }); | ||||
} | } | ||||
public static IObjectLife init_scope2() | |||||
{ | |||||
// Retrieve the active name scope: entering an `init_scope` preserves | |||||
// the name scope of the current context. | |||||
var default_graph = get_default_graph(); | |||||
var scope = default_graph.get_name_scope(); | |||||
if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/")) | |||||
// Names that end with trailing slashes are treated by `name_scope` as | |||||
// absolute. | |||||
scope += "/"; | |||||
// inner_device_stack = default_graph._device_function_stack | |||||
// var outer_context = default_graph.as_default; | |||||
return ops.control_dependencies(null); | |||||
} | |||||
private static int uid_number = 0; | private static int uid_number = 0; | ||||
/// <summary> | /// <summary> | ||||