From 694f99cf362d55a2ee591982ab3c0ddf6e819d78 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 8 Apr 2019 20:23:38 -0500 Subject: [PATCH] overload the cond function --- README.md | 1 + .../_InitializeClustersOpFactory.cs | 23 ++++----- .../Operations/ControlFlows/CondContext.cs | 19 ++++++++ .../Operations/array_ops.py.cs | 20 ++++++++ .../Operations/control_flow_ops.py.cs | 47 +++++++++++++++++++ .../Operations/gen_array_ops.cs | 9 +++- test/TensorFlowNET.Examples/Program.cs | 1 + 7 files changed, 108 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index da0bf7f6..bdd09070 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ TensorFlow.NET provides .NET Standard binding for [TensorFlow](https://www.tenso [![codecov](https://codecov.io/gh/SciSharp/NumSharp/branch/master/graph/badge.svg)](https://codecov.io/gh/SciSharp/NumSharp) [![NuGet](https://img.shields.io/nuget/dt/TensorFlow.NET.svg)](https://www.nuget.org/packages/TensorFlow.NET) [![Documentation Status](https://readthedocs.org/projects/tensorflownet/badge/?version=latest)](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) +[![Badge](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu/#/en_US) TensorFlow.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). diff --git a/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs b/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs index b1167e3e..1cfb85a1 100644 --- a/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs +++ b/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs @@ -52,7 +52,7 @@ namespace Tensorflow.Clustering _num_data = math_ops.add_n(_inputs.Select(i => array_ops.shape(i)[0]).ToArray()); } - private Tensor[] _initialize() + private Tensor _initialize() { return with(ops.control_dependencies(new Operation[] { @@ -63,24 +63,25 @@ namespace Tensorflow.Clustering return control_flow_ops.cond(math_ops.equal(num_now_remaining, 0), () => { - return new Tensor[] { state_ops.assign(_cluster_centers_initialized, true) }; + return state_ops.assign(_cluster_centers_initialized, true); }, () => { - return new Tensor[] { control_flow_ops.no_op().output[0] }; + return control_flow_ops.no_op().output[0]; }); }); } - public Tensor[] op() + public Tensor op() { - return control_flow_ops.cond(gen_math_ops.equal(_num_remaining, 0), + var x = control_flow_ops.cond(gen_math_ops.equal(_num_remaining, 0), () => { - var op = check_ops.assert_equal(_cluster_centers_initialized, true); - return new Tensor[] { op.output[0] }; + return check_ops.assert_equal(_cluster_centers_initialized, true); }, _initialize); + + return x; } private Tensor _add_new_centers() @@ -93,7 +94,7 @@ namespace Tensorflow.Clustering // If cluster_centers is empty, it doesn't have the right shape for concat. var all_centers = control_flow_ops.cond(math_ops.equal(_num_selected, 0), () => new Tensor[] { new_centers }, - () => new Tensor[] { gen_array_ops.concat(new Tensor[] { _cluster_centers, new_centers }, 0) }); + () => new Tensor[] { array_ops.concat(new Tensor[] { _cluster_centers, new_centers }, 0) }); var a = state_ops.assign(_cluster_centers, all_centers, validate_shape: false); @@ -105,16 +106,16 @@ namespace Tensorflow.Clustering return _greedy_batch_sampler()[0]; } - private Tensor[] _greedy_batch_sampler() + private Tensor _greedy_batch_sampler() { return control_flow_ops.cond(_num_data <= _num_remaining, () => { - return new Tensor[] { gen_array_ops.concat(_inputs, 0) }; + return array_ops.concat(_inputs, 0); }, () => { - return new Tensor[] { _random() }; + return _random(); }); } diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs index 15d6f2e1..8ed46036 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -64,6 +64,25 @@ namespace Tensorflow.Operations } } + public (T, Tensor) BuildCondBranch(Func fn) + { + // Add the subgraph defined by fn() to the graph. + var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + var original_result = fn(); + var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + + switch (original_result) + { + case Operation[] results: + return (original_result, _BuildCondTensor(results)); + case float[] fv: + var result = ops.convert_to_tensor(fv[0]); + return (original_result, result ); + default: + return (original_result, null); + } + } + public (T[], Tensor[]) BuildCondBranch(Func fn) { // Add the subgraph defined by fn() to the graph. diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index e5f68573..5b6af742 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -421,6 +421,26 @@ namespace Tensorflow public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y) => Framework.common_shapes.broadcast_shape(shape_x, shape_y); + /// + /// Concatenates tensors along one dimension. + /// + /// + /// + /// + /// + public static Tensor concat(Tensor[] values, int axis, string name = "concat") + { + if(values.Length == 1) // Degenerate case of one tensor. + { + return with(ops.name_scope(name), scope => { + var t = ops.convert_to_tensor(axis, name: "concat_dim", dtype: TF_DataType.TF_INT32); + return identity(values[0], name = scope); + }); + } + + return gen_array_ops.concat_v2(values, axis, name: name); + } + public static Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0) => gen_array_ops.gather_v2(@params, indices, axis, name: name); diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index a8e25c08..db015443 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -187,6 +187,53 @@ namespace Tensorflow return @switch(data, pred, name: name); } + public static Tensor cond(Tensor pred, + Func true_fn = null, + Func false_fn = null, + bool strict = false, + string name = null) + { + return with(ops.name_scope(name, "cond", new { pred }), delegate + { + // Add the Switch to the graph. + var (p_2, p_1) = @switch(pred, pred); + var pivot_1 = array_ops.identity(p_1, name: "switch_t"); + var pivot_2 = array_ops.identity(p_2, name: "switch_f"); + pred = array_ops.identity(pred, name: "pred_id"); + + // Disable the fetching of tensors that are only on one branch of cond. + foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) + tensor.op.graph.prevent_fetching(tensor.op); + + // Build the graph for the true branch in a new context. + var context_t = new CondContext(pred, pivot_1, branch: 1); + context_t.Enter(); + var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); + context_t.Exit(); + + // Build the graph for the false branch in a new context. + var context_f = new CondContext(pred, pivot_2, branch: 0); + context_f.Enter(); + var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); + context_f.Exit(); + + var res_t_flat = res_t; + var res_f_flat = res_f; + + return new Tensor(IntPtr.Zero); + /*var merges = zip(res_f_flat, res_t_flat) + .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) + .ToArray(); + + merges = _convert_flows_to_tensorarrays(orig_res_t, merges); + + ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); + ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); + + return merges;*/ + }); + } + public static Tensor[] cond(Tensor pred, Func true_fn = null, Func false_fn = null, diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 56b47ed6..07525318 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -12,7 +12,14 @@ namespace Tensorflow public static OpDefLibrary _op_def_lib = new OpDefLibrary(); public static Execute _execute = new Execute(); - public static Tensor concat(Tensor[] values, int axis, string name = null) + /// + /// Concatenates tensors along one dimension. + /// + /// + /// + /// + /// + public static Tensor concat_v2(Tensor[] values, int axis, string name = null) { var _op = _op_def_lib._apply_op_helper("ConcatV2", name: name, args: new { values, axis }); diff --git a/test/TensorFlowNET.Examples/Program.cs b/test/TensorFlowNET.Examples/Program.cs index 44448caf..69fd1592 100644 --- a/test/TensorFlowNET.Examples/Program.cs +++ b/test/TensorFlowNET.Examples/Program.cs @@ -40,6 +40,7 @@ namespace TensorFlowNET.Examples } catch (Exception ex) { + errors.Add($"Example {example.Priority}: {example.Name}"); Console.WriteLine(ex); }