@@ -105,6 +105,11 @@ namespace Tensorflow.Contexts | |||||
context_switches.Pop(); | context_switches.Pop(); | ||||
} | } | ||||
public void reset_context() | |||||
{ | |||||
c_api.TFE_ContextClearCaches(_handle); | |||||
} | |||||
public void Dispose() | public void Dispose() | ||||
=> _handle.Dispose(); | => _handle.Dispose(); | ||||
} | } | ||||
@@ -34,6 +34,7 @@ namespace Tensorflow.Functions | |||||
public ConcreteFunction(string name) | public ConcreteFunction(string name) | ||||
{ | { | ||||
func_graph = new FuncGraph(name); | func_graph = new FuncGraph(name); | ||||
func_graph.as_default(); | |||||
} | } | ||||
public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs) | public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs) | ||||
@@ -48,17 +49,16 @@ namespace Tensorflow.Functions | |||||
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | ||||
// IntPtr func_handle; | // IntPtr func_handle; | ||||
using (var graph = new FuncGraph(func_name)) | |||||
{ | |||||
var input = tf.placeholder(dtype); | |||||
var output = func(input); | |||||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||||
_handle = graph.ToGraph(opers, | |||||
new[] { input }, | |||||
new[] { output }, | |||||
null); | |||||
} | |||||
using var graph = new FuncGraph(func_name); | |||||
graph.as_default(); | |||||
var input = tf.placeholder(dtype); | |||||
var output = func(input); | |||||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||||
_handle = graph.ToGraph(opers, | |||||
new[] { input }, | |||||
new[] { output }, | |||||
null); | |||||
} | } | ||||
public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | ||||
@@ -66,19 +66,19 @@ namespace Tensorflow.Functions | |||||
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | ||||
// IntPtr func_handle; | // IntPtr func_handle; | ||||
using (var graph = new FuncGraph(func_name)) | |||||
{ | |||||
var input = tf.placeholder(dtype); | |||||
var output = func(input); | |||||
using var graph = new FuncGraph(func_name); | |||||
graph.as_default(); | |||||
OutputStructure = output.structure; | |||||
var input = tf.placeholder(dtype); | |||||
var output = func(input); | |||||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||||
_handle = graph.ToGraph(opers, | |||||
new[] { input }, | |||||
new[] { output.variant_tensor }, | |||||
null); | |||||
} | |||||
OutputStructure = output.structure; | |||||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||||
_handle = graph.ToGraph(opers, | |||||
new[] { input }, | |||||
new[] { output.variant_tensor }, | |||||
null); | |||||
} | } | ||||
public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func, | public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func, | ||||
@@ -87,22 +87,22 @@ namespace Tensorflow.Functions | |||||
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | ||||
// IntPtr func_handle; | // IntPtr func_handle; | ||||
using (var graph = new FuncGraph(func_name)) | |||||
{ | |||||
var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args"); | |||||
var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args"); | |||||
var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args"); | |||||
var outputs = func(input1, (input2, input3)); | |||||
Outputs = new[] { outputs.Item1, outputs.Item2 }; | |||||
OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() }; | |||||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||||
_handle = graph.ToGraph(opers, | |||||
new[] { input1, input2, input3 }, | |||||
new[] { outputs.Item1, outputs.Item2 }, | |||||
null); | |||||
} | |||||
using var graph = new FuncGraph(func_name); | |||||
graph.as_default(); | |||||
var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args"); | |||||
var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args"); | |||||
var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args"); | |||||
var outputs = func(input1, (input2, input3)); | |||||
Outputs = new[] { outputs.Item1, outputs.Item2 }; | |||||
OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() }; | |||||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||||
_handle = graph.ToGraph(opers, | |||||
new[] { input1, input2, input3 }, | |||||
new[] { outputs.Item1, outputs.Item2 }, | |||||
null); | |||||
} | } | ||||
public void ToGraph(Tensors inputs, Tensors outputs) | public void ToGraph(Tensors inputs, Tensors outputs) | ||||
@@ -26,6 +26,7 @@ namespace Tensorflow.Functions | |||||
var output_names = new string[0]; | var output_names = new string[0]; | ||||
_func_graph = new FuncGraph(graph, name, attrs); | _func_graph = new FuncGraph(graph, name, attrs); | ||||
_func_graph.as_default(); | |||||
_func_graph.ToGraph(operations, inputs, outputs, output_names); | _func_graph.ToGraph(operations, inputs, outputs, output_names); | ||||
} | } | ||||
@@ -85,6 +85,7 @@ namespace Tensorflow.Functions | |||||
var gradients_wrt_outputs = new List<Tensor>(); | var gradients_wrt_outputs = new List<Tensor>(); | ||||
var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{ops.uid()}"); | var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{ops.uid()}"); | ||||
backwards_graph.as_default(); | |||||
foreach (var output in trainable_outputs) | foreach (var output in trainable_outputs) | ||||
gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); | gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); | ||||
var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | ||||
@@ -13,6 +13,7 @@ namespace Tensorflow.Graphs | |||||
// IntPtr func_handle; | // IntPtr func_handle; | ||||
using (var graph = new FuncGraph(func_name)) | using (var graph = new FuncGraph(func_name)) | ||||
{ | { | ||||
graph.as_default(); | |||||
var input = tf.placeholder(tf.int32); | var input = tf.placeholder(tf.int32); | ||||
var output = func(input); | var output = func(input); | ||||
@@ -43,6 +44,7 @@ namespace Tensorflow.Graphs | |||||
// IntPtr func_handle; | // IntPtr func_handle; | ||||
using (var graph = new FuncGraph(func_name)) | using (var graph = new FuncGraph(func_name)) | ||||
{ | { | ||||
graph.as_default(); | |||||
var input1 = tf.placeholder(tf.int32); | var input1 = tf.placeholder(tf.int32); | ||||
var input2 = tf.placeholder(tf.int32); | var input2 = tf.placeholder(tf.int32); | ||||
var output = func(input1, input2); | var output = func(input1, input2); | ||||
@@ -30,9 +30,6 @@ namespace Tensorflow.Graphs | |||||
public Tensor[] internal_captures() | public Tensor[] internal_captures() | ||||
=> _captures.Select(x => x.Value.Item2).ToArray(); | => _captures.Select(x => x.Value.Item2).ToArray(); | ||||
// new Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>(); | |||||
// public new Tensor[] external_captures => _captures.Values.Select(x => x.Item1).ToArray(); | |||||
/// <summary> | /// <summary> | ||||
/// Construct a new FuncGraph. | /// Construct a new FuncGraph. | ||||
/// </summary> | /// </summary> | ||||
@@ -43,8 +40,6 @@ namespace Tensorflow.Graphs | |||||
outer_graph = outer_graph.OuterGraph; | outer_graph = outer_graph.OuterGraph; | ||||
_graph_key = name; | _graph_key = name; | ||||
building_function = true; | building_function = true; | ||||
tf.Context.graph_mode(); | |||||
as_default(); | |||||
} | } | ||||
public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base() | public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base() | ||||
@@ -58,9 +53,6 @@ namespace Tensorflow.Graphs | |||||
// Will to test if FuncGraph has memory leak | // Will to test if FuncGraph has memory leak | ||||
// c_api.TF_DeleteGraph(_handle); | // c_api.TF_DeleteGraph(_handle); | ||||
_handle = handle; | _handle = handle; | ||||
tf.Context.graph_mode(); | |||||
as_default(); | |||||
} | } | ||||
public IntPtr ToGraph(Operation[] opers, | public IntPtr ToGraph(Operation[] opers, | ||||
@@ -110,11 +102,21 @@ namespace Tensorflow.Graphs | |||||
return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device); | return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device); | ||||
} | } | ||||
public Tensor capture(Tensor tensor, string name = null, TF_DataType shape = TF_DataType.DtInvalid) | |||||
const int _EAGER_CONST_THRESHOLD = 128; | |||||
public Tensor capture(Tensor tensor, string name = null, TensorShape shape = null) | |||||
{ | { | ||||
if(tensor is EagerTensor) | if(tensor is EagerTensor) | ||||
{ | { | ||||
throw new NotImplementedException(""); | |||||
if (name == null) | |||||
name = ops.uid().ToString(); | |||||
// Small EagerTensors are captured with Const ops | |||||
if (dtypes.is_value_dtype(tensor.dtype) | |||||
&& (tensor.rank == 0 || tensor.size < _EAGER_CONST_THRESHOLD)) | |||||
return capture_eager_tensor(tensor, name); | |||||
// Large EagerTensors and resources are captured with Placeholder ops | |||||
return _capture_helper(tensor, name, shape: shape); | |||||
} | } | ||||
if(tensor.graph != this) | if(tensor.graph != this) | ||||
@@ -137,6 +139,9 @@ namespace Tensorflow.Graphs | |||||
return tensor; | return tensor; | ||||
} | } | ||||
Tensor capture_eager_tensor(Tensor tensor, string name) | |||||
=> throw new NotImplementedException(""); | |||||
Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null) | Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null) | ||||
{ | { | ||||
Tensor placeholder = null; | Tensor placeholder = null; | ||||
@@ -190,7 +195,8 @@ namespace Tensorflow.Graphs | |||||
if (dtype == TF_DataType.DtInvalid) | if (dtype == TF_DataType.DtInvalid) | ||||
dtype = value.dtype; | dtype = value.dtype; | ||||
var placeholder = tf_with(ops.control_dependencies(null), ctl => array_ops.placeholder(dtype, shape: shape, name: name)); | |||||
var placeholder = tf_with(ops.control_dependencies(null), ctl | |||||
=> array_ops.placeholder(dtype, shape: shape, name: name)); | |||||
// custom_gradient.copy_handle_data(value, placeholder) | // custom_gradient.copy_handle_data(value, placeholder) | ||||
return placeholder; | return placeholder; | ||||
} | } | ||||
@@ -211,6 +217,13 @@ namespace Tensorflow.Graphs | |||||
} | } | ||||
} | } | ||||
public override Graph as_default() | |||||
{ | |||||
tf.Context.graph_mode(isFunc: true); | |||||
ops.set_default_graph(this); | |||||
return this; | |||||
} | |||||
protected override void DisposeManagedResources() | protected override void DisposeManagedResources() | ||||
{ | { | ||||
base.DisposeManagedResources(); | base.DisposeManagedResources(); | ||||
@@ -148,7 +148,7 @@ namespace Tensorflow | |||||
/// Returns a context manager that makes this `Graph` the default graph. | /// Returns a context manager that makes this `Graph` the default graph. | ||||
/// </summary> | /// </summary> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Graph as_default() | |||||
public virtual Graph as_default() | |||||
{ | { | ||||
return ops.set_default_graph(this); | return ops.set_default_graph(this); | ||||
} | } | ||||
@@ -5,7 +5,7 @@ | |||||
<AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
<RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
<TargetTensorFlow>2.2.0</TargetTensorFlow> | <TargetTensorFlow>2.2.0</TargetTensorFlow> | ||||
<Version>0.31.2</Version> | |||||
<Version>0.32.0</Version> | |||||
<LangVersion>8.0</LangVersion> | <LangVersion>8.0</LangVersion> | ||||
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | ||||
<Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
@@ -15,7 +15,7 @@ | |||||
<RepositoryType>git</RepositoryType> | <RepositoryType>git</RepositoryType> | ||||
<PackageProjectUrl>http://scisharpstack.org</PackageProjectUrl> | <PackageProjectUrl>http://scisharpstack.org</PackageProjectUrl> | ||||
<PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&v=4</PackageIconUrl> | <PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&v=4</PackageIconUrl> | ||||
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#, TF.NET</PackageTags> | |||||
<PackageTags>TensorFlow, SciSharp, Machine Learning, TensorFlow.NET, TF.NET, AI</PackageTags> | |||||
<Description>Google's TensorFlow full binding in .NET Standard. | <Description>Google's TensorFlow full binding in .NET Standard. | ||||
Building, training and infering deep learning models. | Building, training and infering deep learning models. | ||||
https://tensorflownet.readthedocs.io</Description> | https://tensorflownet.readthedocs.io</Description> | ||||
@@ -293,5 +293,12 @@ namespace Tensorflow | |||||
else | else | ||||
return self; | return self; | ||||
} | } | ||||
public static bool is_value_dtype(this TF_DataType type) | |||||
{ | |||||
return ((int)type >= 1 && (int)type <= 19) | |||||
|| type == TF_DataType.TF_UINT32 | |||||
|| type == TF_DataType.TF_UINT64; | |||||
} | |||||
} | } | ||||
} | } |
@@ -113,13 +113,13 @@ namespace Tensorflow | |||||
{ | { | ||||
if (tf.executing_eagerly()) | if (tf.executing_eagerly()) | ||||
return eager_tensor; | return eager_tensor; | ||||
/*else | |||||
else | |||||
{ | { | ||||
var graph = get_default_graph(); | var graph = get_default_graph(); | ||||
if (!graph.building_function) | if (!graph.building_function) | ||||
throw new RuntimeError("Attempting to capture an EagerTensor without building a function."); | throw new RuntimeError("Attempting to capture an EagerTensor without building a function."); | ||||
return (graph as FuncGraph).capture(eager_tensor, name: name); | return (graph as FuncGraph).capture(eager_tensor, name: name); | ||||
}*/ | |||||
} | |||||
} | } | ||||
Tensor ret = value switch | Tensor ret = value switch | ||||
@@ -83,6 +83,7 @@ namespace Tensorflow.Keras | |||||
{ | { | ||||
if (_GRAPH == null) | if (_GRAPH == null) | ||||
_GRAPH = new FuncGraph("keras_graph"); | _GRAPH = new FuncGraph("keras_graph"); | ||||
return _GRAPH; | return _GRAPH; | ||||
} | } | ||||
return ops.get_default_graph(); | return ops.get_default_graph(); | ||||
@@ -1,10 +1,12 @@ | |||||
namespace Tensorflow.Keras.Engine | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
public class CallContext | public class CallContext | ||||
{ | { | ||||
public CallContextManager enter() | |||||
public CallContextManager enter(bool build_graph) | |||||
{ | { | ||||
return new CallContextManager(); | |||||
return new CallContextManager(build_graph); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -1,12 +1,20 @@ | |||||
using System; | using System; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
{ | { | ||||
public class CallContextManager : IDisposable | public class CallContextManager : IDisposable | ||||
{ | { | ||||
public void Dispose() | |||||
bool _build_graph; | |||||
public CallContextManager(bool build_graph) | |||||
{ | { | ||||
_build_graph = build_graph; | |||||
} | |||||
public void Dispose() | |||||
{ | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Engine | |||||
Tensors outputs = null; | Tensors outputs = null; | ||||
var eager = tf.executing_eagerly(); | var eager = tf.executing_eagerly(); | ||||
using var ctxManager = CallContext.enter(); | |||||
using var ctxManager = CallContext.enter(build_graph: false); | |||||
string nameScope = ""; | string nameScope = ""; | ||||
if (eager) | if (eager) | ||||
@@ -33,9 +33,6 @@ namespace Tensorflow.Keras.Engine | |||||
else | else | ||||
nameScope = _name_scope(); | nameScope = _name_scope(); | ||||
if (!inputs.IsEagerTensor) | |||||
tf.Context.graph_mode(); | |||||
tf_with(ops.name_scope(nameScope), scope => | tf_with(ops.name_scope(nameScope), scope => | ||||
{ | { | ||||
if (!built) | if (!built) | ||||
@@ -48,9 +45,6 @@ namespace Tensorflow.Keras.Engine | |||||
_set_mask_metadata(inputs, outputs, null); | _set_mask_metadata(inputs, outputs, null); | ||||
}); | }); | ||||
if (!inputs.IsEagerTensor) | |||||
tf.Context.restore_mode(); | |||||
return outputs; | return outputs; | ||||
} | } | ||||
} | } | ||||
@@ -21,12 +21,10 @@ namespace Tensorflow.Keras.Engine | |||||
base_layer_utils.create_keras_history(inputs); | base_layer_utils.create_keras_history(inputs); | ||||
Tensors outputs = null; | Tensors outputs = null; | ||||
using var ctxManager = CallContext.enter(); | |||||
using var ctxManager = CallContext.enter(build_graph: true); | |||||
// using var graph = keras.backend.get_graph(); | |||||
if (!inputs.IsEagerTensor) | |||||
tf.Context.graph_mode(isFunc: true); | |||||
var graph = keras.backend.get_graph(); | |||||
graph.as_default(); | |||||
tf_with(ops.name_scope(_name_scope()), scope => | tf_with(ops.name_scope(_name_scope()), scope => | ||||
{ | { | ||||
@@ -48,8 +46,7 @@ namespace Tensorflow.Keras.Engine | |||||
_set_mask_metadata(inputs, outputs, null); | _set_mask_metadata(inputs, outputs, null); | ||||
}); | }); | ||||
if (!inputs.IsEagerTensor) | |||||
tf.Context.restore_mode(); | |||||
tf.Context.restore_mode(); | |||||
return outputs; | return outputs; | ||||
} | } | ||||
@@ -180,7 +180,7 @@ namespace Tensorflow.Keras.Engine | |||||
if (inputs.IsEagerTensor || tf.Context.is_build_function()) | if (inputs.IsEagerTensor || tf.Context.is_build_function()) | ||||
{ | { | ||||
need_restore_mode = true; | need_restore_mode = true; | ||||
tf.Context.eager_mode(); | |||||
tf.Context.eager_mode(isFunc: tf.Context.is_build_function()); | |||||
} | } | ||||
build(inputs); | build(inputs); | ||||
@@ -4,6 +4,7 @@ using Tensorflow.Keras.Engine.DataAdapters; | |||||
using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | |||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
{ | { | ||||
@@ -51,6 +52,7 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
// Used to cache `trainable` attr of `Layer`s for `fit`. | // Used to cache `trainable` attr of `Layer`s for `fit`. | ||||
_compiled_trainable_state = _get_trainable_state(); | _compiled_trainable_state = _get_trainable_state(); | ||||
keras.backend._GRAPH = null; | |||||
} | } | ||||
void _init_batch_counters() | void _init_batch_counters() | ||||
@@ -24,7 +24,7 @@ | |||||
Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent & simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear & actionable error messages.</Description> | Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent & simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear & actionable error messages.</Description> | ||||
<Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> | <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | ||||
<PackageTags>tensorflow, keras, deep learning, machine learning</PackageTags> | |||||
<PackageTags>tensorflow, keras, deep learning, machine learning, scisharp</PackageTags> | |||||
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | ||||
<RepositoryType>Git</RepositoryType> | <RepositoryType>Git</RepositoryType> | ||||
<SignAssembly>true</SignAssembly> | <SignAssembly>true</SignAssembly> | ||||