@@ -11,10 +11,26 @@ namespace Tensorflow.Framework | |||||
{ | { | ||||
Tensor _values; | Tensor _values; | ||||
public Tensor values => _values; | public Tensor values => _values; | ||||
Tensor _indices; | |||||
public Tensor indices => _indices; | |||||
Tensor _dense_shape; | |||||
public Tensor dense_shape => _dense_shape; | |||||
public string name => _values.name; | |||||
public string device => _values.Device; | |||||
public Operation op => _values.op; | |||||
public TF_DataType dtype => _values.dtype; | |||||
public Graph graph => _values.graph; | |||||
public IndexedSlices(Tensor values, Tensor indices, Tensor dense_shape = null) | public IndexedSlices(Tensor values, Tensor indices, Tensor dense_shape = null) | ||||
{ | { | ||||
_values = values; | |||||
_indices = indices; | |||||
_dense_shape = dense_shape; | |||||
} | } | ||||
public static implicit operator Tensor(IndexedSlices indexedSlices) | public static implicit operator Tensor(IndexedSlices indexedSlices) | ||||
@@ -83,13 +83,13 @@ namespace Tensorflow.Gradients | |||||
new Tensor[] { non_neg_concat_dim, tf.constant(0) }, | new Tensor[] { non_neg_concat_dim, tf.constant(0) }, | ||||
new Tensor[] { tf.constant(1), tf.constant(-1) }); | new Tensor[] { tf.constant(1), tf.constant(-1) }); | ||||
var squeeze_sizes = array_ops.squeeze(slice); | var squeeze_sizes = array_ops.squeeze(slice); | ||||
out_grads = gen_ops.split(grad, squeeze_sizes, non_neg_concat_dim).ToList(); | |||||
out_grads = gen_array_ops.split(grad, squeeze_sizes, non_neg_concat_dim).ToList(); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
var offset = gen_ops.concat_offset(non_neg_concat_dim, sizes); | |||||
var offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes); | |||||
foreach (var (begin, size) in zip(offset, sizes)) | foreach (var (begin, size) in zip(offset, sizes)) | ||||
out_grads.Add(gen_ops.slice(grad, begin, size)); | |||||
out_grads.Add(gen_array_ops.slice(grad, begin, size)); | |||||
} | } | ||||
return (end_value_index <= dim_index ? | return (end_value_index <= dim_index ? | ||||
@@ -129,7 +129,7 @@ namespace Tensorflow.Gradients | |||||
if (fully_known) | if (fully_known) | ||||
return sizes; | return sizes; | ||||
else | else | ||||
return gen_ops.shape_n(inputs); | |||||
return gen_array_ops.shape_n(inputs); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -93,10 +93,7 @@ namespace Tensorflow | |||||
{ | { | ||||
// generate gradient subgraph for op. | // generate gradient subgraph for op. | ||||
var op = queue.Dequeue(); | var op = queue.Dequeue(); | ||||
if(op.name == "embedding/ExpandDims") | |||||
{ | |||||
} | |||||
_maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); | _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); | ||||
//if (loop_state != null) | //if (loop_state != null) | ||||
//loop_state.EnterGradWhileContext(op, before: true); | //loop_state.EnterGradWhileContext(op, before: true); | ||||
@@ -311,9 +308,10 @@ namespace Tensorflow | |||||
// Aggregate multiple gradients, and convert [] to None. | // Aggregate multiple gradients, and convert [] to None. | ||||
if (out_grad.Count > 0) | if (out_grad.Count > 0) | ||||
{ | { | ||||
string used = ""; | |||||
if (out_grad.Count < 2) | if (out_grad.Count < 2) | ||||
{ | { | ||||
string used = "nop"; | |||||
used = "nop"; | |||||
if (out_grad.Count == 0) | if (out_grad.Count == 0) | ||||
{ | { | ||||
throw new ValueError("_AggregatedGrads out_grad.Length == 0"); | throw new ValueError("_AggregatedGrads out_grad.Length == 0"); | ||||
@@ -321,6 +319,11 @@ namespace Tensorflow | |||||
return_grads[i] = out_grad[0]; | return_grads[i] = out_grad[0]; | ||||
} | } | ||||
else | |||||
{ | |||||
used = "add_n"; | |||||
out_grads[i] = new List<Tensor> { _MultiDeviceAddN(out_grad.ToArray(), gradient_uid) }; | |||||
} | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -331,6 +334,38 @@ namespace Tensorflow | |||||
return return_grads; | return return_grads; | ||||
} | } | ||||
/// <summary> | |||||
/// Adds tensors from potentially multiple devices. | |||||
/// </summary> | |||||
/// <param name="tensor_list"></param> | |||||
/// <param name="gradient_uid"></param> | |||||
/// <returns></returns> | |||||
private static Tensor _MultiDeviceAddN(Tensor[] tensor_list, string gradient_uid) | |||||
{ | |||||
// Basic function structure comes from control_flow_ops.group(). | |||||
// Sort tensors according to their devices. | |||||
var tensors_on_device = new Dictionary<string, List<Tensor>>(); | |||||
foreach (var tensor in tensor_list) | |||||
{ | |||||
if (!tensors_on_device.ContainsKey(tensor.Device)) | |||||
tensors_on_device[tensor.Device] = new List<Tensor>(); | |||||
tensors_on_device[tensor.Device].Add(tensor); | |||||
} | |||||
// For each device, add the tensors on that device first. | |||||
var summands = new List<Tensor>(); | |||||
foreach(var dev in tensors_on_device.Keys) | |||||
{ | |||||
var tensors = tensors_on_device[dev]; | |||||
ops._colocate_with_for_gradient(tensors[0].op, gradient_uid, ignore_existing: true); | |||||
summands.Add(math_ops.add_n(tensors.ToArray())); | |||||
} | |||||
return math_ops.add_n(summands.ToArray()); | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// The set of ops that terminate the gradient computation. | /// The set of ops that terminate the gradient computation. | ||||
/// </summary> | /// </summary> | ||||
@@ -276,6 +276,9 @@ namespace Tensorflow | |||||
}); | }); | ||||
} | } | ||||
public static (Tensor, Tensor) unique(Tensor x, TF_DataType out_idx = TF_DataType.TF_INT32, string name = null) | |||||
=> gen_array_ops.unique(x, out_idx: out_idx, name: name); | |||||
public static Tensor where(Tensor condition, object x = null, object y = null, string name = null) | public static Tensor where(Tensor condition, object x = null, object y = null, string name = null) | ||||
{ | { | ||||
if( x == null && y == null) | if( x == null && y == null) | ||||
@@ -26,6 +26,13 @@ namespace Tensorflow | |||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
public static Tensor[] concat_offset(Tensor concat_dim, Tensor[] shape, string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("ConcatOffset", name: name, args: new { concat_dim, shape }); | |||||
return _op.outputs; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Returns a diagonal tensor with a given diagonal values. | /// Returns a diagonal tensor with a given diagonal values. | ||||
/// </summary> | /// </summary> | ||||
@@ -205,6 +212,21 @@ namespace Tensorflow | |||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
/// <summary> | |||||
/// Finds unique elements in a 1-D tensor. | |||||
/// </summary> | |||||
/// <param name="x"></param> | |||||
/// <param name="out_idx"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static (Tensor, Tensor) unique(Tensor x, TF_DataType out_idx = TF_DataType.TF_INT32, string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("Unique", name, new { x, out_idx }); | |||||
// TODO | |||||
throw new NotImplementedException("_result = _UniqueOutput._make(_result)"); | |||||
// return _op.outputs[0]; | |||||
} | |||||
public static Tensor where() | public static Tensor where() | ||||
{ | { | ||||
throw new NotImplementedException("where"); | throw new NotImplementedException("where"); | ||||
@@ -271,6 +293,26 @@ namespace Tensorflow | |||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
/// <summary> | |||||
/// Return a slice from 'input' | |||||
/// </summary> | |||||
/// <param name="input"></param> | |||||
/// <param name="begin"></param> | |||||
/// <param name="size"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size }); | |||||
return _op.outputs[0]; | |||||
} | |||||
public static Tensor[] split(Tensor axis, Tensor value, int num_split, string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split }); | |||||
return _op.outputs; | |||||
} | |||||
public static Tensor tile(Tensor input, Tensor multiples, string name = null) | public static Tensor tile(Tensor input, Tensor multiples, string name = null) | ||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("Tile", name, new { input, multiples }); | var _op = _op_def_lib._apply_op_helper("Tile", name, new { input, multiples }); | ||||
@@ -16,6 +16,19 @@ namespace Tensorflow | |||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
/// <summary> | |||||
/// Add all input tensors element wise. | |||||
/// </summary> | |||||
/// <param name="inputs"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor add_n(Tensor[] inputs, string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("AddN", name, args: new { inputs }); | |||||
return _op.outputs[0]; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Returns the index with the largest value across dimensions of a tensor. | /// Returns the index with the largest value across dimensions of a tensor. | ||||
/// </summary> | /// </summary> | ||||
@@ -198,6 +211,20 @@ namespace Tensorflow | |||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
/// <summary> | |||||
/// Computes the sum along segments of a tensor. | |||||
/// </summary> | |||||
/// <param name="data"></param> | |||||
/// <param name="segment_ids"></param> | |||||
/// <param name="num_segments"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor unsorted_segment_sum(Tensor data, Tensor segment_ids, Tensor num_segments, string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("UnsortedSegmentSum", name, new { data, segment_ids, num_segments }); | |||||
return _op.outputs[0]; | |||||
} | |||||
public static Tensor tan(Tensor x, string name = null) | public static Tensor tan(Tensor x, string name = null) | ||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("Tan", name, args: new { x }); | var _op = _op_def_lib._apply_op_helper("Tan", name, args: new { x }); | ||||
@@ -44,8 +44,8 @@ namespace Tensorflow | |||||
return array_ops.identity(values, name: name); | return array_ops.identity(values, name: name); | ||||
return values; | return values; | ||||
} | } | ||||
throw new NotImplementedException("math_ops add_n n > 1"); | |||||
// return gen_math_ops.add_n(inputs, name: name); | |||||
return gen_math_ops.add_n(inputs, name: name); | |||||
} | } | ||||
public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) | public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) | ||||
@@ -126,6 +126,9 @@ namespace Tensorflow | |||||
public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
=> gen_math_ops.equal(x, y, name: name); | => gen_math_ops.equal(x, y, name: name); | ||||
public static Tensor sqrt(Tensor x, string name = null) | |||||
=> gen_math_ops.sqrt(x, name: name); | |||||
public static Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
=> gen_math_ops.mul(x, y, name: name); | => gen_math_ops.mul(x, y, name: name); | ||||
@@ -319,6 +322,17 @@ namespace Tensorflow | |||||
return _may_reduce_to_scalar(keepdims, axis, min); | return _may_reduce_to_scalar(keepdims, axis, min); | ||||
} | } | ||||
/// <summary> | |||||
/// Computes the sum along segments of a tensor. | |||||
/// </summary> | |||||
/// <param name="data"></param> | |||||
/// <param name="segment_ids"></param> | |||||
/// <param name="num_segments"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor unsorted_segment_sum(Tensor data, Tensor segment_ids, Tensor num_segments, string name = null) | |||||
=> gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments, name: name); | |||||
/// <summary> | /// <summary> | ||||
/// Casts a tensor to type `int32`. | /// Casts a tensor to type `int32`. | ||||
/// </summary> | /// </summary> | ||||
@@ -5,10 +5,10 @@ | |||||
<AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
<RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
<TargetTensorFlow>1.14.0</TargetTensorFlow> | <TargetTensorFlow>1.14.0</TargetTensorFlow> | ||||
<Version>0.8.1</Version> | |||||
<Version>0.8.2</Version> | |||||
<Authors>Haiping Chen</Authors> | <Authors>Haiping Chen</Authors> | ||||
<Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||||
<GeneratePackageOnBuild>false</GeneratePackageOnBuild> | |||||
<Copyright>Apache 2.0</Copyright> | <Copyright>Apache 2.0</Copyright> | ||||
<RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl> | <RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl> | ||||
<RepositoryType>git</RepositoryType> | <RepositoryType>git</RepositoryType> | ||||
@@ -17,14 +17,15 @@ | |||||
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | ||||
<Description>Google's TensorFlow full binding in .NET Standard. | <Description>Google's TensorFlow full binding in .NET Standard. | ||||
Docs: https://tensorflownet.readthedocs.io</Description> | Docs: https://tensorflownet.readthedocs.io</Description> | ||||
<AssemblyVersion>0.8.1.0</AssemblyVersion> | |||||
<AssemblyVersion>0.8.2.0</AssemblyVersion> | |||||
<PackageReleaseNotes>Changes since v0.8: | <PackageReleaseNotes>Changes since v0.8: | ||||
1. Remove global static graph instance. | 1. Remove global static graph instance. | ||||
2. Provide custom gradient function. | 2. Provide custom gradient function. | ||||
3. Add gradient function for Conv2D.</PackageReleaseNotes> | |||||
3. Add gradient function for Conv2D. | |||||
4. Fix bug for Transfer Learning example.</PackageReleaseNotes> | |||||
<LangVersion>7.2</LangVersion> | <LangVersion>7.2</LangVersion> | ||||
<FileVersion>0.8.1.0</FileVersion> | |||||
<FileVersion>0.8.2.0</FileVersion> | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||
@@ -42,6 +43,10 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
<None Remove="runtimes\**" /> | <None Remove="runtimes\**" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | |||||
<Compile Remove="Operations\gen_ops.cs" /> | |||||
</ItemGroup> | |||||
<ItemGroup> | <ItemGroup> | ||||
<None Remove="Protobuf\README.md" /> | <None Remove="Protobuf\README.md" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -1,6 +1,8 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Framework; | |||||
using static Tensorflow.Python; | |||||
namespace Tensorflow.Train | namespace Tensorflow.Train | ||||
{ | { | ||||
@@ -10,9 +12,10 @@ namespace Tensorflow.Train | |||||
/// </summary> | /// </summary> | ||||
public class AdamOptimizer : Optimizer | public class AdamOptimizer : Optimizer | ||||
{ | { | ||||
private float _beta1; | |||||
private float _beta2; | |||||
private float _epsilon; | |||||
float _beta1; | |||||
float _beta2; | |||||
float _epsilon; | |||||
Tensor _lr_t, _beta1_t, _beta2_t, _epsilon_t; | |||||
public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, string name = "Adam") | public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, string name = "Adam") | ||||
: base(learning_rate, use_locking, name) | : base(learning_rate, use_locking, name) | ||||
@@ -21,5 +24,51 @@ namespace Tensorflow.Train | |||||
_beta2 = beta2; | _beta2 = beta2; | ||||
_epsilon = epsilon; | _epsilon = epsilon; | ||||
} | } | ||||
public override Operation _apply_sparse(IndexedSlices grad, RefVariable var) | |||||
{ | |||||
return _apply_sparse_shared(grad.values, var, grad.indices, (x, i, v) => | |||||
{ | |||||
return state_ops.scatter_add(x, i, v, use_locking: _use_locking); | |||||
}); | |||||
} | |||||
private Operation _apply_sparse_shared(Tensor grad, RefVariable var, Tensor indices, Func<RefVariable, Tensor, Tensor, Tensor> scatter_add) | |||||
{ | |||||
var (beta1_power_v, beta2_power_v) = _get_beta_accumulators(); | |||||
Tensor beta1_power = math_ops.cast(beta1_power_v, var.dtype.as_base_dtype()); | |||||
Tensor beta2_power = math_ops.cast(beta2_power_v, var.dtype.as_base_dtype()); | |||||
var lr_t = math_ops.cast(_lr_t, var.dtype.as_base_dtype()); | |||||
var beta1_t = math_ops.cast(_beta1_t, var.dtype.as_base_dtype()); | |||||
var beta2_t = math_ops.cast(_beta2_t, var.dtype.as_base_dtype()); | |||||
var epsilon_t = math_ops.cast(_epsilon_t, var.dtype.as_base_dtype()); | |||||
var lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)); | |||||
var m = get_slot(var, "m"); | |||||
var m_scaled_g_values = grad * (1 - beta1_t); | |||||
var m_t = state_ops.assign(m, m * beta1_t, use_locking: _use_locking); | |||||
with(ops.control_dependencies(new[] { m_t }), delegate | |||||
{ | |||||
m_t = scatter_add(m, indices, m_scaled_g_values); | |||||
}); | |||||
var v = get_slot(var, "v"); | |||||
var v_scaled_g_values = (grad * grad) * (1 - beta2_t); | |||||
var v_t = state_ops.assign(v, v * beta2_t, use_locking: _use_locking); | |||||
with(ops.control_dependencies(new[] { v_t }), delegate | |||||
{ | |||||
v_t = scatter_add(v, indices, v_scaled_g_values); | |||||
}); | |||||
var v_sqrt = math_ops.sqrt(v_t); | |||||
var var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking: _use_locking); | |||||
return control_flow_ops.group(new[] { var_update, m_t, v_t }); | |||||
} | |||||
private (RefVariable, RefVariable) _get_beta_accumulators() | |||||
{ | |||||
ops.init_scope(); | |||||
var graph = ops.get_default_graph(); | |||||
return (_get_non_slot_variable("beta1_power", graph: graph), | |||||
_get_non_slot_variable("beta2_power", graph: graph)); | |||||
} | |||||
} | } | ||||
} | } |
@@ -2,6 +2,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Framework; | |||||
using static Tensorflow.Python; | using static Tensorflow.Python; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -23,8 +24,8 @@ namespace Tensorflow | |||||
public float LearningRate { get; set; } | public float LearningRate { get; set; } | ||||
public Tensor LearningRateTensor { get; set; } | public Tensor LearningRateTensor { get; set; } | ||||
public bool _use_locking; | public bool _use_locking; | ||||
public Dictionary<string, object> _slots; | |||||
public Dictionary<string, object> _non_slot_dict; | |||||
public Dictionary<string, Dictionary<string, RefVariable>> _slots; | |||||
public Dictionary<string, RefVariable> _non_slot_dict; | |||||
public Dictionary<string, object> _deferred_slot_restorations; | public Dictionary<string, object> _deferred_slot_restorations; | ||||
public Optimizer(float learning_rate, bool use_locking, string name = null) | public Optimizer(float learning_rate, bool use_locking, string name = null) | ||||
@@ -36,8 +37,8 @@ namespace Tensorflow | |||||
_use_locking = use_locking; | _use_locking = use_locking; | ||||
LearningRate = learning_rate; | LearningRate = learning_rate; | ||||
// Dictionary of slots. | // Dictionary of slots. | ||||
_slots = new Dictionary<string, object>(); | |||||
_non_slot_dict = new Dictionary<string, object>(); | |||||
_slots = new Dictionary<string, Dictionary<string, RefVariable>>(); | |||||
_non_slot_dict = new Dictionary<string, RefVariable>(); | |||||
_deferred_slot_restorations = new Dictionary<string, object>(); | _deferred_slot_restorations = new Dictionary<string, object>(); | ||||
} | } | ||||
@@ -110,7 +111,7 @@ namespace Tensorflow | |||||
public Operation apply_gradients(Tuple<Tensor, RefVariable>[] grads_and_vars, RefVariable global_step = null, string name = null) | public Operation apply_gradients(Tuple<Tensor, RefVariable>[] grads_and_vars, RefVariable global_step = null, string name = null) | ||||
{ | { | ||||
// No DistributionStrategy case. | // No DistributionStrategy case. | ||||
var converted_grads_and_vars = new List<Tuple<Tensor, RefVariable, _OptimizableVariable>>(); | |||||
var converted_grads_and_vars = new List<(Tensor, RefVariable, _OptimizableVariable)>(); | |||||
foreach (var (g, v) in grads_and_vars) | foreach (var (g, v) in grads_and_vars) | ||||
{ | { | ||||
if(g != null) | if(g != null) | ||||
@@ -118,7 +119,7 @@ namespace Tensorflow | |||||
// Convert the grad to Tensor or IndexedSlices if necessary. | // Convert the grad to Tensor or IndexedSlices if necessary. | ||||
var gR = ops.convert_to_tensor_or_indexed_slices(g); | var gR = ops.convert_to_tensor_or_indexed_slices(g); | ||||
var p = _get_processor(v); | var p = _get_processor(v); | ||||
converted_grads_and_vars.Add(new Tuple<Tensor, RefVariable, _OptimizableVariable>(gR, v, p)); | |||||
converted_grads_and_vars.Add((gR, v, p)); | |||||
} | } | ||||
} | } | ||||
@@ -143,7 +144,8 @@ namespace Tensorflow | |||||
var scope_name = var.op.name; | var scope_name = var.op.name; | ||||
with(ops.name_scope("update_" + scope_name), scope2 => | with(ops.name_scope("update_" + scope_name), scope2 => | ||||
{ | { | ||||
update_ops.Add(processor.update_op(this, grad)); | |||||
var op = processor.update_op(this, grad); | |||||
update_ops.Add(op); | |||||
}); | }); | ||||
} | } | ||||
@@ -201,11 +203,69 @@ namespace Tensorflow | |||||
return gen_training_ops.apply_gradient_descent(var, alpha, grad, use_locking: _use_locking).op; | return gen_training_ops.apply_gradient_descent(var, alpha, grad, use_locking: _use_locking).op; | ||||
} | } | ||||
/// <summary> | |||||
/// Add ops to apply sparse gradients to `var`, with repeated sparse indices. | |||||
/// </summary> | |||||
/// <param name="grad"></param> | |||||
/// <param name="var"></param> | |||||
/// <returns></returns> | |||||
public virtual Operation _apply_sparse_duplicate_indices(IndexedSlices grad, RefVariable var) | |||||
{ | |||||
var (summed_values, unique_indices) = _deduplicate_indexed_slices(values: grad.values, indices: grad.indices); | |||||
var gradient_no_duplicate_indices = new IndexedSlices( | |||||
indices: unique_indices, | |||||
values: summed_values, | |||||
dense_shape: grad.dense_shape); | |||||
return _apply_sparse(gradient_no_duplicate_indices, var); | |||||
} | |||||
public virtual Operation _apply_sparse(IndexedSlices grad, RefVariable var) | |||||
{ | |||||
throw new NotImplementedException("_apply_sparse"); | |||||
} | |||||
public virtual (Tensor, Tensor) _deduplicate_indexed_slices(Tensor values, Tensor indices) | |||||
{ | |||||
var (unique_indices, new_index_positions) = array_ops.unique(indices); | |||||
var summed_values = math_ops.unsorted_segment_sum( | |||||
values, new_index_positions, | |||||
array_ops.shape(unique_indices)[0]); | |||||
return (summed_values, unique_indices); | |||||
} | |||||
public virtual void _prepare() | public virtual void _prepare() | ||||
{ | { | ||||
} | } | ||||
/// <summary> | |||||
/// Return a slot named `name` created for `var` by the Optimizer. | |||||
/// </summary> | |||||
/// <param name="var"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
protected RefVariable get_slot(RefVariable var, string name) | |||||
{ | |||||
var named_slots = _slots.ContainsKey(name) ? _slots[name] : null; | |||||
if (named_slots == null) | |||||
return null; | |||||
return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null; | |||||
} | |||||
private string _var_key(RefVariable var) | |||||
{ | |||||
return $"{var.op.graph.graph_key}.{var.op.name}"; | |||||
} | |||||
protected RefVariable _get_non_slot_variable(string name, Graph graph = null) | |||||
{ | |||||
var key = $"{graph.graph_key}.{name}"; | |||||
var non_slot = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null; | |||||
return non_slot; | |||||
} | |||||
private _OptimizableVariable _get_processor(RefVariable v) | private _OptimizableVariable _get_processor(RefVariable v) | ||||
{ | { | ||||
if(v is RefVariable) | if(v is RefVariable) | ||||
@@ -1,6 +1,7 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Framework; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -32,5 +33,12 @@ namespace Tensorflow | |||||
return update_op; | return update_op; | ||||
} | } | ||||
public Operation update_op(Optimizer optimizer, IndexedSlices g) | |||||
{ | |||||
var update_op = optimizer._apply_dense(g, _v); | |||||
return update_op; | |||||
} | |||||
} | } | ||||
} | } |
@@ -97,6 +97,20 @@ namespace Tensorflow | |||||
var _op = _op_def_lib._apply_op_helper("AssignAdd", name: name, args: new { @ref, value, use_locking }); | var _op = _op_def_lib._apply_op_helper("AssignAdd", name: name, args: new { @ref, value, use_locking }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
/// <summary> | |||||
/// Adds sparse updates to a variable reference. | |||||
/// </summary> | |||||
/// <param name="ref"></param> | |||||
/// <param name="indices"></param> | |||||
/// <param name="updates"></param> | |||||
/// <param name="use_locking"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("ScatterAdd", name: name, args: new { @ref, indices, updates, use_locking }); | |||||
return _op.outputs[0]; | |||||
} | |||||
} | } | ||||
} | } |
@@ -72,5 +72,13 @@ namespace Tensorflow | |||||
Tensor value, | Tensor value, | ||||
bool use_locking = false, | bool use_locking = false, | ||||
string name = null) => gen_state_ops.assign_add(@ref, value, use_locking: use_locking, name: name); | string name = null) => gen_state_ops.assign_add(@ref, value, use_locking: use_locking, name: name); | ||||
public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null) | |||||
{ | |||||
if (@ref.dtype.is_ref_dtype()) | |||||
return gen_state_ops.scatter_add(@ref, indices, updates, use_locking: use_locking, name: name); | |||||
throw new NotImplementedException("scatter_add"); | |||||
} | |||||
} | } | ||||
} | } |
@@ -8,10 +8,7 @@ using System.Text; | |||||
using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
using NumSharp; | using NumSharp; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Sessions; | using Tensorflow.Sessions; | ||||
using TensorFlowNET.Examples.Text.cnn_models; | |||||
using TensorFlowNET.Examples.TextClassification; | |||||
using TensorFlowNET.Examples.Utility; | using TensorFlowNET.Examples.Utility; | ||||
using static Tensorflow.Python; | using static Tensorflow.Python; | ||||
@@ -28,7 +28,7 @@ namespace TensorFlowNET.UnitTest.nn_test | |||||
public void testZeroFraction() | public void testZeroFraction() | ||||
{ | { | ||||
var x_shape = new Shape(5, 17); | var x_shape = new Shape(5, 17); | ||||
var x_np = new NumPyRandom().randint(0, 2, x_shape); | |||||
var x_np = np.random.randint(0, 2, x_shape); | |||||
x_np.astype(np.float32); | x_np.astype(np.float32); | ||||
var y_np = this._ZeroFraction(x_np); | var y_np = this._ZeroFraction(x_np); | ||||