@@ -31,11 +31,18 @@ namespace Tensorflow.Framework | |||
_values = values; | |||
_indices = indices; | |||
_dense_shape = dense_shape; | |||
_values.Tag = this; | |||
} | |||
public static implicit operator Tensor(IndexedSlices indexedSlices) | |||
{ | |||
return indexedSlices.values; | |||
} | |||
public static implicit operator IndexedSlices(Tensor tensor) | |||
{ | |||
return tensor.Tag as IndexedSlices; | |||
} | |||
} | |||
} |
@@ -156,7 +156,7 @@ namespace Tensorflow.Gradients | |||
// For axis 0 gathers, build an appropriately shaped IndexedSlices. | |||
if((int)axis_static == 0) | |||
{ | |||
var params_tail_shape = params_shape[1]; | |||
var params_tail_shape = params_shape[new NumSharp.Slice(start:1)]; | |||
var values_shape = array_ops.concat(new[] { indices_size, params_tail_shape }, 0); | |||
var values = array_ops.reshape(grad, values_shape); | |||
indices = array_ops.reshape(indices, indices_size); | |||
@@ -223,8 +223,8 @@ namespace Tensorflow | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("Unique", name, new { x, out_idx }); | |||
// TODO | |||
throw new NotImplementedException("_result = _UniqueOutput._make(_result)"); | |||
// return _op.outputs[0]; | |||
//var _result = _UniqueOutput._make(_op.outputs); | |||
return (_op.outputs[0], _op.outputs[1]); | |||
} | |||
public static Tensor where() | |||
@@ -58,6 +58,11 @@ namespace Tensorflow | |||
private TF_Output? _tf_output; | |||
/// <summary> | |||
/// used for keep other pointer when do implicit operating | |||
/// </summary> | |||
public object Tag { get; set; } | |||
public int[] shape | |||
{ | |||
get | |||
@@ -219,11 +224,11 @@ namespace Tensorflow | |||
} | |||
} | |||
public Tensor this[int start, int? stop, int? step] | |||
public Tensor this[Slice slice] | |||
{ | |||
get | |||
{ | |||
var slice_spec = new int[] { start }; | |||
var slice_spec = new int[] { slice.Start.Value }; | |||
var begin = new List<int>(); | |||
var end = new List<int>(); | |||
var strides = new List<int>(); | |||
@@ -236,14 +241,16 @@ namespace Tensorflow | |||
foreach (var s in slice_spec) | |||
{ | |||
begin.Add(s); | |||
if (stop == null) | |||
if(slice.Stop.HasValue) | |||
{ | |||
end.Add(slice.Stop.Value); | |||
} | |||
else | |||
{ | |||
end.Add(0); | |||
end_mask |= (1 << index); | |||
} | |||
else | |||
end.Add(s + 1); | |||
strides.Add(1); | |||
strides.Add(slice.Step); | |||
index += 1; | |||
} | |||
@@ -277,7 +284,57 @@ namespace Tensorflow | |||
} | |||
} | |||
public Tensor this[int slice_spec] => this[slice_spec, null, null]; | |||
public Tensor this[int start] | |||
{ | |||
get | |||
{ | |||
var slice_spec = new int[] { start }; | |||
var begin = new List<int>(); | |||
var end = new List<int>(); | |||
var strides = new List<int>(); | |||
var index = 0; | |||
var (new_axis_mask, shrink_axis_mask) = (0, 0); | |||
var (begin_mask, end_mask) = (0, 0); | |||
var ellipsis_mask = 0; | |||
foreach (var s in slice_spec) | |||
{ | |||
begin.Add(s); | |||
end.Add(s + 1); | |||
strides.Add(1); | |||
shrink_axis_mask |= (1 << index); | |||
index += 1; | |||
} | |||
return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => | |||
{ | |||
string name = scope; | |||
if (begin != null) | |||
{ | |||
var (packed_begin, packed_end, packed_strides) = | |||
(array_ops.stack(begin.ToArray()), | |||
array_ops.stack(end.ToArray()), | |||
array_ops.stack(strides.ToArray())); | |||
return gen_array_ops.strided_slice( | |||
this, | |||
packed_begin, | |||
packed_end, | |||
packed_strides, | |||
begin_mask: begin_mask, | |||
end_mask: end_mask, | |||
shrink_axis_mask: shrink_axis_mask, | |||
new_axis_mask: new_axis_mask, | |||
ellipsis_mask: ellipsis_mask, | |||
name: name); | |||
} | |||
throw new NotImplementedException(""); | |||
}); | |||
} | |||
} | |||
public override string ToString() | |||
{ | |||
@@ -227,9 +227,8 @@ namespace Tensorflow | |||
public virtual (Tensor, Tensor) _deduplicate_indexed_slices(Tensor values, Tensor indices) | |||
{ | |||
var (unique_indices, new_index_positions) = array_ops.unique(indices); | |||
var summed_values = math_ops.unsorted_segment_sum( | |||
values, new_index_positions, | |||
array_ops.shape(unique_indices)[0]); | |||
var shape = array_ops.shape(unique_indices)[0]; | |||
var summed_values = math_ops.unsorted_segment_sum(values, new_index_positions, shape); | |||
return (summed_values, unique_indices); | |||
} | |||
@@ -1,6 +1,7 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Framework; | |||
namespace Tensorflow | |||
{ | |||
@@ -29,14 +29,16 @@ namespace Tensorflow | |||
public Operation update_op(Optimizer optimizer, Tensor g) | |||
{ | |||
var update_op = optimizer._apply_dense(g, _v); | |||
return update_op; | |||
} | |||
public Operation update_op(Optimizer optimizer, IndexedSlices g) | |||
{ | |||
var update_op = optimizer._apply_dense(g, _v); | |||
Operation update_op = null; | |||
if (g.Tag == null) | |||
{ | |||
update_op = optimizer._apply_dense(g, _v); | |||
} | |||
else if (g.Tag is IndexedSlices) | |||
{ | |||
return optimizer._apply_sparse_duplicate_indices(g, _v); | |||
} | |||
return update_op; | |||
} | |||