Browse Source

Fix AdamOptimizer in Graph mode.

tags/v0.20
Oceania2018 5 years ago
parent
commit
dd1cf255a7
9 changed files with 108 additions and 471 deletions
  1. +10
    -0
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Operations/math_ops.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Training/AdamOptimizer.cs
  4. +33
    -27
      src/TensorFlowNET.Core/Training/Optimizer.cs
  5. +2
    -2
      src/TensorFlowNET.Core/Training/gen_training_ops.cs
  6. +36
    -0
      src/TensorFlowNET.Core/Training/optimizer.py.cs
  7. +0
    -438
      src/TensorFlowNET.Core/WeakKeyDicionary.cs
  8. +2
    -2
      test/TensorFlowNET.UnitTest/Basics/VariableTest.cs
  9. +23
    -0
      test/TensorFlowNET.UnitTest/EagerModeTestBase.cs

+ 10
- 0
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -268,6 +268,16 @@ namespace Tensorflow


public static Tensor rank(Tensor input, string name = null) public static Tensor rank(Tensor input, string name = null)
{ {
if (tf.context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Rank", name,
null,
input);

return results[0];
}

var _op = tf._op_def_lib._apply_op_helper("Rank", name: name, args: new { input }); var _op = tf._op_def_lib._apply_op_helper("Rank", name: name, args: new { input });


return _op.outputs[0]; return _op.outputs[0];


+ 1
- 1
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -567,7 +567,7 @@ namespace Tensorflow
} }
else else
{ {
if(x is Tensor)
if(x.rank > -1)
return constant_op.constant(np.arange(x.rank)); return constant_op.constant(np.arange(x.rank));


var rank = array_ops.rank(x); var rank = array_ops.rank(x);


+ 1
- 1
src/TensorFlowNET.Core/Training/AdamOptimizer.cs View File

@@ -109,7 +109,7 @@ namespace Tensorflow.Train
return control_flow_ops.group(new[] { var_update, m_t, v_t }); return control_flow_ops.group(new[] { var_update, m_t, v_t });
} }


protected override void _create_slots(RefVariable[] var_list)
protected override void _create_slots(ResourceVariable[] var_list)
{ {
var first_var = var_list.OrderBy(x => x.Name).First(); var first_var = var_list.OrderBy(x => x.Name).First();
_create_non_slot_variable(initial_value: _beta1, name: "beta1_power", colocate_with: first_var); _create_non_slot_variable(initial_value: _beta1, name: "beta1_power", colocate_with: first_var);


+ 33
- 27
src/TensorFlowNET.Core/Training/Optimizer.cs View File

@@ -107,7 +107,7 @@ namespace Tensorflow
/// </returns> /// </returns>
public Operation minimize(Tensor loss, public Operation minimize(Tensor loss,
RefVariable global_step = null, RefVariable global_step = null,
List<RefVariable> var_list=null,
List<ResourceVariable> var_list=null,
GateGradientType gate_gradients = GateGradientType.GATE_OP, GateGradientType gate_gradients = GateGradientType.GATE_OP,
int? aggregation_method=null, int? aggregation_method=null,
bool colocate_gradients_with_ops = false, string name=null, Tensor grad_loss=null) bool colocate_gradients_with_ops = false, string name=null, Tensor grad_loss=null)
@@ -142,17 +142,17 @@ namespace Tensorflow
/// <returns> /// <returns>
/// An `Operation` that applies the specified gradients. If `global_step` /// An `Operation` that applies the specified gradients. If `global_step`
/// was not None, that operation also increments `global_step`.</returns> /// was not None, that operation also increments `global_step`.</returns>
public Operation apply_gradients(Tuple<Tensor, RefVariable>[] grads_and_vars, RefVariable global_step = null, string name = null)
public Operation apply_gradients(Tuple<Tensor, ResourceVariable>[] grads_and_vars, RefVariable global_step = null, string name = null)
{ {
// No DistributionStrategy case. // No DistributionStrategy case.
var converted_grads_and_vars = new List<(Tensor, RefVariable, _OptimizableVariable)>();
var converted_grads_and_vars = new List<(Tensor, ResourceVariable, _OptimizableVariable)>();
foreach (var (g, v) in grads_and_vars) foreach (var (g, v) in grads_and_vars)
{ {
if(g != null) if(g != null)
{ {
// Convert the grad to Tensor or IndexedSlices if necessary. // Convert the grad to Tensor or IndexedSlices if necessary.
var gR = ops.convert_to_tensor_or_indexed_slices(g); var gR = ops.convert_to_tensor_or_indexed_slices(g);
var p = _get_processor(v);
var p = optimizer._get_processor(v);
converted_grads_and_vars.Add((gR, v, p)); converted_grads_and_vars.Add((gR, v, p));
} }
} }
@@ -230,7 +230,7 @@ namespace Tensorflow
/// silently ignored). /// silently ignored).
/// </summary> /// </summary>
/// <param name="var_list"></param> /// <param name="var_list"></param>
protected virtual void _create_slots(RefVariable[] var_list)
protected virtual void _create_slots(ResourceVariable[] var_list)
{ {
} }
@@ -276,6 +276,12 @@ namespace Tensorflow
return control_flow_ops.group(update_ops, name_scope); return control_flow_ops.group(update_ops, name_scope);
} }


public virtual Operation _apply_dense(Tensor grad, ResourceVariable var)
{
var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype());
return gen_training_ops.resource_apply_gradient_descent(var.Handle, alpha, grad, use_locking: _use_locking).op;
}

public virtual Operation _apply_dense(Tensor grad, RefVariable var) public virtual Operation _apply_dense(Tensor grad, RefVariable var)
{ {
var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype());
@@ -298,6 +304,16 @@ namespace Tensorflow
return _apply_sparse(gradient_no_duplicate_indices, var); return _apply_sparse(gradient_no_duplicate_indices, var);
} }


public virtual Operation _apply_sparse_duplicate_indices(IndexedSlices grad, ResourceVariable var)
{
var (summed_values, unique_indices) = _deduplicate_indexed_slices(values: grad.values, indices: grad.indices);
var gradient_no_duplicate_indices = new IndexedSlices(
indices: unique_indices,
values: summed_values,
dense_shape: grad.dense_shape);
return _apply_sparse(gradient_no_duplicate_indices, var);
}

public virtual Operation _apply_sparse(IndexedSlices grad, RefVariable var) public virtual Operation _apply_sparse(IndexedSlices grad, RefVariable var)
{ {
throw new NotImplementedException("_apply_sparse"); throw new NotImplementedException("_apply_sparse");
@@ -344,18 +360,6 @@ namespace Tensorflow
return non_slot; return non_slot;
} }


private _OptimizableVariable _get_processor(RefVariable v)
{
if(v is RefVariable)
{
return new _RefVariableProcessor(v);
}
else
{
throw new NotImplementedException("_get_processor");
}
}

/// <summary> /// <summary>
/// Compute gradients of `loss` for the variables in `var_list`. /// Compute gradients of `loss` for the variables in `var_list`.
/// </summary> /// </summary>
@@ -365,8 +369,8 @@ namespace Tensorflow
/// A list of (gradient, variable) pairs. Variable is always present, but /// A list of (gradient, variable) pairs. Variable is always present, but
/// gradient can be `None`. /// gradient can be `None`.
/// </returns> /// </returns>
public Tuple<Tensor, RefVariable>[] compute_gradients(Tensor loss,
List<RefVariable> var_list = null,
public Tuple<Tensor, ResourceVariable>[] compute_gradients(Tensor loss,
List<ResourceVariable> var_list = null,
int? aggregation_method = null, int? aggregation_method = null,
GateGradientType gate_gradients = GateGradientType.GATE_OP, GateGradientType gate_gradients = GateGradientType.GATE_OP,
bool colocate_gradients_with_ops = false, bool colocate_gradients_with_ops = false,
@@ -374,26 +378,28 @@ namespace Tensorflow
{ {
// Scale loss if using a "mean" loss reduction and multiple replicas. // Scale loss if using a "mean" loss reduction and multiple replicas.
loss = _scale_loss(loss); loss = _scale_loss(loss);
#pragma warning disable CS0219 // Variable is assigned but its value is never used
int num_towers = 1;
#pragma warning restore CS0219 // Variable is assigned but its value is never used


if(var_list == null) if(var_list == null)
{ {
var vars = ops.get_collection<RefVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES);
var vars = ops.get_collection<ResourceVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES);
var tmp = variables.trainable_variables(); var tmp = variables.trainable_variables();
switch (tmp) switch (tmp)
{ {
case List<RefVariable> values:
case List<ResourceVariable> values:
var_list = values.Concat(vars).ToList();
break;
/*case List<RefVariable> values:
var_list = values.Concat(vars).ToList(); var_list = values.Concat(vars).ToList();
break; break;
case List<IVariableV1> values: case List<IVariableV1> values:
var_list = values.Select(x => x as RefVariable).Concat(vars).ToList(); var_list = values.Select(x => x as RefVariable).Concat(vars).ToList();
break;
break;*/
default:
throw new NotImplementedException("");
} }
} }


var_list = var_list.Concat(ops.get_collection<RefVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList();
var_list = var_list.Concat(ops.get_collection<ResourceVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList();
var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); var processors = var_list.Select(v => optimizer._get_processor(v)).ToList();
var var_refs = processors.Select(x => x.target()).ToArray(); var var_refs = processors.Select(x => x.target()).ToArray();


@@ -406,7 +412,7 @@ namespace Tensorflow
grads = control_flow_ops.tuple(grads); grads = control_flow_ops.tuple(grads);


var grads_and_vars = zip(grads, var_list) var grads_and_vars = zip(grads, var_list)
.Select(x => new Tuple<Tensor, RefVariable>(x.Item1, x.Item2))
.Select(x => new Tuple<Tensor, ResourceVariable>(x.Item1, x.Item2))
.ToArray(); .ToArray();


return grads_and_vars; return grads_and_vars;


+ 2
- 2
src/TensorFlowNET.Core/Training/gen_training_ops.cs View File

@@ -59,7 +59,7 @@ namespace Tensorflow
return _op.outputs[0]; return _op.outputs[0];
} }


public static Operation resource_apply_gradient_descent(EagerTensor var, EagerTensor alpha, EagerTensor delta, bool use_locking = false, string name = null)
public static Operation resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null)
{ {
if (tf.context.executing_eagerly()) if (tf.context.executing_eagerly())
{ {
@@ -79,7 +79,7 @@ namespace Tensorflow
use_locking use_locking
}); });


return _op.outputs[0];
return _op;
} }
} }
} }

+ 36
- 0
src/TensorFlowNET.Core/Training/optimizer.py.cs View File

@@ -24,6 +24,11 @@ namespace Tensorflow
{ {
return new _RefVariableProcessor(v); return new _RefVariableProcessor(v);
} }

public static _OptimizableVariable _get_processor(ResourceVariable v)
{
return new _DenseResourceVariableProcessor(v);
}
} }


public class _RefVariableProcessor : _OptimizableVariable public class _RefVariableProcessor : _OptimizableVariable
@@ -56,4 +61,35 @@ namespace Tensorflow
return update_op; return update_op;
} }
} }

public class _DenseResourceVariableProcessor : _OptimizableVariable
{
private ResourceVariable _v;

public _DenseResourceVariableProcessor(ResourceVariable v)
{
_v = v;
}

public Tensor target()
{
return _v.Handle;
}

public Operation update_op(Optimizer optimizer, Tensor g)
{
Operation update_op = null;

if (g.Tag == null)
{
update_op = optimizer._apply_dense(g, _v);
}
else if (g.Tag is IndexedSlices)
{
return optimizer._apply_sparse_duplicate_indices(g, _v);
}

return update_op;
}
}
} }

+ 0
- 438
src/TensorFlowNET.Core/WeakKeyDicionary.cs View File

@@ -1,438 +0,0 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;

namespace Tensorflow
{
public class WeakKeyDictionary<TKey, TValue> : IDictionary<TKey, TValue>
{
private Dictionary<WeakKey, TValue> _internalDictionary;
private object _internalObject = new object();
private bool _finalized;
public WeakKeyDictionary()
{
_internalDictionary = new Dictionary<WeakKey, TValue>(new WeakComparer());
}
public WeakKeyDictionary(int capacity)
{
_internalDictionary = new Dictionary<WeakKey, TValue>(capacity, new WeakComparer());
}
public WeakKeyDictionary(IEqualityComparer<TKey> comparer)
{
_internalDictionary = new Dictionary<WeakKey, TValue>(new WeakComparer(comparer));
}
public WeakKeyDictionary(int capacity, IEqualityComparer<TKey> comparer)
{
_internalDictionary = new Dictionary<WeakKey, TValue>(capacity, new WeakComparer(comparer));
}
// FXCop: this is not empty; we need to mark this so we know if a key
// still has an active dictionary at its finalization.
[SuppressMessage("Microsoft.Performance", "CA1821:RemoveEmptyFinalizers")]
~WeakKeyDictionary()
{
_finalized = true;
}
public ICollection<TKey> Keys
{
get
{
List<TKey> list = new List<TKey>();
lock (_internalObject)
{
foreach (WeakKey key in _internalDictionary.Keys)
{
object TKey = key.Target;
if (TKey != null)
{
list.Add((TKey)TKey);
}
}
}
return list;
}
}
public ICollection<TValue> Values
{
get {
lock (_internalObject) {
return _internalDictionary.Values;
}
}
}
public int Count
{
get
{
// Ensure a fairly accurate count.
ScavangeLostKeys();
lock (_internalObject)
{
return _internalDictionary.Count;
}
}
}
public bool IsReadOnly
{
get {
return false;
}
}
[SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")]
public TValue this[TKey key]
{
get {
lock (_internalObject) {
return _internalDictionary[new WeakKey(key)];
}
}
set
{
WeakKey Tkey = new WeakKey(key);
lock (_internalObject)
{
//_internalDictionary[Tkey] = value;
_internalDictionary.Add(Tkey, value);
}
// This looks a bit weird but the purpose of the lost key finder is to execute
// code in some future garbage collection phase so we immediately create some garbage.
new LostKeyFinder(this, Tkey);
}
}
public bool TryGetValue(TKey key, out TValue value)
{
WeakKey tkey = new WeakKey(key);
lock (_internalObject)
{
return _internalDictionary.TryGetValue(tkey, out value);
}
}
[SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")]
public void Add(TKey key, TValue value)
{
WeakKey tkey = new WeakKey(key);
lock (_internalObject)
{
_internalDictionary.Add(tkey, value);
}
// This looks a bit weird but the purpose of the lost key finder is to execute
// code in some future garbage collection phase so we immediately create some garbage.
new LostKeyFinder(this, tkey);
}
public bool ContainsKey(TKey key)
{
return _internalDictionary.ContainsKey(new WeakKey(key));
}
public bool Remove(TKey key)
{
lock (_internalObject)
{
return _internalDictionary.Remove(new WeakKey(key));
}
}
public void Add(KeyValuePair<TKey, TValue> item)
{
Add(item.Key, item.Value);
}
public void Clear()
{
lock (_internalObject)
{
_internalDictionary.Clear();
}
}
public bool Contains(KeyValuePair<TKey, TValue> item)
{
TValue value;
bool result;
lock (_internalObject)
{
result = _internalDictionary.TryGetValue(new WeakKey(item.Key), out value);
}
if (result)
{
return value.Equals(item.Value);
}
else
{
return false;
}
}
public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex)
{
lock (_internalObject)
{
foreach (KeyValuePair<WeakKey, TValue> item in _internalDictionary)
{
KeyValuePair<TKey, TValue> kv = new KeyValuePair<TKey, TValue>((TKey)item.Key.Target, item.Value);
array[arrayIndex] = kv;
arrayIndex++;
}
}
}
public bool Remove(KeyValuePair<TKey, TValue> item)
{
WeakKey key = new WeakKey(item.Key);
lock (_internalObject)
{
return _internalDictionary.Remove(key);
}
}
public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator()
{
List<WeakKey> lostKeys = null;
lock (_internalObject)
{
foreach (KeyValuePair<WeakKey, TValue> item in _internalDictionary)
{
object TKey = item.Key.Target;
if (TKey != null)
{
yield return new KeyValuePair<TKey, TValue>((TKey)TKey, item.Value);
}
else
{
if (lostKeys == null)
{
lostKeys = new List<WeakKey>();
}
lostKeys.Add(item.Key);
}
}
}
// Recover any lost keys.
if (lostKeys != null)
{
lock (_internalObject)
{
foreach (WeakKey key in lostKeys)
{
_internalDictionary.Remove(key);
}
}
}
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
private void ScavangeLostKeys()
{
List<WeakKey> lostKeys = null;
lock (_internalObject)
{
foreach (WeakKey key in _internalDictionary.Keys)
{
if (!key.IsAlive)
{
if (lostKeys == null)
{
lostKeys = new List<WeakKey>();
}
lostKeys.Add(key);
}
}
}
if (lostKeys != null)
{
lock (_internalObject)
{
foreach (WeakKey key in lostKeys)
{
_internalDictionary.Remove(key);
}
}
}
}

IEnumerator<KeyValuePair<TKey, TValue>> IEnumerable<KeyValuePair<TKey, TValue>>.GetEnumerator()
{
return this.GetEnumerator();
}

private class WeakKey : WeakReference
{
private int _hashCode;
// private GCHandle _gcHandle;
public WeakKey(TKey key)
: base(key, true)
{
_hashCode = key.GetHashCode();
// Keep the key alive until it is explicitly collected
// _gcHandle = GCHandle.Alloc(this);
}
internal void Release()
{
// _gcHandle.Free();
}
public override int GetHashCode()
{
return _hashCode;
}
public override bool Equals(object obj)
{
if (obj == null)
{
return false;
}
if (obj.GetHashCode() != _hashCode)
{
return false;
}
if (obj != this && (!IsAlive || !obj.Equals(Target)))
{
return false;
}
return true;
}
}
private class WeakComparer : IEqualityComparer<WeakKey>
{
private IEqualityComparer<TKey> _comparer;
public WeakComparer()
{
}
public WeakComparer(IEqualityComparer<TKey> comparer)
{
_comparer = comparer;
}
public bool Equals(WeakKey x, WeakKey y)
{
if (x.GetHashCode() != y.GetHashCode())
{
return false;
}
if (object.ReferenceEquals(x, y))
{
return true;
}
object ref1 = x.Target;
if (ref1 == null)
{
return false;
}
object ref2 = y.Target;
if (ref2 == null)
{
return false;
}
if (_comparer != null)
{
return _comparer.Equals((TKey)ref1, (TKey)ref2);
}
else
{
return ref1.Equals(ref2);
}
}
public int GetHashCode(WeakKey obj)
{
return obj.GetHashCode();
}
}
private class LostKeyFinder
{
WeakKeyDictionary<TKey, TValue> _dictionary;
WeakKey _key;
public LostKeyFinder(WeakKeyDictionary<TKey, TValue> dictionary, WeakKey key)
{
_dictionary = dictionary;
_key = key;
}
~LostKeyFinder()
{
if (_dictionary._finalized || _key == null)
{
if (_key != null)
{
_key.Release();
_key = null;
}
return;
}
// if (!_key.IsAlive) {
if (_key.Target == null)
{
lock (_dictionary._internalObject)
{
_dictionary._internalDictionary.Remove(_key);
}
_key.Release();
_key = null;
}
else if (_dictionary._internalDictionary.ContainsKey(_key))
{
GC.ReRegisterForFinalize(this);
}
}
}
}
}

+ 2
- 2
test/TensorFlowNET.UnitTest/Basics/VariableTest.cs View File

@@ -2,13 +2,13 @@
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp; using NumSharp;
using System.Linq; using System.Linq;
using Tensorflow;
using Tensorflow.UnitTest;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace TensorFlowNET.UnitTest.Basics namespace TensorFlowNET.UnitTest.Basics
{ {
[TestClass] [TestClass]
public class VariableTest
public class VariableTest : EagerModeTestBase
{ {
[TestMethod] [TestMethod]
public void NewVariable() public void NewVariable()


+ 23
- 0
test/TensorFlowNET.UnitTest/EagerModeTestBase.cs View File

@@ -0,0 +1,23 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Text;
using TensorFlowNET.UnitTest;
using static Tensorflow.Binding;

namespace Tensorflow.UnitTest
{
public class EagerModeTestBase : PythonTest
{
[TestInitialize]
public void TestInit()
{
tf.enable_eager_execution();
}

[TestCleanup]
public void TestClean()
{
}
}
}

Loading…
Cancel
Save