@@ -19,7 +19,7 @@ namespace Tensorflow | |||
public partial class tensorflow | |||
{ | |||
public Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid) | |||
=> ops.convert_to_tensor(value, dtype, name, preferred_dtype); | |||
=> ops.convert_to_tensor(value, dtype, name, preferred_dtype: preferred_dtype); | |||
public Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides = null, | |||
int begin_mask = 0, | |||
@@ -69,7 +69,7 @@ namespace Tensorflow.Eager | |||
return placeholder; | |||
} | |||
public Tensor AsContatnt(string name = null) | |||
public Tensor AsConstant(string name = null) | |||
{ | |||
Tensor constant = null; | |||
tf_with(ops.control_dependencies(null), delegate | |||
@@ -29,7 +29,7 @@ namespace Tensorflow.Framework | |||
{ | |||
indices = ops.convert_to_tensor( | |||
indices_, name: "indices", dtype: dtypes.int64); | |||
values = ops.internal_convert_to_tensor(values_, name: "values"); | |||
values = ops.convert_to_tensor(values_, name: "values"); | |||
dense_shape = ops.convert_to_tensor( | |||
dense_shape_, name: "dense_shape", dtype: dtypes.int64); | |||
}); | |||
@@ -13,9 +13,6 @@ namespace Tensorflow.Graphs | |||
/// </summary> | |||
public class FuncGraph : Graph | |||
{ | |||
Graph outer_graph; | |||
public Graph OuterGraph => outer_graph; | |||
// _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | |||
IntPtr func_handle; | |||
public string FuncName => _graph_key; | |||
@@ -42,8 +39,10 @@ namespace Tensorflow.Graphs | |||
public FuncGraph(string name) : base() | |||
{ | |||
outer_graph = ops.get_default_graph(); | |||
while (outer_graph.building_function) | |||
outer_graph = outer_graph.OuterGraph; | |||
_graph_key = name; | |||
building_function = true; | |||
tf.Context.graph_mode(); | |||
as_default(); | |||
} | |||
@@ -51,7 +50,10 @@ namespace Tensorflow.Graphs | |||
public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base() | |||
{ | |||
outer_graph = ops.get_default_graph(); | |||
while (outer_graph.building_function) | |||
outer_graph = outer_graph.OuterGraph; | |||
_graph_key = name; | |||
building_function = true; | |||
Attrs = attrs; | |||
// Will to test if FuncGraph has memory leak | |||
// c_api.TF_DeleteGraph(_handle); | |||
@@ -108,7 +110,7 @@ namespace Tensorflow.Graphs | |||
return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device); | |||
} | |||
Tensor capture(Tensor tensor, string name = null, TF_DataType shape = TF_DataType.DtInvalid) | |||
public Tensor capture(Tensor tensor, string name = null, TF_DataType shape = TF_DataType.DtInvalid) | |||
{ | |||
if(tensor is EagerTensor) | |||
{ | |||
@@ -118,6 +118,9 @@ namespace Tensorflow | |||
} | |||
} | |||
protected Graph outer_graph; | |||
public Graph OuterGraph => outer_graph; | |||
public Graph() | |||
{ | |||
_handle = c_api.TF_NewGraph(); | |||
@@ -148,7 +148,7 @@ namespace Tensorflow | |||
else if (default_type_attr_map.ContainsKey(input_arg.TypeAttr)) | |||
default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; | |||
var value = ops.internal_convert_to_tensor(values, | |||
var value = ops.convert_to_tensor(values, | |||
name: input_name, | |||
dtype: dtype.as_tf_dtype(), | |||
as_ref: input_arg.IsRef, | |||
@@ -66,7 +66,7 @@ namespace Tensorflow | |||
else | |||
{ | |||
ops.init_scope(); | |||
var variable = ops.internal_convert_to_tensor(op, as_ref: true); | |||
var variable = ops.convert_to_tensor(op, as_ref: true); | |||
if (variable.dtype.is_ref_dtype()) | |||
yield return new ReferenceVariableSaveable(variable, "", name); | |||
else | |||
@@ -103,7 +103,7 @@ namespace Tensorflow | |||
if (!var.dtype.is_ref_dtype()) | |||
tensor = var.GraphElement; | |||
else | |||
tensor = ops.internal_convert_to_tensor(var, as_ref: true); | |||
tensor = ops.convert_to_tensor(var, as_ref: true); | |||
} | |||
if (tensor.op.type == "ReadVariableOp") | |||
@@ -24,6 +24,7 @@ using System.Runtime.InteropServices; | |||
using System.Threading; | |||
using Tensorflow.Contexts; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Graphs; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
@@ -101,14 +102,44 @@ namespace Tensorflow | |||
public static Tensor convert_to_tensor(object value, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
string name = null, | |||
bool as_ref = false, | |||
TF_DataType preferred_dtype = TF_DataType.DtInvalid, | |||
Context ctx = null) | |||
{ | |||
return internal_convert_to_tensor(value, | |||
dtype: dtype, | |||
name: name, | |||
preferred_dtype: preferred_dtype, | |||
as_ref: false); | |||
if (dtype == TF_DataType.DtInvalid) | |||
dtype = preferred_dtype; | |||
if (value is EagerTensor eager_tensor) | |||
{ | |||
if (tf.executing_eagerly()) | |||
return eager_tensor; | |||
/*else | |||
{ | |||
var graph = get_default_graph(); | |||
if (!graph.building_function) | |||
throw new RuntimeError("Attempting to capture an EagerTensor without building a function."); | |||
return (graph as FuncGraph).capture(eager_tensor, name: name); | |||
}*/ | |||
} | |||
Tensor ret = value switch | |||
{ | |||
NDArray nd => constant_op.constant(nd, dtype: dtype, name: name), | |||
EagerTensor tensor => tensor.dtype == TF_DataType.TF_RESOURCE | |||
? tensor.AsPlaceholder(name: name) | |||
: tensor.AsConstant(name: name), | |||
Tensor tensor => tensor, | |||
Tensor[] tensors => array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name), | |||
RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), | |||
ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), | |||
TensorShape ts => constant_op.constant(ts.dims, dtype: dtype, name: name), | |||
int[] dims => constant_op.constant(dims, dtype: dtype, name: name), | |||
string str => constant_op.constant(str, dtype: tf.@string, name: name), | |||
object[] objects => array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name), | |||
_ => constant_op.constant(value, dtype: dtype, name: name) | |||
}; | |||
return ret; | |||
} | |||
@@ -118,9 +149,7 @@ namespace Tensorflow | |||
} | |||
public static Tensor internal_convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) | |||
{ | |||
return internal_convert_to_tensor(value, dtype: dtype, name: name, as_ref: as_ref); | |||
} | |||
=> convert_to_tensor(value, dtype: dtype, name: name, as_ref: as_ref); | |||
/// <summary> | |||
/// Wrapper for `Graph.control_dependencies()` using the default graph. | |||
@@ -460,52 +489,12 @@ namespace Tensorflow | |||
foreach ((int i, object value) in enumerate(values as object[])) | |||
{ | |||
string n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}"; | |||
ret.Add(internal_convert_to_tensor(value, dtype: dtype, name: n, as_ref: as_ref, preferred_dtype: preferred_dtype)); | |||
ret.Add(convert_to_tensor(value, dtype: dtype, name: n, as_ref: as_ref, preferred_dtype: preferred_dtype)); | |||
} | |||
return ret.ToArray(); | |||
} | |||
public static Tensor internal_convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, | |||
string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid, | |||
bool as_ref = false, | |||
string scope = null) | |||
{ | |||
if (dtype == TF_DataType.DtInvalid) | |||
dtype = preferred_dtype; | |||
switch (value) | |||
{ | |||
case NDArray nd: | |||
return constant_op.constant(nd, dtype: dtype, name: name); | |||
case EagerTensor tensor: | |||
if (tf.executing_eagerly()) | |||
return tensor; | |||
else | |||
return tensor.dtype == TF_DataType.TF_RESOURCE | |||
? tensor.AsPlaceholder(name: name) | |||
: tensor.AsContatnt(name: name); | |||
case Tensor tensor: | |||
return tensor; | |||
case Tensor[] tensors: | |||
return array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name); | |||
case RefVariable varVal: | |||
return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref); | |||
case ResourceVariable varVal: | |||
return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref); | |||
case TensorShape ts: | |||
return constant_op.constant(ts.dims, dtype: dtype, name: name); | |||
case string str: | |||
return constant_op.constant(value, dtype: tf.@string, name: name); | |||
case int[] dims: | |||
return constant_op.constant(dims, dtype: dtype, name: name); | |||
case object[] objects: | |||
return array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name); | |||
default: | |||
return constant_op.constant(value, dtype: dtype, name: name); | |||
} | |||
} | |||
public static string strip_name_scope(string name, string export_scope = "") | |||
{ | |||
if (!string.IsNullOrEmpty(export_scope)) | |||
@@ -17,6 +17,7 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using Tensorflow.Graphs; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras | |||
@@ -78,6 +79,12 @@ namespace Tensorflow.Keras | |||
public Graph get_graph() | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
if (_GRAPH == null) | |||
_GRAPH = new FuncGraph("keras_graph"); | |||
return _GRAPH; | |||
} | |||
return ops.get_default_graph(); | |||
} | |||
@@ -1,6 +1,7 @@ | |||
using System; | |||
using Tensorflow.Keras.Utils; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
@@ -22,7 +23,7 @@ namespace Tensorflow.Keras.Engine | |||
Tensors outputs = null; | |||
using var ctxManager = CallContext.enter(); | |||
// using var graph = tf.keras.backend.get_graph().as_default(); | |||
// using var graph = keras.backend.get_graph(); | |||
if (!inputs.IsEagerTensor) | |||
tf.Context.graph_mode(isFunc: true); | |||