@@ -268,6 +268,16 @@ namespace Tensorflow | |||
public static Tensor rank(Tensor input, string name = null) | |||
{ | |||
if (tf.context.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
"Rank", name, | |||
null, | |||
input); | |||
return results[0]; | |||
} | |||
var _op = tf._op_def_lib._apply_op_helper("Rank", name: name, args: new { input }); | |||
return _op.outputs[0]; | |||
@@ -567,7 +567,7 @@ namespace Tensorflow | |||
} | |||
else | |||
{ | |||
if(x is Tensor) | |||
if(x.rank > -1) | |||
return constant_op.constant(np.arange(x.rank)); | |||
var rank = array_ops.rank(x); | |||
@@ -109,7 +109,7 @@ namespace Tensorflow.Train | |||
return control_flow_ops.group(new[] { var_update, m_t, v_t }); | |||
} | |||
protected override void _create_slots(RefVariable[] var_list) | |||
protected override void _create_slots(ResourceVariable[] var_list) | |||
{ | |||
var first_var = var_list.OrderBy(x => x.Name).First(); | |||
_create_non_slot_variable(initial_value: _beta1, name: "beta1_power", colocate_with: first_var); | |||
@@ -107,7 +107,7 @@ namespace Tensorflow | |||
/// </returns> | |||
public Operation minimize(Tensor loss, | |||
RefVariable global_step = null, | |||
List<RefVariable> var_list=null, | |||
List<ResourceVariable> var_list=null, | |||
GateGradientType gate_gradients = GateGradientType.GATE_OP, | |||
int? aggregation_method=null, | |||
bool colocate_gradients_with_ops = false, string name=null, Tensor grad_loss=null) | |||
@@ -142,17 +142,17 @@ namespace Tensorflow | |||
/// <returns> | |||
/// An `Operation` that applies the specified gradients. If `global_step` | |||
/// was not None, that operation also increments `global_step`.</returns> | |||
public Operation apply_gradients(Tuple<Tensor, RefVariable>[] grads_and_vars, RefVariable global_step = null, string name = null) | |||
public Operation apply_gradients(Tuple<Tensor, ResourceVariable>[] grads_and_vars, RefVariable global_step = null, string name = null) | |||
{ | |||
// No DistributionStrategy case. | |||
var converted_grads_and_vars = new List<(Tensor, RefVariable, _OptimizableVariable)>(); | |||
var converted_grads_and_vars = new List<(Tensor, ResourceVariable, _OptimizableVariable)>(); | |||
foreach (var (g, v) in grads_and_vars) | |||
{ | |||
if(g != null) | |||
{ | |||
// Convert the grad to Tensor or IndexedSlices if necessary. | |||
var gR = ops.convert_to_tensor_or_indexed_slices(g); | |||
var p = _get_processor(v); | |||
var p = optimizer._get_processor(v); | |||
converted_grads_and_vars.Add((gR, v, p)); | |||
} | |||
} | |||
@@ -230,7 +230,7 @@ namespace Tensorflow | |||
/// silently ignored). | |||
/// </summary> | |||
/// <param name="var_list"></param> | |||
protected virtual void _create_slots(RefVariable[] var_list) | |||
protected virtual void _create_slots(ResourceVariable[] var_list) | |||
{ | |||
} | |||
@@ -276,6 +276,12 @@ namespace Tensorflow | |||
return control_flow_ops.group(update_ops, name_scope); | |||
} | |||
public virtual Operation _apply_dense(Tensor grad, ResourceVariable var) | |||
{ | |||
var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); | |||
return gen_training_ops.resource_apply_gradient_descent(var.Handle, alpha, grad, use_locking: _use_locking).op; | |||
} | |||
public virtual Operation _apply_dense(Tensor grad, RefVariable var) | |||
{ | |||
var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); | |||
@@ -298,6 +304,16 @@ namespace Tensorflow | |||
return _apply_sparse(gradient_no_duplicate_indices, var); | |||
} | |||
public virtual Operation _apply_sparse_duplicate_indices(IndexedSlices grad, ResourceVariable var) | |||
{ | |||
var (summed_values, unique_indices) = _deduplicate_indexed_slices(values: grad.values, indices: grad.indices); | |||
var gradient_no_duplicate_indices = new IndexedSlices( | |||
indices: unique_indices, | |||
values: summed_values, | |||
dense_shape: grad.dense_shape); | |||
return _apply_sparse(gradient_no_duplicate_indices, var); | |||
} | |||
public virtual Operation _apply_sparse(IndexedSlices grad, RefVariable var) | |||
{ | |||
throw new NotImplementedException("_apply_sparse"); | |||
@@ -344,18 +360,6 @@ namespace Tensorflow | |||
return non_slot; | |||
} | |||
private _OptimizableVariable _get_processor(RefVariable v) | |||
{ | |||
if(v is RefVariable) | |||
{ | |||
return new _RefVariableProcessor(v); | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException("_get_processor"); | |||
} | |||
} | |||
/// <summary> | |||
/// Compute gradients of `loss` for the variables in `var_list`. | |||
/// </summary> | |||
@@ -365,8 +369,8 @@ namespace Tensorflow | |||
/// A list of (gradient, variable) pairs. Variable is always present, but | |||
/// gradient can be `None`. | |||
/// </returns> | |||
public Tuple<Tensor, RefVariable>[] compute_gradients(Tensor loss, | |||
List<RefVariable> var_list = null, | |||
public Tuple<Tensor, ResourceVariable>[] compute_gradients(Tensor loss, | |||
List<ResourceVariable> var_list = null, | |||
int? aggregation_method = null, | |||
GateGradientType gate_gradients = GateGradientType.GATE_OP, | |||
bool colocate_gradients_with_ops = false, | |||
@@ -374,26 +378,28 @@ namespace Tensorflow | |||
{ | |||
// Scale loss if using a "mean" loss reduction and multiple replicas. | |||
loss = _scale_loss(loss); | |||
#pragma warning disable CS0219 // Variable is assigned but its value is never used | |||
int num_towers = 1; | |||
#pragma warning restore CS0219 // Variable is assigned but its value is never used | |||
if(var_list == null) | |||
{ | |||
var vars = ops.get_collection<RefVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); | |||
var vars = ops.get_collection<ResourceVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); | |||
var tmp = variables.trainable_variables(); | |||
switch (tmp) | |||
{ | |||
case List<RefVariable> values: | |||
case List<ResourceVariable> values: | |||
var_list = values.Concat(vars).ToList(); | |||
break; | |||
/*case List<RefVariable> values: | |||
var_list = values.Concat(vars).ToList(); | |||
break; | |||
case List<IVariableV1> values: | |||
var_list = values.Select(x => x as RefVariable).Concat(vars).ToList(); | |||
break; | |||
break;*/ | |||
default: | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
var_list = var_list.Concat(ops.get_collection<RefVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); | |||
var_list = var_list.Concat(ops.get_collection<ResourceVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); | |||
var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); | |||
var var_refs = processors.Select(x => x.target()).ToArray(); | |||
@@ -406,7 +412,7 @@ namespace Tensorflow | |||
grads = control_flow_ops.tuple(grads); | |||
var grads_and_vars = zip(grads, var_list) | |||
.Select(x => new Tuple<Tensor, RefVariable>(x.Item1, x.Item2)) | |||
.Select(x => new Tuple<Tensor, ResourceVariable>(x.Item1, x.Item2)) | |||
.ToArray(); | |||
return grads_and_vars; | |||
@@ -59,7 +59,7 @@ namespace Tensorflow | |||
return _op.outputs[0]; | |||
} | |||
public static Operation resource_apply_gradient_descent(EagerTensor var, EagerTensor alpha, EagerTensor delta, bool use_locking = false, string name = null) | |||
public static Operation resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) | |||
{ | |||
if (tf.context.executing_eagerly()) | |||
{ | |||
@@ -79,7 +79,7 @@ namespace Tensorflow | |||
use_locking | |||
}); | |||
return _op.outputs[0]; | |||
return _op; | |||
} | |||
} | |||
} |
@@ -24,6 +24,11 @@ namespace Tensorflow | |||
{ | |||
return new _RefVariableProcessor(v); | |||
} | |||
public static _OptimizableVariable _get_processor(ResourceVariable v) | |||
{ | |||
return new _DenseResourceVariableProcessor(v); | |||
} | |||
} | |||
public class _RefVariableProcessor : _OptimizableVariable | |||
@@ -56,4 +61,35 @@ namespace Tensorflow | |||
return update_op; | |||
} | |||
} | |||
public class _DenseResourceVariableProcessor : _OptimizableVariable | |||
{ | |||
private ResourceVariable _v; | |||
public _DenseResourceVariableProcessor(ResourceVariable v) | |||
{ | |||
_v = v; | |||
} | |||
public Tensor target() | |||
{ | |||
return _v.Handle; | |||
} | |||
public Operation update_op(Optimizer optimizer, Tensor g) | |||
{ | |||
Operation update_op = null; | |||
if (g.Tag == null) | |||
{ | |||
update_op = optimizer._apply_dense(g, _v); | |||
} | |||
else if (g.Tag is IndexedSlices) | |||
{ | |||
return optimizer._apply_sparse_duplicate_indices(g, _v); | |||
} | |||
return update_op; | |||
} | |||
} | |||
} |
@@ -1,438 +0,0 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Collections; | |||
using System.Collections.Generic; | |||
using System.Diagnostics.CodeAnalysis; | |||
namespace Tensorflow | |||
{ | |||
public class WeakKeyDictionary<TKey, TValue> : IDictionary<TKey, TValue> | |||
{ | |||
private Dictionary<WeakKey, TValue> _internalDictionary; | |||
private object _internalObject = new object(); | |||
private bool _finalized; | |||
public WeakKeyDictionary() | |||
{ | |||
_internalDictionary = new Dictionary<WeakKey, TValue>(new WeakComparer()); | |||
} | |||
public WeakKeyDictionary(int capacity) | |||
{ | |||
_internalDictionary = new Dictionary<WeakKey, TValue>(capacity, new WeakComparer()); | |||
} | |||
public WeakKeyDictionary(IEqualityComparer<TKey> comparer) | |||
{ | |||
_internalDictionary = new Dictionary<WeakKey, TValue>(new WeakComparer(comparer)); | |||
} | |||
public WeakKeyDictionary(int capacity, IEqualityComparer<TKey> comparer) | |||
{ | |||
_internalDictionary = new Dictionary<WeakKey, TValue>(capacity, new WeakComparer(comparer)); | |||
} | |||
// FXCop: this is not empty; we need to mark this so we know if a key | |||
// still has an active dictionary at its finalization. | |||
[SuppressMessage("Microsoft.Performance", "CA1821:RemoveEmptyFinalizers")] | |||
~WeakKeyDictionary() | |||
{ | |||
_finalized = true; | |||
} | |||
public ICollection<TKey> Keys | |||
{ | |||
get | |||
{ | |||
List<TKey> list = new List<TKey>(); | |||
lock (_internalObject) | |||
{ | |||
foreach (WeakKey key in _internalDictionary.Keys) | |||
{ | |||
object TKey = key.Target; | |||
if (TKey != null) | |||
{ | |||
list.Add((TKey)TKey); | |||
} | |||
} | |||
} | |||
return list; | |||
} | |||
} | |||
public ICollection<TValue> Values | |||
{ | |||
get { | |||
lock (_internalObject) { | |||
return _internalDictionary.Values; | |||
} | |||
} | |||
} | |||
public int Count | |||
{ | |||
get | |||
{ | |||
// Ensure a fairly accurate count. | |||
ScavangeLostKeys(); | |||
lock (_internalObject) | |||
{ | |||
return _internalDictionary.Count; | |||
} | |||
} | |||
} | |||
public bool IsReadOnly | |||
{ | |||
get { | |||
return false; | |||
} | |||
} | |||
[SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")] | |||
public TValue this[TKey key] | |||
{ | |||
get { | |||
lock (_internalObject) { | |||
return _internalDictionary[new WeakKey(key)]; | |||
} | |||
} | |||
set | |||
{ | |||
WeakKey Tkey = new WeakKey(key); | |||
lock (_internalObject) | |||
{ | |||
//_internalDictionary[Tkey] = value; | |||
_internalDictionary.Add(Tkey, value); | |||
} | |||
// This looks a bit weird but the purpose of the lost key finder is to execute | |||
// code in some future garbage collection phase so we immediately create some garbage. | |||
new LostKeyFinder(this, Tkey); | |||
} | |||
} | |||
public bool TryGetValue(TKey key, out TValue value) | |||
{ | |||
WeakKey tkey = new WeakKey(key); | |||
lock (_internalObject) | |||
{ | |||
return _internalDictionary.TryGetValue(tkey, out value); | |||
} | |||
} | |||
[SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")] | |||
public void Add(TKey key, TValue value) | |||
{ | |||
WeakKey tkey = new WeakKey(key); | |||
lock (_internalObject) | |||
{ | |||
_internalDictionary.Add(tkey, value); | |||
} | |||
// This looks a bit weird but the purpose of the lost key finder is to execute | |||
// code in some future garbage collection phase so we immediately create some garbage. | |||
new LostKeyFinder(this, tkey); | |||
} | |||
public bool ContainsKey(TKey key) | |||
{ | |||
return _internalDictionary.ContainsKey(new WeakKey(key)); | |||
} | |||
public bool Remove(TKey key) | |||
{ | |||
lock (_internalObject) | |||
{ | |||
return _internalDictionary.Remove(new WeakKey(key)); | |||
} | |||
} | |||
public void Add(KeyValuePair<TKey, TValue> item) | |||
{ | |||
Add(item.Key, item.Value); | |||
} | |||
public void Clear() | |||
{ | |||
lock (_internalObject) | |||
{ | |||
_internalDictionary.Clear(); | |||
} | |||
} | |||
public bool Contains(KeyValuePair<TKey, TValue> item) | |||
{ | |||
TValue value; | |||
bool result; | |||
lock (_internalObject) | |||
{ | |||
result = _internalDictionary.TryGetValue(new WeakKey(item.Key), out value); | |||
} | |||
if (result) | |||
{ | |||
return value.Equals(item.Value); | |||
} | |||
else | |||
{ | |||
return false; | |||
} | |||
} | |||
public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex) | |||
{ | |||
lock (_internalObject) | |||
{ | |||
foreach (KeyValuePair<WeakKey, TValue> item in _internalDictionary) | |||
{ | |||
KeyValuePair<TKey, TValue> kv = new KeyValuePair<TKey, TValue>((TKey)item.Key.Target, item.Value); | |||
array[arrayIndex] = kv; | |||
arrayIndex++; | |||
} | |||
} | |||
} | |||
public bool Remove(KeyValuePair<TKey, TValue> item) | |||
{ | |||
WeakKey key = new WeakKey(item.Key); | |||
lock (_internalObject) | |||
{ | |||
return _internalDictionary.Remove(key); | |||
} | |||
} | |||
public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator() | |||
{ | |||
List<WeakKey> lostKeys = null; | |||
lock (_internalObject) | |||
{ | |||
foreach (KeyValuePair<WeakKey, TValue> item in _internalDictionary) | |||
{ | |||
object TKey = item.Key.Target; | |||
if (TKey != null) | |||
{ | |||
yield return new KeyValuePair<TKey, TValue>((TKey)TKey, item.Value); | |||
} | |||
else | |||
{ | |||
if (lostKeys == null) | |||
{ | |||
lostKeys = new List<WeakKey>(); | |||
} | |||
lostKeys.Add(item.Key); | |||
} | |||
} | |||
} | |||
// Recover any lost keys. | |||
if (lostKeys != null) | |||
{ | |||
lock (_internalObject) | |||
{ | |||
foreach (WeakKey key in lostKeys) | |||
{ | |||
_internalDictionary.Remove(key); | |||
} | |||
} | |||
} | |||
} | |||
IEnumerator IEnumerable.GetEnumerator() | |||
{ | |||
return GetEnumerator(); | |||
} | |||
private void ScavangeLostKeys() | |||
{ | |||
List<WeakKey> lostKeys = null; | |||
lock (_internalObject) | |||
{ | |||
foreach (WeakKey key in _internalDictionary.Keys) | |||
{ | |||
if (!key.IsAlive) | |||
{ | |||
if (lostKeys == null) | |||
{ | |||
lostKeys = new List<WeakKey>(); | |||
} | |||
lostKeys.Add(key); | |||
} | |||
} | |||
} | |||
if (lostKeys != null) | |||
{ | |||
lock (_internalObject) | |||
{ | |||
foreach (WeakKey key in lostKeys) | |||
{ | |||
_internalDictionary.Remove(key); | |||
} | |||
} | |||
} | |||
} | |||
IEnumerator<KeyValuePair<TKey, TValue>> IEnumerable<KeyValuePair<TKey, TValue>>.GetEnumerator() | |||
{ | |||
return this.GetEnumerator(); | |||
} | |||
private class WeakKey : WeakReference | |||
{ | |||
private int _hashCode; | |||
// private GCHandle _gcHandle; | |||
public WeakKey(TKey key) | |||
: base(key, true) | |||
{ | |||
_hashCode = key.GetHashCode(); | |||
// Keep the key alive until it is explicitly collected | |||
// _gcHandle = GCHandle.Alloc(this); | |||
} | |||
internal void Release() | |||
{ | |||
// _gcHandle.Free(); | |||
} | |||
public override int GetHashCode() | |||
{ | |||
return _hashCode; | |||
} | |||
public override bool Equals(object obj) | |||
{ | |||
if (obj == null) | |||
{ | |||
return false; | |||
} | |||
if (obj.GetHashCode() != _hashCode) | |||
{ | |||
return false; | |||
} | |||
if (obj != this && (!IsAlive || !obj.Equals(Target))) | |||
{ | |||
return false; | |||
} | |||
return true; | |||
} | |||
} | |||
private class WeakComparer : IEqualityComparer<WeakKey> | |||
{ | |||
private IEqualityComparer<TKey> _comparer; | |||
public WeakComparer() | |||
{ | |||
} | |||
public WeakComparer(IEqualityComparer<TKey> comparer) | |||
{ | |||
_comparer = comparer; | |||
} | |||
public bool Equals(WeakKey x, WeakKey y) | |||
{ | |||
if (x.GetHashCode() != y.GetHashCode()) | |||
{ | |||
return false; | |||
} | |||
if (object.ReferenceEquals(x, y)) | |||
{ | |||
return true; | |||
} | |||
object ref1 = x.Target; | |||
if (ref1 == null) | |||
{ | |||
return false; | |||
} | |||
object ref2 = y.Target; | |||
if (ref2 == null) | |||
{ | |||
return false; | |||
} | |||
if (_comparer != null) | |||
{ | |||
return _comparer.Equals((TKey)ref1, (TKey)ref2); | |||
} | |||
else | |||
{ | |||
return ref1.Equals(ref2); | |||
} | |||
} | |||
public int GetHashCode(WeakKey obj) | |||
{ | |||
return obj.GetHashCode(); | |||
} | |||
} | |||
private class LostKeyFinder | |||
{ | |||
WeakKeyDictionary<TKey, TValue> _dictionary; | |||
WeakKey _key; | |||
public LostKeyFinder(WeakKeyDictionary<TKey, TValue> dictionary, WeakKey key) | |||
{ | |||
_dictionary = dictionary; | |||
_key = key; | |||
} | |||
~LostKeyFinder() | |||
{ | |||
if (_dictionary._finalized || _key == null) | |||
{ | |||
if (_key != null) | |||
{ | |||
_key.Release(); | |||
_key = null; | |||
} | |||
return; | |||
} | |||
// if (!_key.IsAlive) { | |||
if (_key.Target == null) | |||
{ | |||
lock (_dictionary._internalObject) | |||
{ | |||
_dictionary._internalDictionary.Remove(_key); | |||
} | |||
_key.Release(); | |||
_key = null; | |||
} | |||
else if (_dictionary._internalDictionary.ContainsKey(_key)) | |||
{ | |||
GC.ReRegisterForFinalize(this); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
@@ -2,13 +2,13 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using NumSharp; | |||
using System.Linq; | |||
using Tensorflow; | |||
using Tensorflow.UnitTest; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.Basics | |||
{ | |||
[TestClass] | |||
public class VariableTest | |||
public class VariableTest : EagerModeTestBase | |||
{ | |||
[TestMethod] | |||
public void NewVariable() | |||
@@ -0,0 +1,23 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using TensorFlowNET.UnitTest; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.UnitTest | |||
{ | |||
public class EagerModeTestBase : PythonTest | |||
{ | |||
[TestInitialize] | |||
public void TestInit() | |||
{ | |||
tf.enable_eager_execution(); | |||
} | |||
[TestCleanup] | |||
public void TestClean() | |||
{ | |||
} | |||
} | |||
} |