@@ -105,6 +105,11 @@ namespace Tensorflow.Contexts | |||
context_switches.Pop(); | |||
} | |||
public void reset_context() | |||
{ | |||
c_api.TFE_ContextClearCaches(_handle); | |||
} | |||
public void Dispose() | |||
=> _handle.Dispose(); | |||
} | |||
@@ -34,6 +34,7 @@ namespace Tensorflow.Functions | |||
public ConcreteFunction(string name) | |||
{ | |||
func_graph = new FuncGraph(name); | |||
func_graph.as_default(); | |||
} | |||
public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs) | |||
@@ -48,17 +49,16 @@ namespace Tensorflow.Functions | |||
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||
// IntPtr func_handle; | |||
using (var graph = new FuncGraph(func_name)) | |||
{ | |||
var input = tf.placeholder(dtype); | |||
var output = func(input); | |||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
_handle = graph.ToGraph(opers, | |||
new[] { input }, | |||
new[] { output }, | |||
null); | |||
} | |||
using var graph = new FuncGraph(func_name); | |||
graph.as_default(); | |||
var input = tf.placeholder(dtype); | |||
var output = func(input); | |||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
_handle = graph.ToGraph(opers, | |||
new[] { input }, | |||
new[] { output }, | |||
null); | |||
} | |||
public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | |||
@@ -66,19 +66,19 @@ namespace Tensorflow.Functions | |||
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||
// IntPtr func_handle; | |||
using (var graph = new FuncGraph(func_name)) | |||
{ | |||
var input = tf.placeholder(dtype); | |||
var output = func(input); | |||
using var graph = new FuncGraph(func_name); | |||
graph.as_default(); | |||
OutputStructure = output.structure; | |||
var input = tf.placeholder(dtype); | |||
var output = func(input); | |||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
_handle = graph.ToGraph(opers, | |||
new[] { input }, | |||
new[] { output.variant_tensor }, | |||
null); | |||
} | |||
OutputStructure = output.structure; | |||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
_handle = graph.ToGraph(opers, | |||
new[] { input }, | |||
new[] { output.variant_tensor }, | |||
null); | |||
} | |||
public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func, | |||
@@ -87,22 +87,22 @@ namespace Tensorflow.Functions | |||
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||
// IntPtr func_handle; | |||
using (var graph = new FuncGraph(func_name)) | |||
{ | |||
var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args"); | |||
var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args"); | |||
var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args"); | |||
var outputs = func(input1, (input2, input3)); | |||
Outputs = new[] { outputs.Item1, outputs.Item2 }; | |||
OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() }; | |||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
_handle = graph.ToGraph(opers, | |||
new[] { input1, input2, input3 }, | |||
new[] { outputs.Item1, outputs.Item2 }, | |||
null); | |||
} | |||
using var graph = new FuncGraph(func_name); | |||
graph.as_default(); | |||
var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args"); | |||
var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args"); | |||
var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args"); | |||
var outputs = func(input1, (input2, input3)); | |||
Outputs = new[] { outputs.Item1, outputs.Item2 }; | |||
OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() }; | |||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
_handle = graph.ToGraph(opers, | |||
new[] { input1, input2, input3 }, | |||
new[] { outputs.Item1, outputs.Item2 }, | |||
null); | |||
} | |||
public void ToGraph(Tensors inputs, Tensors outputs) | |||
@@ -26,6 +26,7 @@ namespace Tensorflow.Functions | |||
var output_names = new string[0]; | |||
_func_graph = new FuncGraph(graph, name, attrs); | |||
_func_graph.as_default(); | |||
_func_graph.ToGraph(operations, inputs, outputs, output_names); | |||
} | |||
@@ -85,6 +85,7 @@ namespace Tensorflow.Functions | |||
var gradients_wrt_outputs = new List<Tensor>(); | |||
var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{ops.uid()}"); | |||
backwards_graph.as_default(); | |||
foreach (var output in trainable_outputs) | |||
gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); | |||
var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | |||
@@ -13,6 +13,7 @@ namespace Tensorflow.Graphs | |||
// IntPtr func_handle; | |||
using (var graph = new FuncGraph(func_name)) | |||
{ | |||
graph.as_default(); | |||
var input = tf.placeholder(tf.int32); | |||
var output = func(input); | |||
@@ -43,6 +44,7 @@ namespace Tensorflow.Graphs | |||
// IntPtr func_handle; | |||
using (var graph = new FuncGraph(func_name)) | |||
{ | |||
graph.as_default(); | |||
var input1 = tf.placeholder(tf.int32); | |||
var input2 = tf.placeholder(tf.int32); | |||
var output = func(input1, input2); | |||
@@ -30,9 +30,6 @@ namespace Tensorflow.Graphs | |||
public Tensor[] internal_captures() | |||
=> _captures.Select(x => x.Value.Item2).ToArray(); | |||
// new Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>(); | |||
// public new Tensor[] external_captures => _captures.Values.Select(x => x.Item1).ToArray(); | |||
/// <summary> | |||
/// Construct a new FuncGraph. | |||
/// </summary> | |||
@@ -43,8 +40,6 @@ namespace Tensorflow.Graphs | |||
outer_graph = outer_graph.OuterGraph; | |||
_graph_key = name; | |||
building_function = true; | |||
tf.Context.graph_mode(); | |||
as_default(); | |||
} | |||
public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base() | |||
@@ -58,9 +53,6 @@ namespace Tensorflow.Graphs | |||
// Will to test if FuncGraph has memory leak | |||
// c_api.TF_DeleteGraph(_handle); | |||
_handle = handle; | |||
tf.Context.graph_mode(); | |||
as_default(); | |||
} | |||
public IntPtr ToGraph(Operation[] opers, | |||
@@ -110,11 +102,21 @@ namespace Tensorflow.Graphs | |||
return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device); | |||
} | |||
public Tensor capture(Tensor tensor, string name = null, TF_DataType shape = TF_DataType.DtInvalid) | |||
const int _EAGER_CONST_THRESHOLD = 128; | |||
public Tensor capture(Tensor tensor, string name = null, TensorShape shape = null) | |||
{ | |||
if(tensor is EagerTensor) | |||
{ | |||
throw new NotImplementedException(""); | |||
if (name == null) | |||
name = ops.uid().ToString(); | |||
// Small EagerTensors are captured with Const ops | |||
if (dtypes.is_value_dtype(tensor.dtype) | |||
&& (tensor.rank == 0 || tensor.size < _EAGER_CONST_THRESHOLD)) | |||
return capture_eager_tensor(tensor, name); | |||
// Large EagerTensors and resources are captured with Placeholder ops | |||
return _capture_helper(tensor, name, shape: shape); | |||
} | |||
if(tensor.graph != this) | |||
@@ -137,6 +139,9 @@ namespace Tensorflow.Graphs | |||
return tensor; | |||
} | |||
Tensor capture_eager_tensor(Tensor tensor, string name) | |||
=> throw new NotImplementedException(""); | |||
Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null) | |||
{ | |||
Tensor placeholder = null; | |||
@@ -190,7 +195,8 @@ namespace Tensorflow.Graphs | |||
if (dtype == TF_DataType.DtInvalid) | |||
dtype = value.dtype; | |||
var placeholder = tf_with(ops.control_dependencies(null), ctl => array_ops.placeholder(dtype, shape: shape, name: name)); | |||
var placeholder = tf_with(ops.control_dependencies(null), ctl | |||
=> array_ops.placeholder(dtype, shape: shape, name: name)); | |||
// custom_gradient.copy_handle_data(value, placeholder) | |||
return placeholder; | |||
} | |||
@@ -211,6 +217,13 @@ namespace Tensorflow.Graphs | |||
} | |||
} | |||
public override Graph as_default() | |||
{ | |||
tf.Context.graph_mode(isFunc: true); | |||
ops.set_default_graph(this); | |||
return this; | |||
} | |||
protected override void DisposeManagedResources() | |||
{ | |||
base.DisposeManagedResources(); | |||
@@ -148,7 +148,7 @@ namespace Tensorflow | |||
/// Returns a context manager that makes this `Graph` the default graph. | |||
/// </summary> | |||
/// <returns></returns> | |||
public Graph as_default() | |||
public virtual Graph as_default() | |||
{ | |||
return ops.set_default_graph(this); | |||
} | |||
@@ -5,7 +5,7 @@ | |||
<AssemblyName>TensorFlow.NET</AssemblyName> | |||
<RootNamespace>Tensorflow</RootNamespace> | |||
<TargetTensorFlow>2.2.0</TargetTensorFlow> | |||
<Version>0.31.2</Version> | |||
<Version>0.32.0</Version> | |||
<LangVersion>8.0</LangVersion> | |||
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | |||
<Company>SciSharp STACK</Company> | |||
@@ -15,7 +15,7 @@ | |||
<RepositoryType>git</RepositoryType> | |||
<PackageProjectUrl>http://scisharpstack.org</PackageProjectUrl> | |||
<PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&v=4</PackageIconUrl> | |||
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#, TF.NET</PackageTags> | |||
<PackageTags>TensorFlow, SciSharp, Machine Learning, TensorFlow.NET, TF.NET, AI</PackageTags> | |||
<Description>Google's TensorFlow full binding in .NET Standard. | |||
Building, training and infering deep learning models. | |||
https://tensorflownet.readthedocs.io</Description> | |||
@@ -293,5 +293,12 @@ namespace Tensorflow | |||
else | |||
return self; | |||
} | |||
public static bool is_value_dtype(this TF_DataType type) | |||
{ | |||
return ((int)type >= 1 && (int)type <= 19) | |||
|| type == TF_DataType.TF_UINT32 | |||
|| type == TF_DataType.TF_UINT64; | |||
} | |||
} | |||
} |
@@ -113,13 +113,13 @@ namespace Tensorflow | |||
{ | |||
if (tf.executing_eagerly()) | |||
return eager_tensor; | |||
/*else | |||
else | |||
{ | |||
var graph = get_default_graph(); | |||
if (!graph.building_function) | |||
throw new RuntimeError("Attempting to capture an EagerTensor without building a function."); | |||
return (graph as FuncGraph).capture(eager_tensor, name: name); | |||
}*/ | |||
} | |||
} | |||
Tensor ret = value switch | |||
@@ -83,6 +83,7 @@ namespace Tensorflow.Keras | |||
{ | |||
if (_GRAPH == null) | |||
_GRAPH = new FuncGraph("keras_graph"); | |||
return _GRAPH; | |||
} | |||
return ops.get_default_graph(); | |||
@@ -1,10 +1,12 @@ | |||
namespace Tensorflow.Keras.Engine | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public class CallContext | |||
{ | |||
public CallContextManager enter() | |||
public CallContextManager enter(bool build_graph) | |||
{ | |||
return new CallContextManager(); | |||
return new CallContextManager(build_graph); | |||
} | |||
} | |||
} |
@@ -1,12 +1,20 @@ | |||
using System; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public class CallContextManager : IDisposable | |||
{ | |||
public void Dispose() | |||
bool _build_graph; | |||
public CallContextManager(bool build_graph) | |||
{ | |||
_build_graph = build_graph; | |||
} | |||
public void Dispose() | |||
{ | |||
} | |||
} | |||
} |
@@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Engine | |||
Tensors outputs = null; | |||
var eager = tf.executing_eagerly(); | |||
using var ctxManager = CallContext.enter(); | |||
using var ctxManager = CallContext.enter(build_graph: false); | |||
string nameScope = ""; | |||
if (eager) | |||
@@ -33,9 +33,6 @@ namespace Tensorflow.Keras.Engine | |||
else | |||
nameScope = _name_scope(); | |||
if (!inputs.IsEagerTensor) | |||
tf.Context.graph_mode(); | |||
tf_with(ops.name_scope(nameScope), scope => | |||
{ | |||
if (!built) | |||
@@ -48,9 +45,6 @@ namespace Tensorflow.Keras.Engine | |||
_set_mask_metadata(inputs, outputs, null); | |||
}); | |||
if (!inputs.IsEagerTensor) | |||
tf.Context.restore_mode(); | |||
return outputs; | |||
} | |||
} | |||
@@ -21,12 +21,10 @@ namespace Tensorflow.Keras.Engine | |||
base_layer_utils.create_keras_history(inputs); | |||
Tensors outputs = null; | |||
using var ctxManager = CallContext.enter(); | |||
using var ctxManager = CallContext.enter(build_graph: true); | |||
// using var graph = keras.backend.get_graph(); | |||
if (!inputs.IsEagerTensor) | |||
tf.Context.graph_mode(isFunc: true); | |||
var graph = keras.backend.get_graph(); | |||
graph.as_default(); | |||
tf_with(ops.name_scope(_name_scope()), scope => | |||
{ | |||
@@ -48,8 +46,7 @@ namespace Tensorflow.Keras.Engine | |||
_set_mask_metadata(inputs, outputs, null); | |||
}); | |||
if (!inputs.IsEagerTensor) | |||
tf.Context.restore_mode(); | |||
tf.Context.restore_mode(); | |||
return outputs; | |||
} | |||
@@ -180,7 +180,7 @@ namespace Tensorflow.Keras.Engine | |||
if (inputs.IsEagerTensor || tf.Context.is_build_function()) | |||
{ | |||
need_restore_mode = true; | |||
tf.Context.eager_mode(); | |||
tf.Context.eager_mode(isFunc: tf.Context.is_build_function()); | |||
} | |||
build(inputs); | |||
@@ -4,6 +4,7 @@ using Tensorflow.Keras.Engine.DataAdapters; | |||
using Tensorflow.Keras.Losses; | |||
using Tensorflow.Keras.Optimizers; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
@@ -51,6 +52,7 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
// Used to cache `trainable` attr of `Layer`s for `fit`. | |||
_compiled_trainable_state = _get_trainable_state(); | |||
keras.backend._GRAPH = null; | |||
} | |||
void _init_batch_counters() | |||
@@ -24,7 +24,7 @@ | |||
Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent & simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear & actionable error messages.</Description> | |||
<Company>SciSharp STACK</Company> | |||
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||
<PackageTags>tensorflow, keras, deep learning, machine learning</PackageTags> | |||
<PackageTags>tensorflow, keras, deep learning, machine learning, scisharp</PackageTags> | |||
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | |||
<RepositoryType>Git</RepositoryType> | |||
<SignAssembly>true</SignAssembly> | |||