Browse Source

add FuncGraph.

v0.20-tensorflow2.3
Oceania2018 Haiping 5 years ago
parent
commit
e7da957bcb
12 changed files with 279 additions and 22 deletions
  1. +1
    -1
      src/TensorFlowNET.Console/TensorFlowNET.Console.csproj
  2. +26
    -0
      src/TensorFlowNET.Core/APIs/tf.autograph.cs
  3. +29
    -15
      src/TensorFlowNET.Core/APIs/tf.control_flow.cs
  4. +31
    -0
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Functions/c_api.function.cs
  6. +47
    -0
      src/TensorFlowNET.Core/Graphs/AutoGraph.cs
  7. +56
    -0
      src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs
  8. +54
    -0
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  9. +25
    -2
      src/TensorFlowNET.Core/Operations/control_flow_ops.cs
  10. +7
    -1
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  11. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  12. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs

+ 1
- 1
src/TensorFlowNET.Console/TensorFlowNET.Console.csproj View File

@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<OutputType>Exe</OutputType>


+ 26
- 0
src/TensorFlowNET.Core/APIs/tf.autograph.cs View File

@@ -0,0 +1,26 @@
/*****************************************************************************
Copyright 2020 Haiping Chen. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using Tensorflow.Graphs;
using Tensorflow.Operations;

namespace Tensorflow
{
public partial class tensorflow
{
public AutoGraph autograph = new AutoGraph();
}
}

+ 29
- 15
src/TensorFlowNET.Core/APIs/tf.control_flow.cs View File

@@ -15,17 +15,22 @@
******************************************************************************/

using System;
using static Tensorflow.Binding;

namespace Tensorflow
{
public partial class tensorflow
{
public Tensor cond(Tensor pred,
Tensor true_value,
Tensor false_false)
=> control_flow_ops.cond(pred, () => true_value, () => false_false);

public Tensor cond(Tensor pred,
Func<ITensorOrOperation> true_fn = null,
Func<ITensorOrOperation> false_fn = null,
bool strict = false,
string name = null)
=> control_flow_ops.cond(pred, true_fn, false_fn, strict: strict, name: name);
=> control_flow_ops.cond(pred, true_fn, false_fn, name: name);

/// <summary>
/// Create an op that groups multiple operations.
@@ -37,22 +42,31 @@ namespace Tensorflow
public Operation group<T>(T[] inputs, string name = null) where T : ITensorOrOperation
=> control_flow_ops.group(inputs, name: name);

/*public Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars,
TensorShape shape_invariants = null,
public Tensor while_loop(Func<Tensor, Tensor> cond,
Func<Tensor, Tensor> body,
Tensor loop_vars,
int parallel_iterations = 10)
{
Func<Tensor[], Tensor> cond1 = x
=> cond(x[0]);

Func<Tensor[], Tensor[]> body1 = x
=> new[] { body(x[0]) };

var results = control_flow_ops.while_loop(cond1,
body1,
new[] { loop_vars });
return results[0];
}

public Tensor[] while_loop(Func<Tensor[], Tensor> cond,
Func<Tensor[], Tensor[]> body,
Tensor[] loop_vars,
int parallel_iterations = 10,
bool back_prop = true,
bool swap_memory = false,
string name = null,
int? maximum_iterations = null,
bool return_same_structure = false)
string name = null)
=> control_flow_ops.while_loop(cond, body, loop_vars,
shape_invariants: shape_invariants,
parallel_iterations: parallel_iterations,
back_prop: back_prop,
swap_memory: swap_memory,
name: name,
maximum_iterations: maximum_iterations,
return_same_structure: return_same_structure);*/
name: name);

public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
=> ops.control_dependencies(control_inputs);


+ 31
- 0
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -78,6 +78,37 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status);

/// <summary>
/// Adds a function (created from TF_GraphToFunction or
/// TF_FunctionImportFunctionDef) to the context, allowing it to be executed with
/// TFE_Execute by creating an op with the same name as the function.
/// </summary>
/// <param name="ctx"></param>
/// <param name="function"></param>
/// <param name="status"></param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextAddFunction(SafeContextHandle ctx, IntPtr function, SafeStatusHandle status);

/// <summary>
/// Removes a function from the context. Once removed, you can no longer
/// TFE_Execute it or TFE_Execute any TFE_Op which has it as an attribute or any
/// other function which calls it as an attribute.
/// </summary>
/// <param name="ctx"></param>
/// <param name="name"></param>
/// <param name="status"></param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextRemoveFunction(SafeContextHandle ctx, string name, SafeStatusHandle status);

/// <summary>
/// Checks whether a function is registered under `name`.
/// </summary>
/// <param name="ctx"></param>
/// <param name="name"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern bool TFE_ContextHasFunction(SafeContextHandle ctx, string name);

[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextStartStep(SafeContextHandle ctx);



+ 1
- 1
src/TensorFlowNET.Core/Functions/c_api.function.cs View File

@@ -39,7 +39,7 @@ namespace Tensorflow
int num_opers, IntPtr[] opers,
int ninputs, TF_Output[] inputs,
int noutputs, TF_Output[] outputs,
IntPtr output_names,
string[] output_names,
IntPtr opts,
string description,
SafeStatusHandle status);


+ 47
- 0
src/TensorFlowNET.Core/Graphs/AutoGraph.cs View File

@@ -0,0 +1,47 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow.Graphs
{
public class AutoGraph
{
public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func)
{
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
tf.compat.v1.disable_eager_execution();
// IntPtr func_handle;
using(var graph = new FuncGraph(func_name))
{
graph.as_default();
var input1 = tf.placeholder(tf.int32);
var input2 = tf.placeholder(tf.int32);
var output = func(input1, input2);

var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
var func_handle = graph.ToGraph(opers,
new Operation[] { input1, input2 },
new Operation[] { output },
null);

c_api.TFE_ContextAddFunction(tf.Context.Handle, func_handle, tf.Status.Handle);
}

tf.enable_eager_execution();

return (Tensor a, Tensor b) =>
{
var result = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName,
func_name,
new[] { a, b },
null,
1);
return result[0];
};
}
}
}

+ 56
- 0
src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs View File

@@ -0,0 +1,56 @@
/*using MethodBoundaryAspect.Fody.Attributes;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace Tensorflow.Graphs
{
public sealed class AutoGraphAspect : OnMethodBoundaryAspect
{
FuncGraph graph;
IntPtr func_handle;

public override void OnEntry(MethodExecutionArgs args)
{
tf.compat.v1.disable_eager_execution();
// convert args to placeholder
for (var i = 0; i < args.Arguments.Length; i++)
{
if (args.Arguments[i] is EagerTensor tensor)
args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape);
}

// make function as an Operation by autograph
graph = new FuncGraph("autograph_add");
graph.as_default();
}

public override void OnExit(MethodExecutionArgs args)
{
var output = (Tensor)args.Method.Invoke(args.Instance, args.Arguments);
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
func_handle = graph.ToGraph(opers,
new Operation[] { },
new Operation[] { },
null);


c_api.TFE_ContextAddFunction(tf.Context.Handle, func_handle, tf.Status.Handle);

var a1 = tf.constant(1);
var b1 = tf.constant(2);

var result = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName,
"autograph_add",
new[] { a1, b1 },
null,
1);
graph.Dispose();
}
}
}*/

+ 54
- 0
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -0,0 +1,54 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow.Graphs
{
/// <summary>
/// Graph representing a function body.
/// </summary>
public class FuncGraph : Graph
{
List<Operation> inputs;
List<Operation> outputs;
Graph outer_graph;
string func_name;
IntPtr func_handle;
public string FuncName => c_api.StringPiece(c_api.TF_FunctionName(func_handle));

/// <summary>
/// Construct a new FuncGraph.
/// </summary>
public FuncGraph(string name) : base()
{
outer_graph = ops.get_default_graph();
func_name = name;
}

public IntPtr ToGraph(Operation[] opers,
Operation[] inputs, Operation[] outputs,
string[] output_names)
{
using var status = new Status();
func_handle = c_api.TF_GraphToFunction(_handle,
func_name,
false,
opers.Length,
opers.Select(x => (IntPtr)x).ToArray(),
inputs.Length,
inputs.Select(x => new TF_Output(x, 0)).ToArray(),
outputs.Length,
outputs.Select(x => new TF_Output(x, 0)).ToArray(),
output_names == null || output_names.Length == 0 ? null : output_names,
IntPtr.Zero,
null,
status.Handle);

c_api.TF_GraphCopyFunction(outer_graph, func_handle, IntPtr.Zero, status.Handle);

return func_handle;
}
}
}

+ 25
- 2
src/TensorFlowNET.Core/Operations/control_flow_ops.cs View File

@@ -22,6 +22,7 @@ using Tensorflow.Operations.ControlFlows;
using util = Tensorflow.control_flow_util;
using static Tensorflow.Binding;
using Tensorflow.Util;
using System.Data;

namespace Tensorflow
{
@@ -420,14 +421,13 @@ namespace Tensorflow
public static Tensor cond(Tensor pred,
Func<ITensorOrOperation> true_fn = null,
Func<ITensorOrOperation> false_fn = null,
bool strict = false,
string name = null)
{
return tf_with(ops.name_scope(name, "cond", new { pred }), delegate
{
if (tf.Context.executing_eagerly())
{
if (pred.ToArray<bool>()[0])
if ((bool)pred)
return true_fn() as Tensor;
else
return false_fn() as Tensor;
@@ -676,6 +676,29 @@ namespace Tensorflow
}
}

public static Tensor[] while_loop(Func<Tensor[], Tensor> cond,
Func<Tensor[], Tensor[]> body,
Tensor[] loop_vars,
int parallel_iterations = 10,
string name = null)
{
var executing_eagerly = tf.Context.executing_eagerly();
if (!executing_eagerly)
{
throw new NotImplementedException("");
}

return tf_with(ops.name_scope("name", "while"), delegate
{
while ((bool)cond(loop_vars))
{
loop_vars = body(loop_vars);
}

return loop_vars;
});
}

/// <summary>
/// Repeat `body` while the condition `cond` is true.
/// </summary>


+ 7
- 1
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -28,7 +28,7 @@ https://tensorflownet.readthedocs.io</Description>
<FileVersion>0.20.1.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly>
<SignAssembly>false</SignAssembly>
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
<Platforms>AnyCPU;x64</Platforms>
</PropertyGroup>
@@ -83,4 +83,10 @@ https://tensorflownet.readthedocs.io</Description>
<ItemGroup>
<Folder Include="Keras\Initializers\" />
</ItemGroup>

<ItemGroup>
<None Update="FodyWeavers.xml">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>
</Project>

+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -50,7 +50,7 @@ namespace Tensorflow
/// </summary>
public AllocationType AllocationType { get; protected set; }

public IntPtr TensorDataPointer => TF_TensorData(_handle);
public IntPtr TensorDataPointer => _handle == IntPtr.Zero ? IntPtr.Zero : TF_TensorData(_handle);

/// <summary>
/// Create a Tensor object from an existing TF handle


+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs View File

@@ -11,7 +11,7 @@ namespace Tensorflow
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_BOOL);
return *(bool*) tensor.buffer;
return *(bool*)tensor.buffer;
}
}



Loading…
Cancel
Save