@@ -268,6 +268,16 @@ namespace Tensorflow | |||||
public static Tensor rank(Tensor input, string name = null) | public static Tensor rank(Tensor input, string name = null) | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Rank", name, | |||||
null, | |||||
input); | |||||
return results[0]; | |||||
} | |||||
var _op = tf._op_def_lib._apply_op_helper("Rank", name: name, args: new { input }); | var _op = tf._op_def_lib._apply_op_helper("Rank", name: name, args: new { input }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
@@ -567,7 +567,7 @@ namespace Tensorflow | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
if(x is Tensor) | |||||
if(x.rank > -1) | |||||
return constant_op.constant(np.arange(x.rank)); | return constant_op.constant(np.arange(x.rank)); | ||||
var rank = array_ops.rank(x); | var rank = array_ops.rank(x); | ||||
@@ -109,7 +109,7 @@ namespace Tensorflow.Train | |||||
return control_flow_ops.group(new[] { var_update, m_t, v_t }); | return control_flow_ops.group(new[] { var_update, m_t, v_t }); | ||||
} | } | ||||
protected override void _create_slots(RefVariable[] var_list) | |||||
protected override void _create_slots(ResourceVariable[] var_list) | |||||
{ | { | ||||
var first_var = var_list.OrderBy(x => x.Name).First(); | var first_var = var_list.OrderBy(x => x.Name).First(); | ||||
_create_non_slot_variable(initial_value: _beta1, name: "beta1_power", colocate_with: first_var); | _create_non_slot_variable(initial_value: _beta1, name: "beta1_power", colocate_with: first_var); | ||||
@@ -107,7 +107,7 @@ namespace Tensorflow | |||||
/// </returns> | /// </returns> | ||||
public Operation minimize(Tensor loss, | public Operation minimize(Tensor loss, | ||||
RefVariable global_step = null, | RefVariable global_step = null, | ||||
List<RefVariable> var_list=null, | |||||
List<ResourceVariable> var_list=null, | |||||
GateGradientType gate_gradients = GateGradientType.GATE_OP, | GateGradientType gate_gradients = GateGradientType.GATE_OP, | ||||
int? aggregation_method=null, | int? aggregation_method=null, | ||||
bool colocate_gradients_with_ops = false, string name=null, Tensor grad_loss=null) | bool colocate_gradients_with_ops = false, string name=null, Tensor grad_loss=null) | ||||
@@ -142,17 +142,17 @@ namespace Tensorflow | |||||
/// <returns> | /// <returns> | ||||
/// An `Operation` that applies the specified gradients. If `global_step` | /// An `Operation` that applies the specified gradients. If `global_step` | ||||
/// was not None, that operation also increments `global_step`.</returns> | /// was not None, that operation also increments `global_step`.</returns> | ||||
public Operation apply_gradients(Tuple<Tensor, RefVariable>[] grads_and_vars, RefVariable global_step = null, string name = null) | |||||
public Operation apply_gradients(Tuple<Tensor, ResourceVariable>[] grads_and_vars, RefVariable global_step = null, string name = null) | |||||
{ | { | ||||
// No DistributionStrategy case. | // No DistributionStrategy case. | ||||
var converted_grads_and_vars = new List<(Tensor, RefVariable, _OptimizableVariable)>(); | |||||
var converted_grads_and_vars = new List<(Tensor, ResourceVariable, _OptimizableVariable)>(); | |||||
foreach (var (g, v) in grads_and_vars) | foreach (var (g, v) in grads_and_vars) | ||||
{ | { | ||||
if(g != null) | if(g != null) | ||||
{ | { | ||||
// Convert the grad to Tensor or IndexedSlices if necessary. | // Convert the grad to Tensor or IndexedSlices if necessary. | ||||
var gR = ops.convert_to_tensor_or_indexed_slices(g); | var gR = ops.convert_to_tensor_or_indexed_slices(g); | ||||
var p = _get_processor(v); | |||||
var p = optimizer._get_processor(v); | |||||
converted_grads_and_vars.Add((gR, v, p)); | converted_grads_and_vars.Add((gR, v, p)); | ||||
} | } | ||||
} | } | ||||
@@ -230,7 +230,7 @@ namespace Tensorflow | |||||
/// silently ignored). | /// silently ignored). | ||||
/// </summary> | /// </summary> | ||||
/// <param name="var_list"></param> | /// <param name="var_list"></param> | ||||
protected virtual void _create_slots(RefVariable[] var_list) | |||||
protected virtual void _create_slots(ResourceVariable[] var_list) | |||||
{ | { | ||||
} | } | ||||
@@ -276,6 +276,12 @@ namespace Tensorflow | |||||
return control_flow_ops.group(update_ops, name_scope); | return control_flow_ops.group(update_ops, name_scope); | ||||
} | } | ||||
public virtual Operation _apply_dense(Tensor grad, ResourceVariable var) | |||||
{ | |||||
var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); | |||||
return gen_training_ops.resource_apply_gradient_descent(var.Handle, alpha, grad, use_locking: _use_locking).op; | |||||
} | |||||
public virtual Operation _apply_dense(Tensor grad, RefVariable var) | public virtual Operation _apply_dense(Tensor grad, RefVariable var) | ||||
{ | { | ||||
var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); | var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); | ||||
@@ -298,6 +304,16 @@ namespace Tensorflow | |||||
return _apply_sparse(gradient_no_duplicate_indices, var); | return _apply_sparse(gradient_no_duplicate_indices, var); | ||||
} | } | ||||
public virtual Operation _apply_sparse_duplicate_indices(IndexedSlices grad, ResourceVariable var) | |||||
{ | |||||
var (summed_values, unique_indices) = _deduplicate_indexed_slices(values: grad.values, indices: grad.indices); | |||||
var gradient_no_duplicate_indices = new IndexedSlices( | |||||
indices: unique_indices, | |||||
values: summed_values, | |||||
dense_shape: grad.dense_shape); | |||||
return _apply_sparse(gradient_no_duplicate_indices, var); | |||||
} | |||||
public virtual Operation _apply_sparse(IndexedSlices grad, RefVariable var) | public virtual Operation _apply_sparse(IndexedSlices grad, RefVariable var) | ||||
{ | { | ||||
throw new NotImplementedException("_apply_sparse"); | throw new NotImplementedException("_apply_sparse"); | ||||
@@ -344,18 +360,6 @@ namespace Tensorflow | |||||
return non_slot; | return non_slot; | ||||
} | } | ||||
private _OptimizableVariable _get_processor(RefVariable v) | |||||
{ | |||||
if(v is RefVariable) | |||||
{ | |||||
return new _RefVariableProcessor(v); | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException("_get_processor"); | |||||
} | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Compute gradients of `loss` for the variables in `var_list`. | /// Compute gradients of `loss` for the variables in `var_list`. | ||||
/// </summary> | /// </summary> | ||||
@@ -365,8 +369,8 @@ namespace Tensorflow | |||||
/// A list of (gradient, variable) pairs. Variable is always present, but | /// A list of (gradient, variable) pairs. Variable is always present, but | ||||
/// gradient can be `None`. | /// gradient can be `None`. | ||||
/// </returns> | /// </returns> | ||||
public Tuple<Tensor, RefVariable>[] compute_gradients(Tensor loss, | |||||
List<RefVariable> var_list = null, | |||||
public Tuple<Tensor, ResourceVariable>[] compute_gradients(Tensor loss, | |||||
List<ResourceVariable> var_list = null, | |||||
int? aggregation_method = null, | int? aggregation_method = null, | ||||
GateGradientType gate_gradients = GateGradientType.GATE_OP, | GateGradientType gate_gradients = GateGradientType.GATE_OP, | ||||
bool colocate_gradients_with_ops = false, | bool colocate_gradients_with_ops = false, | ||||
@@ -374,26 +378,28 @@ namespace Tensorflow | |||||
{ | { | ||||
// Scale loss if using a "mean" loss reduction and multiple replicas. | // Scale loss if using a "mean" loss reduction and multiple replicas. | ||||
loss = _scale_loss(loss); | loss = _scale_loss(loss); | ||||
#pragma warning disable CS0219 // Variable is assigned but its value is never used | |||||
int num_towers = 1; | |||||
#pragma warning restore CS0219 // Variable is assigned but its value is never used | |||||
if(var_list == null) | if(var_list == null) | ||||
{ | { | ||||
var vars = ops.get_collection<RefVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); | |||||
var vars = ops.get_collection<ResourceVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); | |||||
var tmp = variables.trainable_variables(); | var tmp = variables.trainable_variables(); | ||||
switch (tmp) | switch (tmp) | ||||
{ | { | ||||
case List<RefVariable> values: | |||||
case List<ResourceVariable> values: | |||||
var_list = values.Concat(vars).ToList(); | |||||
break; | |||||
/*case List<RefVariable> values: | |||||
var_list = values.Concat(vars).ToList(); | var_list = values.Concat(vars).ToList(); | ||||
break; | break; | ||||
case List<IVariableV1> values: | case List<IVariableV1> values: | ||||
var_list = values.Select(x => x as RefVariable).Concat(vars).ToList(); | var_list = values.Select(x => x as RefVariable).Concat(vars).ToList(); | ||||
break; | |||||
break;*/ | |||||
default: | |||||
throw new NotImplementedException(""); | |||||
} | } | ||||
} | } | ||||
var_list = var_list.Concat(ops.get_collection<RefVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); | |||||
var_list = var_list.Concat(ops.get_collection<ResourceVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); | |||||
var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); | var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); | ||||
var var_refs = processors.Select(x => x.target()).ToArray(); | var var_refs = processors.Select(x => x.target()).ToArray(); | ||||
@@ -406,7 +412,7 @@ namespace Tensorflow | |||||
grads = control_flow_ops.tuple(grads); | grads = control_flow_ops.tuple(grads); | ||||
var grads_and_vars = zip(grads, var_list) | var grads_and_vars = zip(grads, var_list) | ||||
.Select(x => new Tuple<Tensor, RefVariable>(x.Item1, x.Item2)) | |||||
.Select(x => new Tuple<Tensor, ResourceVariable>(x.Item1, x.Item2)) | |||||
.ToArray(); | .ToArray(); | ||||
return grads_and_vars; | return grads_and_vars; | ||||
@@ -59,7 +59,7 @@ namespace Tensorflow | |||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
public static Operation resource_apply_gradient_descent(EagerTensor var, EagerTensor alpha, EagerTensor delta, bool use_locking = false, string name = null) | |||||
public static Operation resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
@@ -79,7 +79,7 @@ namespace Tensorflow | |||||
use_locking | use_locking | ||||
}); | }); | ||||
return _op.outputs[0]; | |||||
return _op; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -24,6 +24,11 @@ namespace Tensorflow | |||||
{ | { | ||||
return new _RefVariableProcessor(v); | return new _RefVariableProcessor(v); | ||||
} | } | ||||
public static _OptimizableVariable _get_processor(ResourceVariable v) | |||||
{ | |||||
return new _DenseResourceVariableProcessor(v); | |||||
} | |||||
} | } | ||||
public class _RefVariableProcessor : _OptimizableVariable | public class _RefVariableProcessor : _OptimizableVariable | ||||
@@ -56,4 +61,35 @@ namespace Tensorflow | |||||
return update_op; | return update_op; | ||||
} | } | ||||
} | } | ||||
public class _DenseResourceVariableProcessor : _OptimizableVariable | |||||
{ | |||||
private ResourceVariable _v; | |||||
public _DenseResourceVariableProcessor(ResourceVariable v) | |||||
{ | |||||
_v = v; | |||||
} | |||||
public Tensor target() | |||||
{ | |||||
return _v.Handle; | |||||
} | |||||
public Operation update_op(Optimizer optimizer, Tensor g) | |||||
{ | |||||
Operation update_op = null; | |||||
if (g.Tag == null) | |||||
{ | |||||
update_op = optimizer._apply_dense(g, _v); | |||||
} | |||||
else if (g.Tag is IndexedSlices) | |||||
{ | |||||
return optimizer._apply_sparse_duplicate_indices(g, _v); | |||||
} | |||||
return update_op; | |||||
} | |||||
} | |||||
} | } |
@@ -1,438 +0,0 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using System.Collections; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics.CodeAnalysis; | |||||
namespace Tensorflow | |||||
{ | |||||
public class WeakKeyDictionary<TKey, TValue> : IDictionary<TKey, TValue> | |||||
{ | |||||
private Dictionary<WeakKey, TValue> _internalDictionary; | |||||
private object _internalObject = new object(); | |||||
private bool _finalized; | |||||
public WeakKeyDictionary() | |||||
{ | |||||
_internalDictionary = new Dictionary<WeakKey, TValue>(new WeakComparer()); | |||||
} | |||||
public WeakKeyDictionary(int capacity) | |||||
{ | |||||
_internalDictionary = new Dictionary<WeakKey, TValue>(capacity, new WeakComparer()); | |||||
} | |||||
public WeakKeyDictionary(IEqualityComparer<TKey> comparer) | |||||
{ | |||||
_internalDictionary = new Dictionary<WeakKey, TValue>(new WeakComparer(comparer)); | |||||
} | |||||
public WeakKeyDictionary(int capacity, IEqualityComparer<TKey> comparer) | |||||
{ | |||||
_internalDictionary = new Dictionary<WeakKey, TValue>(capacity, new WeakComparer(comparer)); | |||||
} | |||||
// FXCop: this is not empty; we need to mark this so we know if a key | |||||
// still has an active dictionary at its finalization. | |||||
[SuppressMessage("Microsoft.Performance", "CA1821:RemoveEmptyFinalizers")] | |||||
~WeakKeyDictionary() | |||||
{ | |||||
_finalized = true; | |||||
} | |||||
public ICollection<TKey> Keys | |||||
{ | |||||
get | |||||
{ | |||||
List<TKey> list = new List<TKey>(); | |||||
lock (_internalObject) | |||||
{ | |||||
foreach (WeakKey key in _internalDictionary.Keys) | |||||
{ | |||||
object TKey = key.Target; | |||||
if (TKey != null) | |||||
{ | |||||
list.Add((TKey)TKey); | |||||
} | |||||
} | |||||
} | |||||
return list; | |||||
} | |||||
} | |||||
public ICollection<TValue> Values | |||||
{ | |||||
get { | |||||
lock (_internalObject) { | |||||
return _internalDictionary.Values; | |||||
} | |||||
} | |||||
} | |||||
public int Count | |||||
{ | |||||
get | |||||
{ | |||||
// Ensure a fairly accurate count. | |||||
ScavangeLostKeys(); | |||||
lock (_internalObject) | |||||
{ | |||||
return _internalDictionary.Count; | |||||
} | |||||
} | |||||
} | |||||
public bool IsReadOnly | |||||
{ | |||||
get { | |||||
return false; | |||||
} | |||||
} | |||||
[SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")] | |||||
public TValue this[TKey key] | |||||
{ | |||||
get { | |||||
lock (_internalObject) { | |||||
return _internalDictionary[new WeakKey(key)]; | |||||
} | |||||
} | |||||
set | |||||
{ | |||||
WeakKey Tkey = new WeakKey(key); | |||||
lock (_internalObject) | |||||
{ | |||||
//_internalDictionary[Tkey] = value; | |||||
_internalDictionary.Add(Tkey, value); | |||||
} | |||||
// This looks a bit weird but the purpose of the lost key finder is to execute | |||||
// code in some future garbage collection phase so we immediately create some garbage. | |||||
new LostKeyFinder(this, Tkey); | |||||
} | |||||
} | |||||
public bool TryGetValue(TKey key, out TValue value) | |||||
{ | |||||
WeakKey tkey = new WeakKey(key); | |||||
lock (_internalObject) | |||||
{ | |||||
return _internalDictionary.TryGetValue(tkey, out value); | |||||
} | |||||
} | |||||
[SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")] | |||||
public void Add(TKey key, TValue value) | |||||
{ | |||||
WeakKey tkey = new WeakKey(key); | |||||
lock (_internalObject) | |||||
{ | |||||
_internalDictionary.Add(tkey, value); | |||||
} | |||||
// This looks a bit weird but the purpose of the lost key finder is to execute | |||||
// code in some future garbage collection phase so we immediately create some garbage. | |||||
new LostKeyFinder(this, tkey); | |||||
} | |||||
public bool ContainsKey(TKey key) | |||||
{ | |||||
return _internalDictionary.ContainsKey(new WeakKey(key)); | |||||
} | |||||
public bool Remove(TKey key) | |||||
{ | |||||
lock (_internalObject) | |||||
{ | |||||
return _internalDictionary.Remove(new WeakKey(key)); | |||||
} | |||||
} | |||||
public void Add(KeyValuePair<TKey, TValue> item) | |||||
{ | |||||
Add(item.Key, item.Value); | |||||
} | |||||
public void Clear() | |||||
{ | |||||
lock (_internalObject) | |||||
{ | |||||
_internalDictionary.Clear(); | |||||
} | |||||
} | |||||
public bool Contains(KeyValuePair<TKey, TValue> item) | |||||
{ | |||||
TValue value; | |||||
bool result; | |||||
lock (_internalObject) | |||||
{ | |||||
result = _internalDictionary.TryGetValue(new WeakKey(item.Key), out value); | |||||
} | |||||
if (result) | |||||
{ | |||||
return value.Equals(item.Value); | |||||
} | |||||
else | |||||
{ | |||||
return false; | |||||
} | |||||
} | |||||
public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex) | |||||
{ | |||||
lock (_internalObject) | |||||
{ | |||||
foreach (KeyValuePair<WeakKey, TValue> item in _internalDictionary) | |||||
{ | |||||
KeyValuePair<TKey, TValue> kv = new KeyValuePair<TKey, TValue>((TKey)item.Key.Target, item.Value); | |||||
array[arrayIndex] = kv; | |||||
arrayIndex++; | |||||
} | |||||
} | |||||
} | |||||
public bool Remove(KeyValuePair<TKey, TValue> item) | |||||
{ | |||||
WeakKey key = new WeakKey(item.Key); | |||||
lock (_internalObject) | |||||
{ | |||||
return _internalDictionary.Remove(key); | |||||
} | |||||
} | |||||
public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator() | |||||
{ | |||||
List<WeakKey> lostKeys = null; | |||||
lock (_internalObject) | |||||
{ | |||||
foreach (KeyValuePair<WeakKey, TValue> item in _internalDictionary) | |||||
{ | |||||
object TKey = item.Key.Target; | |||||
if (TKey != null) | |||||
{ | |||||
yield return new KeyValuePair<TKey, TValue>((TKey)TKey, item.Value); | |||||
} | |||||
else | |||||
{ | |||||
if (lostKeys == null) | |||||
{ | |||||
lostKeys = new List<WeakKey>(); | |||||
} | |||||
lostKeys.Add(item.Key); | |||||
} | |||||
} | |||||
} | |||||
// Recover any lost keys. | |||||
if (lostKeys != null) | |||||
{ | |||||
lock (_internalObject) | |||||
{ | |||||
foreach (WeakKey key in lostKeys) | |||||
{ | |||||
_internalDictionary.Remove(key); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
IEnumerator IEnumerable.GetEnumerator() | |||||
{ | |||||
return GetEnumerator(); | |||||
} | |||||
private void ScavangeLostKeys() | |||||
{ | |||||
List<WeakKey> lostKeys = null; | |||||
lock (_internalObject) | |||||
{ | |||||
foreach (WeakKey key in _internalDictionary.Keys) | |||||
{ | |||||
if (!key.IsAlive) | |||||
{ | |||||
if (lostKeys == null) | |||||
{ | |||||
lostKeys = new List<WeakKey>(); | |||||
} | |||||
lostKeys.Add(key); | |||||
} | |||||
} | |||||
} | |||||
if (lostKeys != null) | |||||
{ | |||||
lock (_internalObject) | |||||
{ | |||||
foreach (WeakKey key in lostKeys) | |||||
{ | |||||
_internalDictionary.Remove(key); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
IEnumerator<KeyValuePair<TKey, TValue>> IEnumerable<KeyValuePair<TKey, TValue>>.GetEnumerator() | |||||
{ | |||||
return this.GetEnumerator(); | |||||
} | |||||
private class WeakKey : WeakReference | |||||
{ | |||||
private int _hashCode; | |||||
// private GCHandle _gcHandle; | |||||
public WeakKey(TKey key) | |||||
: base(key, true) | |||||
{ | |||||
_hashCode = key.GetHashCode(); | |||||
// Keep the key alive until it is explicitly collected | |||||
// _gcHandle = GCHandle.Alloc(this); | |||||
} | |||||
internal void Release() | |||||
{ | |||||
// _gcHandle.Free(); | |||||
} | |||||
public override int GetHashCode() | |||||
{ | |||||
return _hashCode; | |||||
} | |||||
public override bool Equals(object obj) | |||||
{ | |||||
if (obj == null) | |||||
{ | |||||
return false; | |||||
} | |||||
if (obj.GetHashCode() != _hashCode) | |||||
{ | |||||
return false; | |||||
} | |||||
if (obj != this && (!IsAlive || !obj.Equals(Target))) | |||||
{ | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
} | |||||
private class WeakComparer : IEqualityComparer<WeakKey> | |||||
{ | |||||
private IEqualityComparer<TKey> _comparer; | |||||
public WeakComparer() | |||||
{ | |||||
} | |||||
public WeakComparer(IEqualityComparer<TKey> comparer) | |||||
{ | |||||
_comparer = comparer; | |||||
} | |||||
public bool Equals(WeakKey x, WeakKey y) | |||||
{ | |||||
if (x.GetHashCode() != y.GetHashCode()) | |||||
{ | |||||
return false; | |||||
} | |||||
if (object.ReferenceEquals(x, y)) | |||||
{ | |||||
return true; | |||||
} | |||||
object ref1 = x.Target; | |||||
if (ref1 == null) | |||||
{ | |||||
return false; | |||||
} | |||||
object ref2 = y.Target; | |||||
if (ref2 == null) | |||||
{ | |||||
return false; | |||||
} | |||||
if (_comparer != null) | |||||
{ | |||||
return _comparer.Equals((TKey)ref1, (TKey)ref2); | |||||
} | |||||
else | |||||
{ | |||||
return ref1.Equals(ref2); | |||||
} | |||||
} | |||||
public int GetHashCode(WeakKey obj) | |||||
{ | |||||
return obj.GetHashCode(); | |||||
} | |||||
} | |||||
private class LostKeyFinder | |||||
{ | |||||
WeakKeyDictionary<TKey, TValue> _dictionary; | |||||
WeakKey _key; | |||||
public LostKeyFinder(WeakKeyDictionary<TKey, TValue> dictionary, WeakKey key) | |||||
{ | |||||
_dictionary = dictionary; | |||||
_key = key; | |||||
} | |||||
~LostKeyFinder() | |||||
{ | |||||
if (_dictionary._finalized || _key == null) | |||||
{ | |||||
if (_key != null) | |||||
{ | |||||
_key.Release(); | |||||
_key = null; | |||||
} | |||||
return; | |||||
} | |||||
// if (!_key.IsAlive) { | |||||
if (_key.Target == null) | |||||
{ | |||||
lock (_dictionary._internalObject) | |||||
{ | |||||
_dictionary._internalDictionary.Remove(_key); | |||||
} | |||||
_key.Release(); | |||||
_key = null; | |||||
} | |||||
else if (_dictionary._internalDictionary.ContainsKey(_key)) | |||||
{ | |||||
GC.ReRegisterForFinalize(this); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
@@ -2,13 +2,13 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using NumSharp; | using NumSharp; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow; | |||||
using Tensorflow.UnitTest; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace TensorFlowNET.UnitTest.Basics | namespace TensorFlowNET.UnitTest.Basics | ||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class VariableTest | |||||
public class VariableTest : EagerModeTestBase | |||||
{ | { | ||||
[TestMethod] | [TestMethod] | ||||
public void NewVariable() | public void NewVariable() | ||||
@@ -0,0 +1,23 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using TensorFlowNET.UnitTest; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.UnitTest | |||||
{ | |||||
public class EagerModeTestBase : PythonTest | |||||
{ | |||||
[TestInitialize] | |||||
public void TestInit() | |||||
{ | |||||
tf.enable_eager_execution(); | |||||
} | |||||
[TestCleanup] | |||||
public void TestClean() | |||||
{ | |||||
} | |||||
} | |||||
} |