@@ -6,6 +6,7 @@ TensorFlow.NET provides .NET Standard binding for [TensorFlow](https://www.tenso | |||
[](https://codecov.io/gh/SciSharp/NumSharp) | |||
[](https://www.nuget.org/packages/TensorFlow.NET) | |||
[](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) | |||
[](https://996.icu/#/en_US) | |||
TensorFlow.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). | |||
@@ -52,7 +52,7 @@ namespace Tensorflow.Clustering | |||
_num_data = math_ops.add_n(_inputs.Select(i => array_ops.shape(i)[0]).ToArray()); | |||
} | |||
private Tensor[] _initialize() | |||
private Tensor _initialize() | |||
{ | |||
return with(ops.control_dependencies(new Operation[] | |||
{ | |||
@@ -63,24 +63,25 @@ namespace Tensorflow.Clustering | |||
return control_flow_ops.cond(math_ops.equal(num_now_remaining, 0), | |||
() => | |||
{ | |||
return new Tensor[] { state_ops.assign(_cluster_centers_initialized, true) }; | |||
return state_ops.assign(_cluster_centers_initialized, true); | |||
}, | |||
() => | |||
{ | |||
return new Tensor[] { control_flow_ops.no_op().output[0] }; | |||
return control_flow_ops.no_op().output[0]; | |||
}); | |||
}); | |||
} | |||
public Tensor[] op() | |||
public Tensor op() | |||
{ | |||
return control_flow_ops.cond(gen_math_ops.equal(_num_remaining, 0), | |||
var x = control_flow_ops.cond(gen_math_ops.equal(_num_remaining, 0), | |||
() => | |||
{ | |||
var op = check_ops.assert_equal(_cluster_centers_initialized, true); | |||
return new Tensor[] { op.output[0] }; | |||
return check_ops.assert_equal(_cluster_centers_initialized, true); | |||
}, | |||
_initialize); | |||
return x; | |||
} | |||
private Tensor _add_new_centers() | |||
@@ -93,7 +94,7 @@ namespace Tensorflow.Clustering | |||
// If cluster_centers is empty, it doesn't have the right shape for concat. | |||
var all_centers = control_flow_ops.cond(math_ops.equal(_num_selected, 0), | |||
() => new Tensor[] { new_centers }, | |||
() => new Tensor[] { gen_array_ops.concat(new Tensor[] { _cluster_centers, new_centers }, 0) }); | |||
() => new Tensor[] { array_ops.concat(new Tensor[] { _cluster_centers, new_centers }, 0) }); | |||
var a = state_ops.assign(_cluster_centers, all_centers, validate_shape: false); | |||
@@ -105,16 +106,16 @@ namespace Tensorflow.Clustering | |||
return _greedy_batch_sampler()[0]; | |||
} | |||
private Tensor[] _greedy_batch_sampler() | |||
private Tensor _greedy_batch_sampler() | |||
{ | |||
return control_flow_ops.cond(_num_data <= _num_remaining, | |||
() => | |||
{ | |||
return new Tensor[] { gen_array_ops.concat(_inputs, 0) }; | |||
return array_ops.concat(_inputs, 0); | |||
}, | |||
() => | |||
{ | |||
return new Tensor[] { _random() }; | |||
return _random(); | |||
}); | |||
} | |||
@@ -64,6 +64,25 @@ namespace Tensorflow.Operations | |||
} | |||
} | |||
public (T, Tensor) BuildCondBranch<T>(Func<T> fn) | |||
{ | |||
// Add the subgraph defined by fn() to the graph. | |||
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||
var original_result = fn(); | |||
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||
switch (original_result) | |||
{ | |||
case Operation[] results: | |||
return (original_result, _BuildCondTensor(results)); | |||
case float[] fv: | |||
var result = ops.convert_to_tensor(fv[0]); | |||
return (original_result, result ); | |||
default: | |||
return (original_result, null); | |||
} | |||
} | |||
public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn) | |||
{ | |||
// Add the subgraph defined by fn() to the graph. | |||
@@ -421,6 +421,26 @@ namespace Tensorflow | |||
public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y) | |||
=> Framework.common_shapes.broadcast_shape(shape_x, shape_y); | |||
/// <summary> | |||
/// Concatenates tensors along one dimension. | |||
/// </summary> | |||
/// <param name="values"></param> | |||
/// <param name="axis"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor concat(Tensor[] values, int axis, string name = "concat") | |||
{ | |||
if(values.Length == 1) // Degenerate case of one tensor. | |||
{ | |||
return with(ops.name_scope(name), scope => { | |||
var t = ops.convert_to_tensor(axis, name: "concat_dim", dtype: TF_DataType.TF_INT32); | |||
return identity(values[0], name = scope); | |||
}); | |||
} | |||
return gen_array_ops.concat_v2(values, axis, name: name); | |||
} | |||
public static Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0) | |||
=> gen_array_ops.gather_v2(@params, indices, axis, name: name); | |||
@@ -187,6 +187,53 @@ namespace Tensorflow | |||
return @switch(data, pred, name: name); | |||
} | |||
public static Tensor cond(Tensor pred, | |||
Func<ITensorOrOperation> true_fn = null, | |||
Func<ITensorOrOperation> false_fn = null, | |||
bool strict = false, | |||
string name = null) | |||
{ | |||
return with(ops.name_scope(name, "cond", new { pred }), delegate | |||
{ | |||
// Add the Switch to the graph. | |||
var (p_2, p_1) = @switch(pred, pred); | |||
var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | |||
var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | |||
pred = array_ops.identity(pred, name: "pred_id"); | |||
// Disable the fetching of tensors that are only on one branch of cond. | |||
foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) | |||
tensor.op.graph.prevent_fetching(tensor.op); | |||
// Build the graph for the true branch in a new context. | |||
var context_t = new CondContext(pred, pivot_1, branch: 1); | |||
context_t.Enter(); | |||
var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | |||
context_t.Exit(); | |||
// Build the graph for the false branch in a new context. | |||
var context_f = new CondContext(pred, pivot_2, branch: 0); | |||
context_f.Enter(); | |||
var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | |||
context_f.Exit(); | |||
var res_t_flat = res_t; | |||
var res_f_flat = res_f; | |||
return new Tensor(IntPtr.Zero); | |||
/*var merges = zip(res_f_flat, res_t_flat) | |||
.Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) | |||
.ToArray(); | |||
merges = _convert_flows_to_tensorarrays(orig_res_t, merges); | |||
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); | |||
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); | |||
return merges;*/ | |||
}); | |||
} | |||
public static Tensor[] cond<T>(Tensor pred, | |||
Func<T[]> true_fn = null, | |||
Func<T[]> false_fn = null, | |||
@@ -12,7 +12,14 @@ namespace Tensorflow | |||
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
public static Execute _execute = new Execute(); | |||
public static Tensor concat(Tensor[] values, int axis, string name = null) | |||
/// <summary> | |||
/// Concatenates tensors along one dimension. | |||
/// </summary> | |||
/// <param name="values"></param> | |||
/// <param name="axis"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor concat_v2(Tensor[] values, int axis, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("ConcatV2", name: name, args: new { values, axis }); | |||
@@ -40,6 +40,7 @@ namespace TensorFlowNET.Examples | |||
} | |||
catch (Exception ex) | |||
{ | |||
errors.Add($"Example {example.Priority}: {example.Name}"); | |||
Console.WriteLine(ex); | |||
} | |||