diff --git a/src/TensorFlowNET.Core/Framework/IndexedSlices.cs b/src/TensorFlowNET.Core/Framework/IndexedSlices.cs
index ef422968..0c4f0c8b 100644
--- a/src/TensorFlowNET.Core/Framework/IndexedSlices.cs
+++ b/src/TensorFlowNET.Core/Framework/IndexedSlices.cs
@@ -31,11 +31,18 @@ namespace Tensorflow.Framework
_values = values;
_indices = indices;
_dense_shape = dense_shape;
+
+ _values.Tag = this;
}
public static implicit operator Tensor(IndexedSlices indexedSlices)
{
return indexedSlices.values;
}
+
+ public static implicit operator IndexedSlices(Tensor tensor)
+ {
+ return tensor.Tag as IndexedSlices;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs
index 58dd7e4a..4896d4dd 100644
--- a/src/TensorFlowNET.Core/Gradients/array_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs
@@ -156,7 +156,7 @@ namespace Tensorflow.Gradients
// For axis 0 gathers, build an appropriately shaped IndexedSlices.
if((int)axis_static == 0)
{
- var params_tail_shape = params_shape[1];
+ var params_tail_shape = params_shape[new NumSharp.Slice(start:1)];
var values_shape = array_ops.concat(new[] { indices_size, params_tail_shape }, 0);
var values = array_ops.reshape(grad, values_shape);
indices = array_ops.reshape(indices, indices_size);
diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
index 087a2430..8308d48d 100644
--- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
@@ -223,8 +223,8 @@ namespace Tensorflow
{
var _op = _op_def_lib._apply_op_helper("Unique", name, new { x, out_idx });
// TODO
- throw new NotImplementedException("_result = _UniqueOutput._make(_result)");
- // return _op.outputs[0];
+ //var _result = _UniqueOutput._make(_op.outputs);
+ return (_op.outputs[0], _op.outputs[1]);
}
public static Tensor where()
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index 92d45681..f782451f 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -58,6 +58,11 @@ namespace Tensorflow
private TF_Output? _tf_output;
+ ///
+ /// used for keep other pointer when do implicit operating
+ ///
+ public object Tag { get; set; }
+
public int[] shape
{
get
@@ -219,11 +224,11 @@ namespace Tensorflow
}
}
- public Tensor this[int start, int? stop, int? step]
+ public Tensor this[Slice slice]
{
get
{
- var slice_spec = new int[] { start };
+ var slice_spec = new int[] { slice.Start.Value };
var begin = new List();
var end = new List();
var strides = new List();
@@ -236,14 +241,16 @@ namespace Tensorflow
foreach (var s in slice_spec)
{
begin.Add(s);
- if (stop == null)
+ if(slice.Stop.HasValue)
+ {
+ end.Add(slice.Stop.Value);
+ }
+ else
{
end.Add(0);
end_mask |= (1 << index);
}
- else
- end.Add(s + 1);
- strides.Add(1);
+ strides.Add(slice.Step);
index += 1;
}
@@ -277,7 +284,57 @@ namespace Tensorflow
}
}
- public Tensor this[int slice_spec] => this[slice_spec, null, null];
+ public Tensor this[int start]
+ {
+ get
+ {
+ var slice_spec = new int[] { start };
+ var begin = new List();
+ var end = new List();
+ var strides = new List();
+
+ var index = 0;
+ var (new_axis_mask, shrink_axis_mask) = (0, 0);
+ var (begin_mask, end_mask) = (0, 0);
+ var ellipsis_mask = 0;
+
+ foreach (var s in slice_spec)
+ {
+ begin.Add(s);
+ end.Add(s + 1);
+ strides.Add(1);
+ shrink_axis_mask |= (1 << index);
+ index += 1;
+ }
+
+ return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
+ {
+ string name = scope;
+ if (begin != null)
+ {
+ var (packed_begin, packed_end, packed_strides) =
+ (array_ops.stack(begin.ToArray()),
+ array_ops.stack(end.ToArray()),
+ array_ops.stack(strides.ToArray()));
+
+ return gen_array_ops.strided_slice(
+ this,
+ packed_begin,
+ packed_end,
+ packed_strides,
+ begin_mask: begin_mask,
+ end_mask: end_mask,
+ shrink_axis_mask: shrink_axis_mask,
+ new_axis_mask: new_axis_mask,
+ ellipsis_mask: ellipsis_mask,
+
+ name: name);
+ }
+
+ throw new NotImplementedException("");
+ });
+ }
+ }
public override string ToString()
{
diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs
index f5474c23..9284d5c6 100644
--- a/src/TensorFlowNET.Core/Train/Optimizer.cs
+++ b/src/TensorFlowNET.Core/Train/Optimizer.cs
@@ -227,9 +227,8 @@ namespace Tensorflow
public virtual (Tensor, Tensor) _deduplicate_indexed_slices(Tensor values, Tensor indices)
{
var (unique_indices, new_index_positions) = array_ops.unique(indices);
- var summed_values = math_ops.unsorted_segment_sum(
- values, new_index_positions,
- array_ops.shape(unique_indices)[0]);
+ var shape = array_ops.shape(unique_indices)[0];
+ var summed_values = math_ops.unsorted_segment_sum(values, new_index_positions, shape);
return (summed_values, unique_indices);
}
diff --git a/src/TensorFlowNET.Core/Train/_OptimizableVariable.cs b/src/TensorFlowNET.Core/Train/_OptimizableVariable.cs
index e363e580..2d61781a 100644
--- a/src/TensorFlowNET.Core/Train/_OptimizableVariable.cs
+++ b/src/TensorFlowNET.Core/Train/_OptimizableVariable.cs
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
+using Tensorflow.Framework;
namespace Tensorflow
{
diff --git a/src/TensorFlowNET.Core/Train/optimizer.py.cs b/src/TensorFlowNET.Core/Train/optimizer.py.cs
index fbf32876..15c302b4 100644
--- a/src/TensorFlowNET.Core/Train/optimizer.py.cs
+++ b/src/TensorFlowNET.Core/Train/optimizer.py.cs
@@ -29,14 +29,16 @@ namespace Tensorflow
public Operation update_op(Optimizer optimizer, Tensor g)
{
- var update_op = optimizer._apply_dense(g, _v);
-
- return update_op;
- }
-
- public Operation update_op(Optimizer optimizer, IndexedSlices g)
- {
- var update_op = optimizer._apply_dense(g, _v);
+ Operation update_op = null;
+
+ if (g.Tag == null)
+ {
+ update_op = optimizer._apply_dense(g, _v);
+ }
+ else if (g.Tag is IndexedSlices)
+ {
+ return optimizer._apply_sparse_duplicate_indices(g, _v);
+ }
return update_op;
}