Browse Source

overload the cond function

tags/v0.9
Oceania2018 6 years ago
parent
commit
694f99cf36
7 changed files with 108 additions and 12 deletions
  1. +1
    -0
      README.md
  2. +12
    -11
      src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs
  3. +19
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
  4. +20
    -0
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  5. +47
    -0
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  6. +8
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  7. +1
    -0
      test/TensorFlowNET.Examples/Program.cs

+ 1
- 0
README.md View File

@@ -6,6 +6,7 @@ TensorFlow.NET provides .NET Standard binding for [TensorFlow](https://www.tenso
[![codecov](https://codecov.io/gh/SciSharp/NumSharp/branch/master/graph/badge.svg)](https://codecov.io/gh/SciSharp/NumSharp)
[![NuGet](https://img.shields.io/nuget/dt/TensorFlow.NET.svg)](https://www.nuget.org/packages/TensorFlow.NET)
[![Documentation Status](https://readthedocs.org/projects/tensorflownet/badge/?version=latest)](https://tensorflownet.readthedocs.io/en/latest/?badge=latest)
[![Badge](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu/#/en_US)

TensorFlow.NET is a member project of [SciSharp STACK](https://github.com/SciSharp).



+ 12
- 11
src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs View File

@@ -52,7 +52,7 @@ namespace Tensorflow.Clustering
_num_data = math_ops.add_n(_inputs.Select(i => array_ops.shape(i)[0]).ToArray());
}

private Tensor[] _initialize()
private Tensor _initialize()
{
return with(ops.control_dependencies(new Operation[]
{
@@ -63,24 +63,25 @@ namespace Tensorflow.Clustering
return control_flow_ops.cond(math_ops.equal(num_now_remaining, 0),
() =>
{
return new Tensor[] { state_ops.assign(_cluster_centers_initialized, true) };
return state_ops.assign(_cluster_centers_initialized, true);
},
() =>
{
return new Tensor[] { control_flow_ops.no_op().output[0] };
return control_flow_ops.no_op().output[0];
});
});
}

public Tensor[] op()
public Tensor op()
{
return control_flow_ops.cond(gen_math_ops.equal(_num_remaining, 0),
var x = control_flow_ops.cond(gen_math_ops.equal(_num_remaining, 0),
() =>
{
var op = check_ops.assert_equal(_cluster_centers_initialized, true);
return new Tensor[] { op.output[0] };
return check_ops.assert_equal(_cluster_centers_initialized, true);
},
_initialize);

return x;
}

private Tensor _add_new_centers()
@@ -93,7 +94,7 @@ namespace Tensorflow.Clustering
// If cluster_centers is empty, it doesn't have the right shape for concat.
var all_centers = control_flow_ops.cond(math_ops.equal(_num_selected, 0),
() => new Tensor[] { new_centers },
() => new Tensor[] { gen_array_ops.concat(new Tensor[] { _cluster_centers, new_centers }, 0) });
() => new Tensor[] { array_ops.concat(new Tensor[] { _cluster_centers, new_centers }, 0) });

var a = state_ops.assign(_cluster_centers, all_centers, validate_shape: false);

@@ -105,16 +106,16 @@ namespace Tensorflow.Clustering
return _greedy_batch_sampler()[0];
}

private Tensor[] _greedy_batch_sampler()
private Tensor _greedy_batch_sampler()
{
return control_flow_ops.cond(_num_data <= _num_remaining,
() =>
{
return new Tensor[] { gen_array_ops.concat(_inputs, 0) };
return array_ops.concat(_inputs, 0);
},
() =>
{
return new Tensor[] { _random() };
return _random();
});
}



+ 19
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs View File

@@ -64,6 +64,25 @@ namespace Tensorflow.Operations
}
}

public (T, Tensor) BuildCondBranch<T>(Func<T> fn)
{
// Add the subgraph defined by fn() to the graph.
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
var original_result = fn();
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);

switch (original_result)
{
case Operation[] results:
return (original_result, _BuildCondTensor(results));
case float[] fv:
var result = ops.convert_to_tensor(fv[0]);
return (original_result, result );
default:
return (original_result, null);
}
}

public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn)
{
// Add the subgraph defined by fn() to the graph.


+ 20
- 0
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -421,6 +421,26 @@ namespace Tensorflow
public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y)
=> Framework.common_shapes.broadcast_shape(shape_x, shape_y);
/// <summary>
/// Concatenates tensors along one dimension.
/// </summary>
/// <param name="values"></param>
/// <param name="axis"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor concat(Tensor[] values, int axis, string name = "concat")
{
if(values.Length == 1) // Degenerate case of one tensor.
{
return with(ops.name_scope(name), scope => {
var t = ops.convert_to_tensor(axis, name: "concat_dim", dtype: TF_DataType.TF_INT32);
return identity(values[0], name = scope);
});
}
return gen_array_ops.concat_v2(values, axis, name: name);
}
public static Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0)
=> gen_array_ops.gather_v2(@params, indices, axis, name: name);


+ 47
- 0
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -187,6 +187,53 @@ namespace Tensorflow
return @switch(data, pred, name: name);
}

public static Tensor cond(Tensor pred,
Func<ITensorOrOperation> true_fn = null,
Func<ITensorOrOperation> false_fn = null,
bool strict = false,
string name = null)
{
return with(ops.name_scope(name, "cond", new { pred }), delegate
{
// Add the Switch to the graph.
var (p_2, p_1) = @switch(pred, pred);
var pivot_1 = array_ops.identity(p_1, name: "switch_t");
var pivot_2 = array_ops.identity(p_2, name: "switch_f");
pred = array_ops.identity(pred, name: "pred_id");

// Disable the fetching of tensors that are only on one branch of cond.
foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred })
tensor.op.graph.prevent_fetching(tensor.op);

// Build the graph for the true branch in a new context.
var context_t = new CondContext(pred, pivot_1, branch: 1);
context_t.Enter();
var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
context_t.Exit();

// Build the graph for the false branch in a new context.
var context_f = new CondContext(pred, pivot_2, branch: 0);
context_f.Enter();
var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
context_f.Exit();

var res_t_flat = res_t;
var res_f_flat = res_f;

return new Tensor(IntPtr.Zero);
/*var merges = zip(res_f_flat, res_t_flat)
.Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 }))
.ToArray();

merges = _convert_flows_to_tensorarrays(orig_res_t, merges);

ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t);
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f);

return merges;*/
});
}

public static Tensor[] cond<T>(Tensor pred,
Func<T[]> true_fn = null,
Func<T[]> false_fn = null,


+ 8
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -12,7 +12,14 @@ namespace Tensorflow
public static OpDefLibrary _op_def_lib = new OpDefLibrary();
public static Execute _execute = new Execute();

public static Tensor concat(Tensor[] values, int axis, string name = null)
/// <summary>
/// Concatenates tensors along one dimension.
/// </summary>
/// <param name="values"></param>
/// <param name="axis"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor concat_v2(Tensor[] values, int axis, string name = null)
{
var _op = _op_def_lib._apply_op_helper("ConcatV2", name: name, args: new { values, axis });



+ 1
- 0
test/TensorFlowNET.Examples/Program.cs View File

@@ -40,6 +40,7 @@ namespace TensorFlowNET.Examples
}
catch (Exception ex)
{
errors.Add($"Example {example.Priority}: {example.Name}");
Console.WriteLine(ex);
}



Loading…
Cancel
Save