|
|
@@ -310,17 +310,25 @@ namespace Tensorflow |
|
|
|
private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y) |
|
|
|
{ |
|
|
|
TF_DataType dtype = TF_DataType.DtInvalid; |
|
|
|
bool switchToGraphModeTemp = !tf.executing_eagerly(); |
|
|
|
|
|
|
|
if (x is Tensor tl) |
|
|
|
{ |
|
|
|
dtype = tl.dtype.as_base_dtype(); |
|
|
|
switchToGraphModeTemp = switchToGraphModeTemp || !tl.IsEagerTensor; |
|
|
|
} |
|
|
|
|
|
|
|
if (y is Tensor tr) |
|
|
|
{ |
|
|
|
dtype = tr.dtype.as_base_dtype(); |
|
|
|
|
|
|
|
if (name == "div") |
|
|
|
name = div_or_truediv(name, x, y); |
|
|
|
switchToGraphModeTemp = switchToGraphModeTemp || !tr.IsEagerTensor; |
|
|
|
} |
|
|
|
|
|
|
|
return tf_with(ops.name_scope(null, name, new { x, y }), scope => |
|
|
|
{ |
|
|
|
if (switchToGraphModeTemp) |
|
|
|
tf.Context.graph_mode(); |
|
|
|
|
|
|
|
Tensor result; |
|
|
|
var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); |
|
|
|
var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y"); |
|
|
@@ -352,6 +360,9 @@ namespace Tensorflow |
|
|
|
throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty).Name}"); |
|
|
|
} |
|
|
|
|
|
|
|
if (switchToGraphModeTemp) |
|
|
|
tf.Context.restore_mode(); |
|
|
|
|
|
|
|
return result; |
|
|
|
}); |
|
|
|
} |
|
|
|