diff --git a/src/TensorFlowNET.Core/APIs/tf.ops.cs b/src/TensorFlowNET.Core/APIs/tf.ops.cs
index 86e979c4..c651bba9 100644
--- a/src/TensorFlowNET.Core/APIs/tf.ops.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.ops.cs
@@ -30,10 +30,7 @@ namespace Tensorflow
public Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
=> state_ops.assign(@ref, value, validate_shape, use_locking, name);
- public Tensor assign(RefVariable @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
- => state_ops.assign(@ref, value, validate_shape, use_locking, name);
-
- public Tensor assign(ResourceVariable @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
+ public Tensor assign(IVariableV1 @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
=> state_ops.assign(@ref, value, validate_shape, use_locking, name);
public void device(string device_name)
diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs
index aec7471b..530c4b27 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs
@@ -121,7 +121,7 @@ namespace Tensorflow.Keras.Engine
///
///
///
- public Tensor Apply(Tensor inputs, bool is_training = false)
+ public Tensor Apply(Tensor inputs, bool is_training = false, Tensor state = null)
{
Tensor outputs = null;
@@ -135,9 +135,9 @@ namespace Tensorflow.Keras.Engine
string nameScope = "";
if (eager)
- {
nameScope = name;
- }
+ else
+ nameScope = _name_scope();
// using var graph = tf.keras.backend.get_graph().as_default();
if (!inputs.IsEagerTensor)
@@ -148,7 +148,7 @@ namespace Tensorflow.Keras.Engine
if (!built)
MaybeBuild(inputs);
- outputs = call(inputs, is_training: is_training);
+ outputs = call(inputs, is_training: is_training, state: state);
outputs = _set_connectivity_metadata_(inputs, outputs);
_handle_activity_regularization(inputs, outputs);
diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs
index b9d73cf8..43fd90bc 100644
--- a/src/TensorFlowNET.Core/Layers/Layer.cs
+++ b/src/TensorFlowNET.Core/Layers/Layer.cs
@@ -88,7 +88,9 @@ namespace Tensorflow.Layers
{
_current_scope = scope2;
// Actually call layer
- outputs = base.Apply(inputs);
+ outputs = base.Apply(inputs,
+ is_training: training == null ? false : false,
+ state: state);
});
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
index 55589e64..592be625 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
@@ -71,8 +71,8 @@ namespace Tensorflow
{
// Most basic RNN: output = new_state = act(W * input + U * state + B).
var concat = array_ops.concat(new[] { inputs, state }, 1);
- var gate_inputs = math_ops.matmul(concat, _kernel as RefVariable);
- gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable);
+ var gate_inputs = math_ops.matmul(concat, _kernel.AsTensor());
+ gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor());
var output = _activation(gate_inputs, null);
return output;
}
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index db528e70..5a99deff 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -326,7 +326,7 @@ namespace Tensorflow
// the updated inputs are reloaded from the c_api
lock (Locks.ProcessWide)
{
- c_api.UpdateEdge(_graph, output, input, tf.Status.Handle);
+ // c_api.UpdateEdge(_graph, output, input, tf.Status.Handle);
//var updated_inputs = inputs;
tf.Status.Check();
}
diff --git a/src/TensorFlowNET.Core/Operations/embedding_ops.cs b/src/TensorFlowNET.Core/Operations/embedding_ops.cs
index fa94244b..f9ba150a 100644
--- a/src/TensorFlowNET.Core/Operations/embedding_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/embedding_ops.cs
@@ -74,7 +74,7 @@ namespace Tensorflow
ids = ops.convert_to_tensor(ids, name: "ids");
if (np == 1)
{
- var gather = array_ops.gather(@params, ids, name: name);
+ var gather = array_ops.gather(@params.AsTensor(), ids, name: name);
var result = _clip(gather, ids, max_norm);
return array_ops.identity(result);
diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs
index 2567ecd9..71faef7d 100644
--- a/src/TensorFlowNET.Core/Operations/math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/math_ops.cs
@@ -706,11 +706,12 @@ namespace Tensorflow
=> tf_with(ops.name_scope(name, "Pow", new { x, y }), scope =>
{
name = scope;
- var x_tensor = ops.convert_to_tensor(x, name: "x");
- var y_tensor = ops.convert_to_tensor(y, name: "y", dtype: x_tensor.dtype.as_base_dtype());
if (tf.executing_eagerly())
{
+ var x_tensor = ops.convert_to_tensor(x, name: "x");
+ var y_tensor = ops.convert_to_tensor(y, name: "y", dtype: x_tensor.dtype.as_base_dtype());
+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Pow", name,
null,
@@ -719,7 +720,7 @@ namespace Tensorflow
return results[0];
}
- var _op = tf.OpDefLib._apply_op_helper("Pow", name, args: new { x_tensor, y_tensor });
+ var _op = tf.OpDefLib._apply_op_helper("Pow", name, args: new { x, y });
return _op.output;
});
diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
index b90996d1..30651968 100644
--- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
+++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
@@ -10,7 +10,7 @@
Haiping Chen, Meinrad Recheis, Eli Belash
SciSharp STACK
true
- Apache 2.0
+ Apache 2.0, Haiping Chen $([System.DateTime]::UtcNow.ToString(yyyy))
https://github.com/SciSharp/TensorFlow.NET
git
http://scisharpstack.org
diff --git a/src/TensorFlowNET.Core/Training/AdamOptimizer.cs b/src/TensorFlowNET.Core/Training/AdamOptimizer.cs
index 47d4331c..4151843b 100644
--- a/src/TensorFlowNET.Core/Training/AdamOptimizer.cs
+++ b/src/TensorFlowNET.Core/Training/AdamOptimizer.cs
@@ -52,6 +52,14 @@ namespace Tensorflow.Train
_dtype = dtype;
}
+ public override Operation _apply_sparse(IndexedSlices grad, ResourceVariable var)
+ {
+ return _apply_sparse_shared(grad.values, var, grad.indices, (x, i, v) =>
+ {
+ return state_ops.scatter_add(x, i, v, use_locking: _use_locking);
+ });
+ }
+
public override Operation _apply_sparse(IndexedSlices grad, RefVariable var)
{
return _apply_sparse_shared(grad.values, var, grad.indices, (x, i, v) =>
@@ -91,7 +99,7 @@ namespace Tensorflow.Train
var lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power));
var m = get_slot(var, "m");
var m_scaled_g_values = grad * (1 - beta1_t);
- var m_t = state_ops.assign(m.AsTensor(), m.AsTensor() * beta1_t, use_locking: _use_locking);
+ var m_t = state_ops.assign(m, m.AsTensor() * beta1_t, use_locking: _use_locking);
tf_with(ops.control_dependencies(new[] { m_t }), delegate
{
m_t = scatter_add(m, indices, m_scaled_g_values);
@@ -99,7 +107,7 @@ namespace Tensorflow.Train
var v = get_slot(var, "v");
var v_scaled_g_values = (grad * grad) * (1 - beta2_t);
- var v_t = state_ops.assign(v.AsTensor(), v.AsTensor() * beta2_t, use_locking: _use_locking);
+ var v_t = state_ops.assign(v, v.AsTensor() * beta2_t, use_locking: _use_locking);
tf_with(ops.control_dependencies(new[] { v_t }), delegate
{
v_t = scatter_add(v, indices, v_scaled_g_values);
diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs
index 20822af0..86c8a33f 100644
--- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs
+++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs
@@ -56,7 +56,7 @@ namespace Tensorflow
///
///
///
- public static Tensor assign(Tensor @ref, object value,
+ public static Tensor assign(T @ref, object value,
bool validate_shape = true,
bool use_locking = true,
string name = null)
@@ -74,40 +74,10 @@ namespace Tensorflow
return _result[0];
}
- public static Tensor assign(RefVariable @ref, object value,
- bool validate_shape = true,
- bool use_locking = true,
- string name = null)
- {
- var _op = tf.OpDefLib._apply_op_helper("Assign", name: name, args: new { @ref, value, validate_shape, use_locking });
-
- var _result = _op.outputs;
- var _inputs_flat = _op.inputs;
-
- var _attrs = new Dictionary();
- _attrs["T"] = _op.get_attr("T");
- _attrs["validate_shape"] = _op.get_attr("validate_shape");
- _attrs["use_locking"] = _op.get_attr("use_locking");
-
- return _result[0];
- }
-
- public static Tensor assign(ResourceVariable @ref, object value,
- bool validate_shape = true,
- bool use_locking = true,
- string name = null)
+ public static Tensor assign_add(IVariableV1 @ref, T value, bool use_locking = false, string name = null)
{
- var _op = tf.OpDefLib._apply_op_helper("Assign", name: name, args: new { @ref, value, validate_shape, use_locking });
-
- var _result = _op.outputs;
- var _inputs_flat = _op.inputs;
-
- var _attrs = new Dictionary();
- _attrs["T"] = _op.get_attr("T");
- _attrs["validate_shape"] = _op.get_attr("validate_shape");
- _attrs["use_locking"] = _op.get_attr("use_locking");
-
- return _result[0];
+ var _op = tf.OpDefLib._apply_op_helper("AssignAdd", name: name, args: new { @ref, value, use_locking });
+ return _op.outputs[0];
}
public static Tensor assign_sub(IVariableV1 @ref,
diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs
index ad621915..e7962ac1 100644
--- a/src/TensorFlowNET.Core/Variables/state_ops.cs
+++ b/src/TensorFlowNET.Core/Variables/state_ops.cs
@@ -15,6 +15,7 @@
******************************************************************************/
using System;
+using static Tensorflow.Binding;
namespace Tensorflow
{
@@ -54,19 +55,7 @@ namespace Tensorflow
return @ref.assign((Tensor)value, name: name);
}
- public static Tensor assign(RefVariable @ref, object value,
- bool validate_shape = true,
- bool use_locking = true,
- string name = null)
- {
- return gen_state_ops.assign(@ref,
- value,
- validate_shape: validate_shape,
- use_locking: use_locking,
- name: name);
- }
-
- public static Tensor assign(ResourceVariable @ref, object value,
+ public static Tensor assign(T @ref, object value,
bool validate_shape = true,
bool use_locking = true,
string name = null)
@@ -110,7 +99,12 @@ namespace Tensorflow
T value,
bool use_locking = false,
string name = null)
- => @ref.assign_add(value, use_locking: use_locking, name: name);
+ {
+ if(tf.executing_eagerly())
+ return @ref.assign_add(value, use_locking: use_locking, name: name);
+ else
+ return gen_state_ops.assign_add(@ref, value, use_locking: use_locking, name: name);
+ }
public static Tensor scatter_add(IVariableV1 @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null)
{