Browse Source

Refactor execution mode switch in keras layer.

tags/keras_v0.3.0
Oceania2018 4 years ago
parent
commit
05dd652c8b
18 changed files with 107 additions and 74 deletions
  1. +5
    -0
      src/TensorFlowNET.Core/Contexts/Context.cs
  2. +38
    -38
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  3. +1
    -0
      src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs
  4. +1
    -0
      src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
  5. +2
    -0
      src/TensorFlowNET.Core/Graphs/AutoGraph.cs
  6. +24
    -11
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  8. +2
    -2
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  9. +7
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  10. +2
    -2
      src/TensorFlowNET.Core/ops.cs
  11. +1
    -0
      src/TensorFlowNET.Keras/BackendImpl.cs
  12. +5
    -3
      src/TensorFlowNET.Keras/Engine/CallContext.cs
  13. +9
    -1
      src/TensorFlowNET.Keras/Engine/CallContextManager.cs
  14. +1
    -7
      src/TensorFlowNET.Keras/Engine/Layer.Apply.cs
  15. +4
    -7
      src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs
  16. +1
    -1
      src/TensorFlowNET.Keras/Engine/Layer.cs
  17. +2
    -0
      src/TensorFlowNET.Keras/Engine/Model.cs
  18. +1
    -1
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj

+ 5
- 0
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -105,6 +105,11 @@ namespace Tensorflow.Contexts
context_switches.Pop(); context_switches.Pop();
} }


public void reset_context()
{
c_api.TFE_ContextClearCaches(_handle);
}

public void Dispose() public void Dispose()
=> _handle.Dispose(); => _handle.Dispose();
} }


+ 38
- 38
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -34,6 +34,7 @@ namespace Tensorflow.Functions
public ConcreteFunction(string name) public ConcreteFunction(string name)
{ {
func_graph = new FuncGraph(name); func_graph = new FuncGraph(name);
func_graph.as_default();
} }


public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs) public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs)
@@ -48,17 +49,16 @@ namespace Tensorflow.Functions
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";


// IntPtr func_handle; // IntPtr func_handle;
using (var graph = new FuncGraph(func_name))
{
var input = tf.placeholder(dtype);
var output = func(input);

var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
_handle = graph.ToGraph(opers,
new[] { input },
new[] { output },
null);
}
using var graph = new FuncGraph(func_name);
graph.as_default();
var input = tf.placeholder(dtype);
var output = func(input);

var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
_handle = graph.ToGraph(opers,
new[] { input },
new[] { output },
null);
} }


public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype)
@@ -66,19 +66,19 @@ namespace Tensorflow.Functions
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";


// IntPtr func_handle; // IntPtr func_handle;
using (var graph = new FuncGraph(func_name))
{
var input = tf.placeholder(dtype);
var output = func(input);
using var graph = new FuncGraph(func_name);
graph.as_default();


OutputStructure = output.structure;
var input = tf.placeholder(dtype);
var output = func(input);


var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
_handle = graph.ToGraph(opers,
new[] { input },
new[] { output.variant_tensor },
null);
}
OutputStructure = output.structure;

var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
_handle = graph.ToGraph(opers,
new[] { input },
new[] { output.variant_tensor },
null);
} }


public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func, public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func,
@@ -87,22 +87,22 @@ namespace Tensorflow.Functions
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";


// IntPtr func_handle; // IntPtr func_handle;
using (var graph = new FuncGraph(func_name))
{
var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args");
var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args");
var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args");
var outputs = func(input1, (input2, input3));
Outputs = new[] { outputs.Item1, outputs.Item2 };
OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() };
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
_handle = graph.ToGraph(opers,
new[] { input1, input2, input3 },
new[] { outputs.Item1, outputs.Item2 },
null);
}
using var graph = new FuncGraph(func_name);
graph.as_default();
var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args");
var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args");
var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args");
var outputs = func(input1, (input2, input3));
Outputs = new[] { outputs.Item1, outputs.Item2 };
OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() };
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
_handle = graph.ToGraph(opers,
new[] { input1, input2, input3 },
new[] { outputs.Item1, outputs.Item2 },
null);
} }


public void ToGraph(Tensors inputs, Tensors outputs) public void ToGraph(Tensors inputs, Tensors outputs)


+ 1
- 0
src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs View File

@@ -26,6 +26,7 @@ namespace Tensorflow.Functions
var output_names = new string[0]; var output_names = new string[0];


_func_graph = new FuncGraph(graph, name, attrs); _func_graph = new FuncGraph(graph, name, attrs);
_func_graph.as_default();
_func_graph.ToGraph(operations, inputs, outputs, output_names); _func_graph.ToGraph(operations, inputs, outputs, output_names);
} }




+ 1
- 0
src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs View File

@@ -85,6 +85,7 @@ namespace Tensorflow.Functions


var gradients_wrt_outputs = new List<Tensor>(); var gradients_wrt_outputs = new List<Tensor>();
var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{ops.uid()}"); var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{ops.uid()}");
backwards_graph.as_default();
foreach (var output in trainable_outputs) foreach (var output in trainable_outputs)
gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape));
var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(),


+ 2
- 0
src/TensorFlowNET.Core/Graphs/AutoGraph.cs View File

@@ -13,6 +13,7 @@ namespace Tensorflow.Graphs
// IntPtr func_handle; // IntPtr func_handle;
using (var graph = new FuncGraph(func_name)) using (var graph = new FuncGraph(func_name))
{ {
graph.as_default();
var input = tf.placeholder(tf.int32); var input = tf.placeholder(tf.int32);
var output = func(input); var output = func(input);


@@ -43,6 +44,7 @@ namespace Tensorflow.Graphs
// IntPtr func_handle; // IntPtr func_handle;
using (var graph = new FuncGraph(func_name)) using (var graph = new FuncGraph(func_name))
{ {
graph.as_default();
var input1 = tf.placeholder(tf.int32); var input1 = tf.placeholder(tf.int32);
var input2 = tf.placeholder(tf.int32); var input2 = tf.placeholder(tf.int32);
var output = func(input1, input2); var output = func(input1, input2);


+ 24
- 11
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -30,9 +30,6 @@ namespace Tensorflow.Graphs
public Tensor[] internal_captures() public Tensor[] internal_captures()
=> _captures.Select(x => x.Value.Item2).ToArray(); => _captures.Select(x => x.Value.Item2).ToArray();


// new Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>();
// public new Tensor[] external_captures => _captures.Values.Select(x => x.Item1).ToArray();

/// <summary> /// <summary>
/// Construct a new FuncGraph. /// Construct a new FuncGraph.
/// </summary> /// </summary>
@@ -43,8 +40,6 @@ namespace Tensorflow.Graphs
outer_graph = outer_graph.OuterGraph; outer_graph = outer_graph.OuterGraph;
_graph_key = name; _graph_key = name;
building_function = true; building_function = true;
tf.Context.graph_mode();
as_default();
} }


public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base() public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base()
@@ -58,9 +53,6 @@ namespace Tensorflow.Graphs
// Will to test if FuncGraph has memory leak // Will to test if FuncGraph has memory leak
// c_api.TF_DeleteGraph(_handle); // c_api.TF_DeleteGraph(_handle);
_handle = handle; _handle = handle;

tf.Context.graph_mode();
as_default();
} }


public IntPtr ToGraph(Operation[] opers, public IntPtr ToGraph(Operation[] opers,
@@ -110,11 +102,21 @@ namespace Tensorflow.Graphs
return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device); return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device);
} }


public Tensor capture(Tensor tensor, string name = null, TF_DataType shape = TF_DataType.DtInvalid)
const int _EAGER_CONST_THRESHOLD = 128;
public Tensor capture(Tensor tensor, string name = null, TensorShape shape = null)
{ {
if(tensor is EagerTensor) if(tensor is EagerTensor)
{ {
throw new NotImplementedException("");
if (name == null)
name = ops.uid().ToString();

// Small EagerTensors are captured with Const ops
if (dtypes.is_value_dtype(tensor.dtype)
&& (tensor.rank == 0 || tensor.size < _EAGER_CONST_THRESHOLD))
return capture_eager_tensor(tensor, name);

// Large EagerTensors and resources are captured with Placeholder ops
return _capture_helper(tensor, name, shape: shape);
} }


if(tensor.graph != this) if(tensor.graph != this)
@@ -137,6 +139,9 @@ namespace Tensorflow.Graphs
return tensor; return tensor;
} }


Tensor capture_eager_tensor(Tensor tensor, string name)
=> throw new NotImplementedException("");

Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null) Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null)
{ {
Tensor placeholder = null; Tensor placeholder = null;
@@ -190,7 +195,8 @@ namespace Tensorflow.Graphs
if (dtype == TF_DataType.DtInvalid) if (dtype == TF_DataType.DtInvalid)
dtype = value.dtype; dtype = value.dtype;


var placeholder = tf_with(ops.control_dependencies(null), ctl => array_ops.placeholder(dtype, shape: shape, name: name));
var placeholder = tf_with(ops.control_dependencies(null), ctl
=> array_ops.placeholder(dtype, shape: shape, name: name));
// custom_gradient.copy_handle_data(value, placeholder) // custom_gradient.copy_handle_data(value, placeholder)
return placeholder; return placeholder;
} }
@@ -211,6 +217,13 @@ namespace Tensorflow.Graphs
} }
} }


public override Graph as_default()
{
tf.Context.graph_mode(isFunc: true);
ops.set_default_graph(this);
return this;
}

protected override void DisposeManagedResources() protected override void DisposeManagedResources()
{ {
base.DisposeManagedResources(); base.DisposeManagedResources();


+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -148,7 +148,7 @@ namespace Tensorflow
/// Returns a context manager that makes this `Graph` the default graph. /// Returns a context manager that makes this `Graph` the default graph.
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
public Graph as_default()
public virtual Graph as_default()
{ {
return ops.set_default_graph(this); return ops.set_default_graph(this);
} }


+ 2
- 2
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName> <AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace> <RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>2.2.0</TargetTensorFlow> <TargetTensorFlow>2.2.0</TargetTensorFlow>
<Version>0.31.2</Version>
<Version>0.32.0</Version>
<LangVersion>8.0</LangVersion> <LangVersion>8.0</LangVersion>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
<Company>SciSharp STACK</Company> <Company>SciSharp STACK</Company>
@@ -15,7 +15,7 @@
<RepositoryType>git</RepositoryType> <RepositoryType>git</RepositoryType>
<PackageProjectUrl>http://scisharpstack.org</PackageProjectUrl> <PackageProjectUrl>http://scisharpstack.org</PackageProjectUrl>
<PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&amp;v=4</PackageIconUrl> <PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&amp;v=4</PackageIconUrl>
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#, TF.NET</PackageTags>
<PackageTags>TensorFlow, SciSharp, Machine Learning, TensorFlow.NET, TF.NET, AI</PackageTags>
<Description>Google's TensorFlow full binding in .NET Standard. <Description>Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models. Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io</Description> https://tensorflownet.readthedocs.io</Description>


+ 7
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -293,5 +293,12 @@ namespace Tensorflow
else else
return self; return self;
} }

public static bool is_value_dtype(this TF_DataType type)
{
return ((int)type >= 1 && (int)type <= 19)
|| type == TF_DataType.TF_UINT32
|| type == TF_DataType.TF_UINT64;
}
} }
} }

+ 2
- 2
src/TensorFlowNET.Core/ops.cs View File

@@ -113,13 +113,13 @@ namespace Tensorflow
{ {
if (tf.executing_eagerly()) if (tf.executing_eagerly())
return eager_tensor; return eager_tensor;
/*else
else
{ {
var graph = get_default_graph(); var graph = get_default_graph();
if (!graph.building_function) if (!graph.building_function)
throw new RuntimeError("Attempting to capture an EagerTensor without building a function."); throw new RuntimeError("Attempting to capture an EagerTensor without building a function.");
return (graph as FuncGraph).capture(eager_tensor, name: name); return (graph as FuncGraph).capture(eager_tensor, name: name);
}*/
}
} }


Tensor ret = value switch Tensor ret = value switch


+ 1
- 0
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -83,6 +83,7 @@ namespace Tensorflow.Keras
{ {
if (_GRAPH == null) if (_GRAPH == null)
_GRAPH = new FuncGraph("keras_graph"); _GRAPH = new FuncGraph("keras_graph");
return _GRAPH; return _GRAPH;
} }
return ops.get_default_graph(); return ops.get_default_graph();


+ 5
- 3
src/TensorFlowNET.Keras/Engine/CallContext.cs View File

@@ -1,10 +1,12 @@
namespace Tensorflow.Keras.Engine
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
{ {
public class CallContext public class CallContext
{ {
public CallContextManager enter()
public CallContextManager enter(bool build_graph)
{ {
return new CallContextManager();
return new CallContextManager(build_graph);
} }
} }
} }

+ 9
- 1
src/TensorFlowNET.Keras/Engine/CallContextManager.cs View File

@@ -1,12 +1,20 @@
using System; using System;
using static Tensorflow.Binding;


namespace Tensorflow.Keras.Engine namespace Tensorflow.Keras.Engine
{ {
public class CallContextManager : IDisposable public class CallContextManager : IDisposable
{ {
public void Dispose()
bool _build_graph;

public CallContextManager(bool build_graph)
{ {
_build_graph = build_graph;
}


public void Dispose()
{
} }
} }
} }

+ 1
- 7
src/TensorFlowNET.Keras/Engine/Layer.Apply.cs View File

@@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Engine
Tensors outputs = null; Tensors outputs = null;


var eager = tf.executing_eagerly(); var eager = tf.executing_eagerly();
using var ctxManager = CallContext.enter();
using var ctxManager = CallContext.enter(build_graph: false);


string nameScope = ""; string nameScope = "";
if (eager) if (eager)
@@ -33,9 +33,6 @@ namespace Tensorflow.Keras.Engine
else else
nameScope = _name_scope(); nameScope = _name_scope();


if (!inputs.IsEagerTensor)
tf.Context.graph_mode();

tf_with(ops.name_scope(nameScope), scope => tf_with(ops.name_scope(nameScope), scope =>
{ {
if (!built) if (!built)
@@ -48,9 +45,6 @@ namespace Tensorflow.Keras.Engine
_set_mask_metadata(inputs, outputs, null); _set_mask_metadata(inputs, outputs, null);
}); });


if (!inputs.IsEagerTensor)
tf.Context.restore_mode();

return outputs; return outputs;
} }
} }


+ 4
- 7
src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs View File

@@ -21,12 +21,10 @@ namespace Tensorflow.Keras.Engine
base_layer_utils.create_keras_history(inputs); base_layer_utils.create_keras_history(inputs);


Tensors outputs = null; Tensors outputs = null;
using var ctxManager = CallContext.enter();
using var ctxManager = CallContext.enter(build_graph: true);


// using var graph = keras.backend.get_graph();

if (!inputs.IsEagerTensor)
tf.Context.graph_mode(isFunc: true);
var graph = keras.backend.get_graph();
graph.as_default();


tf_with(ops.name_scope(_name_scope()), scope => tf_with(ops.name_scope(_name_scope()), scope =>
{ {
@@ -48,8 +46,7 @@ namespace Tensorflow.Keras.Engine
_set_mask_metadata(inputs, outputs, null); _set_mask_metadata(inputs, outputs, null);
}); });


if (!inputs.IsEagerTensor)
tf.Context.restore_mode();
tf.Context.restore_mode();


return outputs; return outputs;
} }


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -180,7 +180,7 @@ namespace Tensorflow.Keras.Engine
if (inputs.IsEagerTensor || tf.Context.is_build_function()) if (inputs.IsEagerTensor || tf.Context.is_build_function())
{ {
need_restore_mode = true; need_restore_mode = true;
tf.Context.eager_mode();
tf.Context.eager_mode(isFunc: tf.Context.is_build_function());
} }
build(inputs); build(inputs);


+ 2
- 0
src/TensorFlowNET.Keras/Engine/Model.cs View File

@@ -4,6 +4,7 @@ using Tensorflow.Keras.Engine.DataAdapters;
using Tensorflow.Keras.Losses; using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.Optimizers;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using static Tensorflow.KerasApi;


namespace Tensorflow.Keras.Engine namespace Tensorflow.Keras.Engine
{ {
@@ -51,6 +52,7 @@ namespace Tensorflow.Keras.Engine
{ {
// Used to cache `trainable` attr of `Layer`s for `fit`. // Used to cache `trainable` attr of `Layer`s for `fit`.
_compiled_trainable_state = _get_trainable_state(); _compiled_trainable_state = _get_trainable_state();
keras.backend._GRAPH = null;
} }


void _init_batch_counters() void _init_batch_counters()


+ 1
- 1
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -24,7 +24,7 @@
Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent &amp; simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear &amp; actionable error messages.</Description> Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent &amp; simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear &amp; actionable error messages.</Description>
<Company>SciSharp STACK</Company> <Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> <GeneratePackageOnBuild>true</GeneratePackageOnBuild>
<PackageTags>tensorflow, keras, deep learning, machine learning</PackageTags>
<PackageTags>tensorflow, keras, deep learning, machine learning, scisharp</PackageTags>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<RepositoryType>Git</RepositoryType> <RepositoryType>Git</RepositoryType>
<SignAssembly>true</SignAssembly> <SignAssembly>true</SignAssembly>


Loading…
Cancel
Save