@@ -1,4 +1,4 @@ | |||
<Project Sdk="Microsoft.NET.Sdk"> | |||
<Project Sdk="Microsoft.NET.Sdk"> | |||
<PropertyGroup> | |||
<OutputType>Exe</OutputType> | |||
@@ -0,0 +1,26 @@ | |||
/***************************************************************************** | |||
Copyright 2020 Haiping Chen. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Tensorflow.Graphs; | |||
using Tensorflow.Operations; | |||
namespace Tensorflow | |||
{ | |||
public partial class tensorflow | |||
{ | |||
public AutoGraph autograph = new AutoGraph(); | |||
} | |||
} |
@@ -15,17 +15,22 @@ | |||
******************************************************************************/ | |||
using System; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
public partial class tensorflow | |||
{ | |||
public Tensor cond(Tensor pred, | |||
Tensor true_value, | |||
Tensor false_false) | |||
=> control_flow_ops.cond(pred, () => true_value, () => false_false); | |||
public Tensor cond(Tensor pred, | |||
Func<ITensorOrOperation> true_fn = null, | |||
Func<ITensorOrOperation> false_fn = null, | |||
bool strict = false, | |||
string name = null) | |||
=> control_flow_ops.cond(pred, true_fn, false_fn, strict: strict, name: name); | |||
=> control_flow_ops.cond(pred, true_fn, false_fn, name: name); | |||
/// <summary> | |||
/// Create an op that groups multiple operations. | |||
@@ -37,22 +42,31 @@ namespace Tensorflow | |||
public Operation group<T>(T[] inputs, string name = null) where T : ITensorOrOperation | |||
=> control_flow_ops.group(inputs, name: name); | |||
/*public Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars, | |||
TensorShape shape_invariants = null, | |||
public Tensor while_loop(Func<Tensor, Tensor> cond, | |||
Func<Tensor, Tensor> body, | |||
Tensor loop_vars, | |||
int parallel_iterations = 10) | |||
{ | |||
Func<Tensor[], Tensor> cond1 = x | |||
=> cond(x[0]); | |||
Func<Tensor[], Tensor[]> body1 = x | |||
=> new[] { body(x[0]) }; | |||
var results = control_flow_ops.while_loop(cond1, | |||
body1, | |||
new[] { loop_vars }); | |||
return results[0]; | |||
} | |||
public Tensor[] while_loop(Func<Tensor[], Tensor> cond, | |||
Func<Tensor[], Tensor[]> body, | |||
Tensor[] loop_vars, | |||
int parallel_iterations = 10, | |||
bool back_prop = true, | |||
bool swap_memory = false, | |||
string name = null, | |||
int? maximum_iterations = null, | |||
bool return_same_structure = false) | |||
string name = null) | |||
=> control_flow_ops.while_loop(cond, body, loop_vars, | |||
shape_invariants: shape_invariants, | |||
parallel_iterations: parallel_iterations, | |||
back_prop: back_prop, | |||
swap_memory: swap_memory, | |||
name: name, | |||
maximum_iterations: maximum_iterations, | |||
return_same_structure: return_same_structure);*/ | |||
name: name); | |||
public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | |||
=> ops.control_dependencies(control_inputs); | |||
@@ -78,6 +78,37 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status); | |||
/// <summary> | |||
/// Adds a function (created from TF_GraphToFunction or | |||
/// TF_FunctionImportFunctionDef) to the context, allowing it to be executed with | |||
/// TFE_Execute by creating an op with the same name as the function. | |||
/// </summary> | |||
/// <param name="ctx"></param> | |||
/// <param name="function"></param> | |||
/// <param name="status"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_ContextAddFunction(SafeContextHandle ctx, IntPtr function, SafeStatusHandle status); | |||
/// <summary> | |||
/// Removes a function from the context. Once removed, you can no longer | |||
/// TFE_Execute it or TFE_Execute any TFE_Op which has it as an attribute or any | |||
/// other function which calls it as an attribute. | |||
/// </summary> | |||
/// <param name="ctx"></param> | |||
/// <param name="name"></param> | |||
/// <param name="status"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_ContextRemoveFunction(SafeContextHandle ctx, string name, SafeStatusHandle status); | |||
/// <summary> | |||
/// Checks whether a function is registered under `name`. | |||
/// </summary> | |||
/// <param name="ctx"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern bool TFE_ContextHasFunction(SafeContextHandle ctx, string name); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_ContextStartStep(SafeContextHandle ctx); | |||
@@ -39,7 +39,7 @@ namespace Tensorflow | |||
int num_opers, IntPtr[] opers, | |||
int ninputs, TF_Output[] inputs, | |||
int noutputs, TF_Output[] outputs, | |||
IntPtr output_names, | |||
string[] output_names, | |||
IntPtr opts, | |||
string description, | |||
SafeStatusHandle status); | |||
@@ -0,0 +1,47 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Linq.Expressions; | |||
using System.Text; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Graphs | |||
{ | |||
public class AutoGraph | |||
{ | |||
public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func) | |||
{ | |||
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||
tf.compat.v1.disable_eager_execution(); | |||
// IntPtr func_handle; | |||
using(var graph = new FuncGraph(func_name)) | |||
{ | |||
graph.as_default(); | |||
var input1 = tf.placeholder(tf.int32); | |||
var input2 = tf.placeholder(tf.int32); | |||
var output = func(input1, input2); | |||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
var func_handle = graph.ToGraph(opers, | |||
new Operation[] { input1, input2 }, | |||
new Operation[] { output }, | |||
null); | |||
c_api.TFE_ContextAddFunction(tf.Context.Handle, func_handle, tf.Status.Handle); | |||
} | |||
tf.enable_eager_execution(); | |||
return (Tensor a, Tensor b) => | |||
{ | |||
var result = tf.Runner.TFE_Execute(tf.Context, | |||
tf.Context.DeviceName, | |||
func_name, | |||
new[] { a, b }, | |||
null, | |||
1); | |||
return result[0]; | |||
}; | |||
} | |||
} | |||
} |
@@ -0,0 +1,56 @@ | |||
/*using MethodBoundaryAspect.Fody.Attributes; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Graphs | |||
{ | |||
public sealed class AutoGraphAspect : OnMethodBoundaryAspect | |||
{ | |||
FuncGraph graph; | |||
IntPtr func_handle; | |||
public override void OnEntry(MethodExecutionArgs args) | |||
{ | |||
tf.compat.v1.disable_eager_execution(); | |||
// convert args to placeholder | |||
for (var i = 0; i < args.Arguments.Length; i++) | |||
{ | |||
if (args.Arguments[i] is EagerTensor tensor) | |||
args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape); | |||
} | |||
// make function as an Operation by autograph | |||
graph = new FuncGraph("autograph_add"); | |||
graph.as_default(); | |||
} | |||
public override void OnExit(MethodExecutionArgs args) | |||
{ | |||
var output = (Tensor)args.Method.Invoke(args.Instance, args.Arguments); | |||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
func_handle = graph.ToGraph(opers, | |||
new Operation[] { }, | |||
new Operation[] { }, | |||
null); | |||
c_api.TFE_ContextAddFunction(tf.Context.Handle, func_handle, tf.Status.Handle); | |||
var a1 = tf.constant(1); | |||
var b1 = tf.constant(2); | |||
var result = tf.Runner.TFE_Execute(tf.Context, | |||
tf.Context.DeviceName, | |||
"autograph_add", | |||
new[] { a1, b1 }, | |||
null, | |||
1); | |||
graph.Dispose(); | |||
} | |||
} | |||
}*/ |
@@ -0,0 +1,54 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Graphs | |||
{ | |||
/// <summary> | |||
/// Graph representing a function body. | |||
/// </summary> | |||
public class FuncGraph : Graph | |||
{ | |||
List<Operation> inputs; | |||
List<Operation> outputs; | |||
Graph outer_graph; | |||
string func_name; | |||
IntPtr func_handle; | |||
public string FuncName => c_api.StringPiece(c_api.TF_FunctionName(func_handle)); | |||
/// <summary> | |||
/// Construct a new FuncGraph. | |||
/// </summary> | |||
public FuncGraph(string name) : base() | |||
{ | |||
outer_graph = ops.get_default_graph(); | |||
func_name = name; | |||
} | |||
public IntPtr ToGraph(Operation[] opers, | |||
Operation[] inputs, Operation[] outputs, | |||
string[] output_names) | |||
{ | |||
using var status = new Status(); | |||
func_handle = c_api.TF_GraphToFunction(_handle, | |||
func_name, | |||
false, | |||
opers.Length, | |||
opers.Select(x => (IntPtr)x).ToArray(), | |||
inputs.Length, | |||
inputs.Select(x => new TF_Output(x, 0)).ToArray(), | |||
outputs.Length, | |||
outputs.Select(x => new TF_Output(x, 0)).ToArray(), | |||
output_names == null || output_names.Length == 0 ? null : output_names, | |||
IntPtr.Zero, | |||
null, | |||
status.Handle); | |||
c_api.TF_GraphCopyFunction(outer_graph, func_handle, IntPtr.Zero, status.Handle); | |||
return func_handle; | |||
} | |||
} | |||
} |
@@ -22,6 +22,7 @@ using Tensorflow.Operations.ControlFlows; | |||
using util = Tensorflow.control_flow_util; | |||
using static Tensorflow.Binding; | |||
using Tensorflow.Util; | |||
using System.Data; | |||
namespace Tensorflow | |||
{ | |||
@@ -420,14 +421,13 @@ namespace Tensorflow | |||
public static Tensor cond(Tensor pred, | |||
Func<ITensorOrOperation> true_fn = null, | |||
Func<ITensorOrOperation> false_fn = null, | |||
bool strict = false, | |||
string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, "cond", new { pred }), delegate | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
if (pred.ToArray<bool>()[0]) | |||
if ((bool)pred) | |||
return true_fn() as Tensor; | |||
else | |||
return false_fn() as Tensor; | |||
@@ -676,6 +676,29 @@ namespace Tensorflow | |||
} | |||
} | |||
public static Tensor[] while_loop(Func<Tensor[], Tensor> cond, | |||
Func<Tensor[], Tensor[]> body, | |||
Tensor[] loop_vars, | |||
int parallel_iterations = 10, | |||
string name = null) | |||
{ | |||
var executing_eagerly = tf.Context.executing_eagerly(); | |||
if (!executing_eagerly) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
return tf_with(ops.name_scope("name", "while"), delegate | |||
{ | |||
while ((bool)cond(loop_vars)) | |||
{ | |||
loop_vars = body(loop_vars); | |||
} | |||
return loop_vars; | |||
}); | |||
} | |||
/// <summary> | |||
/// Repeat `body` while the condition `cond` is true. | |||
/// </summary> | |||
@@ -28,7 +28,7 @@ https://tensorflownet.readthedocs.io</Description> | |||
<FileVersion>0.20.1.0</FileVersion> | |||
<PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | |||
<SignAssembly>true</SignAssembly> | |||
<SignAssembly>false</SignAssembly> | |||
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | |||
<Platforms>AnyCPU;x64</Platforms> | |||
</PropertyGroup> | |||
@@ -83,4 +83,10 @@ https://tensorflownet.readthedocs.io</Description> | |||
<ItemGroup> | |||
<Folder Include="Keras\Initializers\" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<None Update="FodyWeavers.xml"> | |||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
</None> | |||
</ItemGroup> | |||
</Project> |
@@ -50,7 +50,7 @@ namespace Tensorflow | |||
/// </summary> | |||
public AllocationType AllocationType { get; protected set; } | |||
public IntPtr TensorDataPointer => TF_TensorData(_handle); | |||
public IntPtr TensorDataPointer => _handle == IntPtr.Zero ? IntPtr.Zero : TF_TensorData(_handle); | |||
/// <summary> | |||
/// Create a Tensor object from an existing TF handle | |||
@@ -11,7 +11,7 @@ namespace Tensorflow | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_BOOL); | |||
return *(bool*) tensor.buffer; | |||
return *(bool*)tensor.buffer; | |||
} | |||
} | |||