@@ -196,13 +196,13 @@ namespace Tensorflow.Gradients | |||||
var grad = grads[0]; | var grad = grads[0]; | ||||
var x = op.inputs[0]; | var x = op.inputs[0]; | ||||
var a = op.inputs[1]; | var a = op.inputs[1]; | ||||
var pad_before = array_ops.slice(a, new[] { 0, 0 }, | |||||
new[] { array_ops.stack(new object[] { array_ops.rank(x), 1 }) }); | |||||
var size = array_ops.stack(new object[] { array_ops.rank(x), 1 }); | |||||
var pad_before = array_ops.slice(a, new[] { 0, 0 }, size); | |||||
// Make it a 1-D tensor. | // Make it a 1-D tensor. | ||||
var begin = array_ops.reshape(pad_before, new[] { -1 }); | var begin = array_ops.reshape(pad_before, new[] { -1 }); | ||||
var sizes = array_ops.shape(x); | var sizes = array_ops.shape(x); | ||||
var x_grad = array_ops.slice(grad, new[] { begin }, new[] { sizes }); | |||||
var x_grad = array_ops.slice(grad, begin, sizes); | |||||
if (len(op.inputs) == 3) | if (len(op.inputs) == 3) | ||||
return new Tensor[] { x_grad, null, null }; | return new Tensor[] { x_grad, null, null }; | ||||
@@ -108,7 +108,7 @@ namespace Tensorflow | |||||
{ | { | ||||
// generate gradient subgraph for op. | // generate gradient subgraph for op. | ||||
var op = queue.Dequeue(); | var op = queue.Dequeue(); | ||||
if(tf.get_default_graph()._nodes_by_name.Count >= 20611) | |||||
if(tf.get_default_graph()._nodes_by_name.Count >= 23868) | |||||
{ | { | ||||
} | } | ||||
@@ -216,7 +216,7 @@ namespace Tensorflow | |||||
in_grad.Tag == null && // maybe a IndexedSlice | in_grad.Tag == null && // maybe a IndexedSlice | ||||
t_in.dtype != TF_DataType.TF_RESOURCE) | t_in.dtype != TF_DataType.TF_RESOURCE) | ||||
{ | { | ||||
in_grad.shape = t_in.shape; | |||||
in_grad.set_shape(t_in.TensorShape); | |||||
} | } | ||||
_SetGrad(grads, t_in, in_grad); | _SetGrad(grads, t_in, in_grad); | ||||
@@ -611,7 +611,7 @@ namespace Tensorflow | |||||
}); | }); | ||||
} | } | ||||
public static Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null) | |||||
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null) | |||||
=> gen_array_ops.slice(input, begin, size, name: name); | => gen_array_ops.slice(input, begin, size, name: name); | ||||
public static Tensor stack(object values, int axis = 0, string name = "stack") | public static Tensor stack(object values, int axis = 0, string name = "stack") | ||||
@@ -475,7 +475,7 @@ namespace Tensorflow | |||||
return op.output; | return op.output; | ||||
} | } | ||||
public static Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null) | |||||
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null) | |||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size }); | var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
@@ -105,10 +105,13 @@ namespace Tensorflow | |||||
if (_handle == IntPtr.Zero) | if (_handle == IntPtr.Zero) | ||||
{ | { | ||||
var status = new Status(); | |||||
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); | |||||
status.Check(); | |||||
} else | |||||
using (var status = new Status()) | |||||
{ | |||||
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); | |||||
status.Check(); | |||||
} | |||||
} | |||||
else | |||||
{ | { | ||||
for (int i = 0; i < rank; i++) | for (int i = 0; i < rank; i++) | ||||
dims[i] = c_api.TF_Dim(_handle, i); | dims[i] = c_api.TF_Dim(_handle, i); | ||||
@@ -119,14 +122,15 @@ namespace Tensorflow | |||||
set | set | ||||
{ | { | ||||
var status = new Status(); | |||||
if (value == null) | |||||
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); | |||||
else | |||||
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); | |||||
using (var status = new Status()) | |||||
{ | |||||
if (value == null) | |||||
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); | |||||
else | |||||
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); | |||||
status.Check(true); | |||||
status.Check(true); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -142,16 +146,7 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public void set_shape(TensorShape shape) | public void set_shape(TensorShape shape) | ||||
{ | { | ||||
this.shape = (int[]) shape.dims.Clone(); | |||||
} | |||||
/// <summary> | |||||
/// Updates the shape of this tensor. | |||||
/// </summary> | |||||
[Obsolete("Please use set_shape(TensorShape shape) instead.", false)] | |||||
public void SetShape(TensorShape shape) | |||||
{ | |||||
this.shape = (int[]) shape.dims.Clone(); | |||||
this.shape = shape.rank > 0 ? shape.dims : null; | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||