Browse Source

Fix AsTensor() for ref and copy.

tags/v0.30
Oceania2018 5 years ago
parent
commit
64276a3ce8
7 changed files with 42 additions and 12 deletions
  1. +14
    -3
      src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
  2. +7
    -0
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  3. +3
    -0
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  4. +15
    -6
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Variables/IVariableV1.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs

+ 14
- 3
src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs View File

@@ -310,17 +310,25 @@ namespace Tensorflow
private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y)
{
TF_DataType dtype = TF_DataType.DtInvalid;
bool switchToGraphModeTemp = !tf.executing_eagerly();

if (x is Tensor tl)
{
dtype = tl.dtype.as_base_dtype();
switchToGraphModeTemp = switchToGraphModeTemp || !tl.IsEagerTensor;
}
if (y is Tensor tr)
{
dtype = tr.dtype.as_base_dtype();

if (name == "div")
name = div_or_truediv(name, x, y);
switchToGraphModeTemp = switchToGraphModeTemp || !tr.IsEagerTensor;
}

return tf_with(ops.name_scope(null, name, new { x, y }), scope =>
{
if (switchToGraphModeTemp)
tf.Context.graph_mode();

Tensor result;
var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x");
var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y");
@@ -352,6 +360,9 @@ namespace Tensorflow
throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty).Name}");
}

if (switchToGraphModeTemp)
tf.Context.restore_mode();

return result;
});
}


+ 7
- 0
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -253,6 +253,13 @@ namespace Tensorflow
return (int[]) dims.Clone();
}

public long[] as_list_long()
{
if (shape.IsEmpty)
throw new ValueError("as_list() is not defined on an unknown TensorShape.");
return dims.Select(x => Convert.ToInt64(x)).ToArray();
}

public int num_elements()
{
if(is_fully_defined())


+ 3
- 0
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -56,6 +56,9 @@ namespace Tensorflow
public static implicit operator Tensors(Tensor[] tensors)
=> new Tensors(tensors);

public static implicit operator Tensors(List<Tensor> tensors)
=> new Tensors(tensors.ToArray());

public static implicit operator Tensor(Tensors tensors)
=> tensors.FirstOrDefault();



+ 15
- 6
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -100,6 +100,18 @@ namespace Tensorflow
variable_accessed(this);
var result = gen_resource_variable_ops.read_variable_op(handle, _dtype);
// _maybe_set_handle_data(_dtype, _handle, result);

// have to set shape when converting to substituent placeholder
if (result.TensorShape.ndim == -1)
{
c_api.TF_GraphSetTensorShape(result.graph,
result._as_tf_output(),
shape.as_list_long(),
shape.ndim,
tf.Status.Handle);
tf.Status.Check(true);
}

return result;
}

@@ -160,15 +172,12 @@ namespace Tensorflow
{
}

public Tensor AsTensor(bool as_ref = true)
public Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
{
if (!as_ref && GraphElement != null)
return GraphElement;

if (as_ref)
return tf.executing_eagerly() ? read_value() : GraphElement;
return read_value().op.inputs[0];
else
return _read_variable_op();
return value();
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Variables/IVariableV1.cs View File

@@ -49,6 +49,6 @@ namespace Tensorflow
public TensorShape shape { get; }
Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true);
Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true);
Tensor AsTensor(bool as_ref = true);
Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false);
}
}

+ 1
- 1
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -222,7 +222,7 @@ namespace Tensorflow

public Tensor value() => _snapshot;

public Tensor AsTensor(bool as_ref = true) => _snapshot;
public Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) => _snapshot;

public Tensor _as_graph_element() => _variable;



+ 1
- 1
src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs View File

@@ -37,7 +37,7 @@ namespace Tensorflow
if (as_ref)
return handle;
else
return tf.executing_eagerly() ? AsTensor() : value();
return AsTensor();
}
}
}

Loading…
Cancel
Save