@@ -310,17 +310,25 @@ namespace Tensorflow | |||||
private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y) | private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y) | ||||
{ | { | ||||
TF_DataType dtype = TF_DataType.DtInvalid; | TF_DataType dtype = TF_DataType.DtInvalid; | ||||
bool switchToGraphModeTemp = !tf.executing_eagerly(); | |||||
if (x is Tensor tl) | if (x is Tensor tl) | ||||
{ | |||||
dtype = tl.dtype.as_base_dtype(); | dtype = tl.dtype.as_base_dtype(); | ||||
switchToGraphModeTemp = switchToGraphModeTemp || !tl.IsEagerTensor; | |||||
} | |||||
if (y is Tensor tr) | if (y is Tensor tr) | ||||
{ | |||||
dtype = tr.dtype.as_base_dtype(); | dtype = tr.dtype.as_base_dtype(); | ||||
if (name == "div") | |||||
name = div_or_truediv(name, x, y); | |||||
switchToGraphModeTemp = switchToGraphModeTemp || !tr.IsEagerTensor; | |||||
} | |||||
return tf_with(ops.name_scope(null, name, new { x, y }), scope => | return tf_with(ops.name_scope(null, name, new { x, y }), scope => | ||||
{ | { | ||||
if (switchToGraphModeTemp) | |||||
tf.Context.graph_mode(); | |||||
Tensor result; | Tensor result; | ||||
var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); | var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); | ||||
var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y"); | var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y"); | ||||
@@ -352,6 +360,9 @@ namespace Tensorflow | |||||
throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty).Name}"); | throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty).Name}"); | ||||
} | } | ||||
if (switchToGraphModeTemp) | |||||
tf.Context.restore_mode(); | |||||
return result; | return result; | ||||
}); | }); | ||||
} | } | ||||
@@ -253,6 +253,13 @@ namespace Tensorflow | |||||
return (int[]) dims.Clone(); | return (int[]) dims.Clone(); | ||||
} | } | ||||
public long[] as_list_long() | |||||
{ | |||||
if (shape.IsEmpty) | |||||
throw new ValueError("as_list() is not defined on an unknown TensorShape."); | |||||
return dims.Select(x => Convert.ToInt64(x)).ToArray(); | |||||
} | |||||
public int num_elements() | public int num_elements() | ||||
{ | { | ||||
if(is_fully_defined()) | if(is_fully_defined()) | ||||
@@ -56,6 +56,9 @@ namespace Tensorflow | |||||
public static implicit operator Tensors(Tensor[] tensors) | public static implicit operator Tensors(Tensor[] tensors) | ||||
=> new Tensors(tensors); | => new Tensors(tensors); | ||||
public static implicit operator Tensors(List<Tensor> tensors) | |||||
=> new Tensors(tensors.ToArray()); | |||||
public static implicit operator Tensor(Tensors tensors) | public static implicit operator Tensor(Tensors tensors) | ||||
=> tensors.FirstOrDefault(); | => tensors.FirstOrDefault(); | ||||
@@ -100,6 +100,18 @@ namespace Tensorflow | |||||
variable_accessed(this); | variable_accessed(this); | ||||
var result = gen_resource_variable_ops.read_variable_op(handle, _dtype); | var result = gen_resource_variable_ops.read_variable_op(handle, _dtype); | ||||
// _maybe_set_handle_data(_dtype, _handle, result); | // _maybe_set_handle_data(_dtype, _handle, result); | ||||
// have to set shape when converting to substituent placeholder | |||||
if (result.TensorShape.ndim == -1) | |||||
{ | |||||
c_api.TF_GraphSetTensorShape(result.graph, | |||||
result._as_tf_output(), | |||||
shape.as_list_long(), | |||||
shape.ndim, | |||||
tf.Status.Handle); | |||||
tf.Status.Check(true); | |||||
} | |||||
return result; | return result; | ||||
} | } | ||||
@@ -160,15 +172,12 @@ namespace Tensorflow | |||||
{ | { | ||||
} | } | ||||
public Tensor AsTensor(bool as_ref = true) | |||||
public Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) | |||||
{ | { | ||||
if (!as_ref && GraphElement != null) | |||||
return GraphElement; | |||||
if (as_ref) | if (as_ref) | ||||
return tf.executing_eagerly() ? read_value() : GraphElement; | |||||
return read_value().op.inputs[0]; | |||||
else | else | ||||
return _read_variable_op(); | |||||
return value(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -49,6 +49,6 @@ namespace Tensorflow | |||||
public TensorShape shape { get; } | public TensorShape shape { get; } | ||||
Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true); | Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true); | ||||
Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true); | Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true); | ||||
Tensor AsTensor(bool as_ref = true); | |||||
Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false); | |||||
} | } | ||||
} | } |
@@ -222,7 +222,7 @@ namespace Tensorflow | |||||
public Tensor value() => _snapshot; | public Tensor value() => _snapshot; | ||||
public Tensor AsTensor(bool as_ref = true) => _snapshot; | |||||
public Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) => _snapshot; | |||||
public Tensor _as_graph_element() => _variable; | public Tensor _as_graph_element() => _variable; | ||||
@@ -37,7 +37,7 @@ namespace Tensorflow | |||||
if (as_ref) | if (as_ref) | ||||
return handle; | return handle; | ||||
else | else | ||||
return tf.executing_eagerly() ? AsTensor() : value(); | |||||
return AsTensor(); | |||||
} | } | ||||
} | } | ||||
} | } |