Browse Source

fix Tensor set_shape()

tags/v0.12
Oceania2018 6 years ago
parent
commit
9edada5abc
5 changed files with 23 additions and 28 deletions
  1. +3
    -3
      src/TensorFlowNET.Core/Gradients/array_grad.cs
  2. +2
    -2
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  5. +16
    -21
      src/TensorFlowNET.Core/Tensors/Tensor.cs

+ 3
- 3
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -196,13 +196,13 @@ namespace Tensorflow.Gradients
var grad = grads[0]; var grad = grads[0];
var x = op.inputs[0]; var x = op.inputs[0];
var a = op.inputs[1]; var a = op.inputs[1];
var pad_before = array_ops.slice(a, new[] { 0, 0 },
new[] { array_ops.stack(new object[] { array_ops.rank(x), 1 }) });
var size = array_ops.stack(new object[] { array_ops.rank(x), 1 });
var pad_before = array_ops.slice(a, new[] { 0, 0 }, size);


// Make it a 1-D tensor. // Make it a 1-D tensor.
var begin = array_ops.reshape(pad_before, new[] { -1 }); var begin = array_ops.reshape(pad_before, new[] { -1 });
var sizes = array_ops.shape(x); var sizes = array_ops.shape(x);
var x_grad = array_ops.slice(grad, new[] { begin }, new[] { sizes });
var x_grad = array_ops.slice(grad, begin, sizes);


if (len(op.inputs) == 3) if (len(op.inputs) == 3)
return new Tensor[] { x_grad, null, null }; return new Tensor[] { x_grad, null, null };


+ 2
- 2
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -108,7 +108,7 @@ namespace Tensorflow
{ {
// generate gradient subgraph for op. // generate gradient subgraph for op.
var op = queue.Dequeue(); var op = queue.Dequeue();
if(tf.get_default_graph()._nodes_by_name.Count >= 20611)
if(tf.get_default_graph()._nodes_by_name.Count >= 23868)
{ {


} }
@@ -216,7 +216,7 @@ namespace Tensorflow
in_grad.Tag == null && // maybe a IndexedSlice in_grad.Tag == null && // maybe a IndexedSlice
t_in.dtype != TF_DataType.TF_RESOURCE) t_in.dtype != TF_DataType.TF_RESOURCE)
{ {
in_grad.shape = t_in.shape;
in_grad.set_shape(t_in.TensorShape);
} }


_SetGrad(grads, t_in, in_grad); _SetGrad(grads, t_in, in_grad);


+ 1
- 1
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -611,7 +611,7 @@ namespace Tensorflow
}); });
} }
public static Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null)
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)
=> gen_array_ops.slice(input, begin, size, name: name); => gen_array_ops.slice(input, begin, size, name: name);
public static Tensor stack(object values, int axis = 0, string name = "stack") public static Tensor stack(object values, int axis = 0, string name = "stack")


+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -475,7 +475,7 @@ namespace Tensorflow
return op.output; return op.output;
} }


public static Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null)
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)
{ {
var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size }); var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size });
return _op.outputs[0]; return _op.outputs[0];


+ 16
- 21
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -105,10 +105,13 @@ namespace Tensorflow


if (_handle == IntPtr.Zero) if (_handle == IntPtr.Zero)
{ {
var status = new Status();
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status);
status.Check();
} else
using (var status = new Status())
{
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status);
status.Check();
}
}
else
{ {
for (int i = 0; i < rank; i++) for (int i = 0; i < rank; i++)
dims[i] = c_api.TF_Dim(_handle, i); dims[i] = c_api.TF_Dim(_handle, i);
@@ -119,14 +122,15 @@ namespace Tensorflow


set set
{ {
var status = new Status();
if (value == null)
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status);
else
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status);
using (var status = new Status())
{
if (value == null)
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status);
else
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status);


status.Check(true);
status.Check(true);
}
} }
} }


@@ -142,16 +146,7 @@ namespace Tensorflow
/// </summary> /// </summary>
public void set_shape(TensorShape shape) public void set_shape(TensorShape shape)
{ {
this.shape = (int[]) shape.dims.Clone();
}

/// <summary>
/// Updates the shape of this tensor.
/// </summary>
[Obsolete("Please use set_shape(TensorShape shape) instead.", false)]
public void SetShape(TensorShape shape)
{
this.shape = (int[]) shape.dims.Clone();
this.shape = shape.rank > 0 ? shape.dims : null;
} }


/// <summary> /// <summary>


Loading…
Cancel
Save