Browse Source

fix Tensor set_shape()

tags/v0.12
Oceania2018 6 years ago
parent
commit
9edada5abc
5 changed files with 23 additions and 28 deletions
  1. +3
    -3
      src/TensorFlowNET.Core/Gradients/array_grad.cs
  2. +2
    -2
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  5. +16
    -21
      src/TensorFlowNET.Core/Tensors/Tensor.cs

+ 3
- 3
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -196,13 +196,13 @@ namespace Tensorflow.Gradients
var grad = grads[0];
var x = op.inputs[0];
var a = op.inputs[1];
var pad_before = array_ops.slice(a, new[] { 0, 0 },
new[] { array_ops.stack(new object[] { array_ops.rank(x), 1 }) });
var size = array_ops.stack(new object[] { array_ops.rank(x), 1 });
var pad_before = array_ops.slice(a, new[] { 0, 0 }, size);

// Make it a 1-D tensor.
var begin = array_ops.reshape(pad_before, new[] { -1 });
var sizes = array_ops.shape(x);
var x_grad = array_ops.slice(grad, new[] { begin }, new[] { sizes });
var x_grad = array_ops.slice(grad, begin, sizes);

if (len(op.inputs) == 3)
return new Tensor[] { x_grad, null, null };


+ 2
- 2
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -108,7 +108,7 @@ namespace Tensorflow
{
// generate gradient subgraph for op.
var op = queue.Dequeue();
if(tf.get_default_graph()._nodes_by_name.Count >= 20611)
if(tf.get_default_graph()._nodes_by_name.Count >= 23868)
{

}
@@ -216,7 +216,7 @@ namespace Tensorflow
in_grad.Tag == null && // maybe a IndexedSlice
t_in.dtype != TF_DataType.TF_RESOURCE)
{
in_grad.shape = t_in.shape;
in_grad.set_shape(t_in.TensorShape);
}

_SetGrad(grads, t_in, in_grad);


+ 1
- 1
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -611,7 +611,7 @@ namespace Tensorflow
});
}
public static Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null)
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)
=> gen_array_ops.slice(input, begin, size, name: name);
public static Tensor stack(object values, int axis = 0, string name = "stack")


+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -475,7 +475,7 @@ namespace Tensorflow
return op.output;
}

public static Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null)
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size });
return _op.outputs[0];


+ 16
- 21
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -105,10 +105,13 @@ namespace Tensorflow

if (_handle == IntPtr.Zero)
{
var status = new Status();
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status);
status.Check();
} else
using (var status = new Status())
{
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status);
status.Check();
}
}
else
{
for (int i = 0; i < rank; i++)
dims[i] = c_api.TF_Dim(_handle, i);
@@ -119,14 +122,15 @@ namespace Tensorflow

set
{
var status = new Status();
if (value == null)
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status);
else
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status);
using (var status = new Status())
{
if (value == null)
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status);
else
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status);

status.Check(true);
status.Check(true);
}
}
}

@@ -142,16 +146,7 @@ namespace Tensorflow
/// </summary>
public void set_shape(TensorShape shape)
{
this.shape = (int[]) shape.dims.Clone();
}

/// <summary>
/// Updates the shape of this tensor.
/// </summary>
[Obsolete("Please use set_shape(TensorShape shape) instead.", false)]
public void SetShape(TensorShape shape)
{
this.shape = (int[]) shape.dims.Clone();
this.shape = shape.rank > 0 ? shape.dims : null;
}

/// <summary>


Loading…
Cancel
Save