@@ -6,6 +6,7 @@ TensorFlow.NET provides .NET Standard binding for [TensorFlow](https://www.tenso | |||||
[](https://codecov.io/gh/SciSharp/NumSharp) | [](https://codecov.io/gh/SciSharp/NumSharp) | ||||
[](https://www.nuget.org/packages/TensorFlow.NET) | [](https://www.nuget.org/packages/TensorFlow.NET) | ||||
[](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) | [](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) | ||||
[](https://996.icu/#/en_US) | |||||
TensorFlow.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). | TensorFlow.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). | ||||
@@ -52,7 +52,7 @@ namespace Tensorflow.Clustering | |||||
_num_data = math_ops.add_n(_inputs.Select(i => array_ops.shape(i)[0]).ToArray()); | _num_data = math_ops.add_n(_inputs.Select(i => array_ops.shape(i)[0]).ToArray()); | ||||
} | } | ||||
private Tensor[] _initialize() | |||||
private Tensor _initialize() | |||||
{ | { | ||||
return with(ops.control_dependencies(new Operation[] | return with(ops.control_dependencies(new Operation[] | ||||
{ | { | ||||
@@ -63,24 +63,25 @@ namespace Tensorflow.Clustering | |||||
return control_flow_ops.cond(math_ops.equal(num_now_remaining, 0), | return control_flow_ops.cond(math_ops.equal(num_now_remaining, 0), | ||||
() => | () => | ||||
{ | { | ||||
return new Tensor[] { state_ops.assign(_cluster_centers_initialized, true) }; | |||||
return state_ops.assign(_cluster_centers_initialized, true); | |||||
}, | }, | ||||
() => | () => | ||||
{ | { | ||||
return new Tensor[] { control_flow_ops.no_op().output[0] }; | |||||
return control_flow_ops.no_op().output[0]; | |||||
}); | }); | ||||
}); | }); | ||||
} | } | ||||
public Tensor[] op() | |||||
public Tensor op() | |||||
{ | { | ||||
return control_flow_ops.cond(gen_math_ops.equal(_num_remaining, 0), | |||||
var x = control_flow_ops.cond(gen_math_ops.equal(_num_remaining, 0), | |||||
() => | () => | ||||
{ | { | ||||
var op = check_ops.assert_equal(_cluster_centers_initialized, true); | |||||
return new Tensor[] { op.output[0] }; | |||||
return check_ops.assert_equal(_cluster_centers_initialized, true); | |||||
}, | }, | ||||
_initialize); | _initialize); | ||||
return x; | |||||
} | } | ||||
private Tensor _add_new_centers() | private Tensor _add_new_centers() | ||||
@@ -93,7 +94,7 @@ namespace Tensorflow.Clustering | |||||
// If cluster_centers is empty, it doesn't have the right shape for concat. | // If cluster_centers is empty, it doesn't have the right shape for concat. | ||||
var all_centers = control_flow_ops.cond(math_ops.equal(_num_selected, 0), | var all_centers = control_flow_ops.cond(math_ops.equal(_num_selected, 0), | ||||
() => new Tensor[] { new_centers }, | () => new Tensor[] { new_centers }, | ||||
() => new Tensor[] { gen_array_ops.concat(new Tensor[] { _cluster_centers, new_centers }, 0) }); | |||||
() => new Tensor[] { array_ops.concat(new Tensor[] { _cluster_centers, new_centers }, 0) }); | |||||
var a = state_ops.assign(_cluster_centers, all_centers, validate_shape: false); | var a = state_ops.assign(_cluster_centers, all_centers, validate_shape: false); | ||||
@@ -105,16 +106,16 @@ namespace Tensorflow.Clustering | |||||
return _greedy_batch_sampler()[0]; | return _greedy_batch_sampler()[0]; | ||||
} | } | ||||
private Tensor[] _greedy_batch_sampler() | |||||
private Tensor _greedy_batch_sampler() | |||||
{ | { | ||||
return control_flow_ops.cond(_num_data <= _num_remaining, | return control_flow_ops.cond(_num_data <= _num_remaining, | ||||
() => | () => | ||||
{ | { | ||||
return new Tensor[] { gen_array_ops.concat(_inputs, 0) }; | |||||
return array_ops.concat(_inputs, 0); | |||||
}, | }, | ||||
() => | () => | ||||
{ | { | ||||
return new Tensor[] { _random() }; | |||||
return _random(); | |||||
}); | }); | ||||
} | } | ||||
@@ -64,6 +64,25 @@ namespace Tensorflow.Operations | |||||
} | } | ||||
} | } | ||||
public (T, Tensor) BuildCondBranch<T>(Func<T> fn) | |||||
{ | |||||
// Add the subgraph defined by fn() to the graph. | |||||
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
var original_result = fn(); | |||||
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
switch (original_result) | |||||
{ | |||||
case Operation[] results: | |||||
return (original_result, _BuildCondTensor(results)); | |||||
case float[] fv: | |||||
var result = ops.convert_to_tensor(fv[0]); | |||||
return (original_result, result ); | |||||
default: | |||||
return (original_result, null); | |||||
} | |||||
} | |||||
public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn) | public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn) | ||||
{ | { | ||||
// Add the subgraph defined by fn() to the graph. | // Add the subgraph defined by fn() to the graph. | ||||
@@ -421,6 +421,26 @@ namespace Tensorflow | |||||
public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y) | public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y) | ||||
=> Framework.common_shapes.broadcast_shape(shape_x, shape_y); | => Framework.common_shapes.broadcast_shape(shape_x, shape_y); | ||||
/// <summary> | |||||
/// Concatenates tensors along one dimension. | |||||
/// </summary> | |||||
/// <param name="values"></param> | |||||
/// <param name="axis"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor concat(Tensor[] values, int axis, string name = "concat") | |||||
{ | |||||
if(values.Length == 1) // Degenerate case of one tensor. | |||||
{ | |||||
return with(ops.name_scope(name), scope => { | |||||
var t = ops.convert_to_tensor(axis, name: "concat_dim", dtype: TF_DataType.TF_INT32); | |||||
return identity(values[0], name = scope); | |||||
}); | |||||
} | |||||
return gen_array_ops.concat_v2(values, axis, name: name); | |||||
} | |||||
public static Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0) | public static Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0) | ||||
=> gen_array_ops.gather_v2(@params, indices, axis, name: name); | => gen_array_ops.gather_v2(@params, indices, axis, name: name); | ||||
@@ -187,6 +187,53 @@ namespace Tensorflow | |||||
return @switch(data, pred, name: name); | return @switch(data, pred, name: name); | ||||
} | } | ||||
public static Tensor cond(Tensor pred, | |||||
Func<ITensorOrOperation> true_fn = null, | |||||
Func<ITensorOrOperation> false_fn = null, | |||||
bool strict = false, | |||||
string name = null) | |||||
{ | |||||
return with(ops.name_scope(name, "cond", new { pred }), delegate | |||||
{ | |||||
// Add the Switch to the graph. | |||||
var (p_2, p_1) = @switch(pred, pred); | |||||
var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | |||||
var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | |||||
pred = array_ops.identity(pred, name: "pred_id"); | |||||
// Disable the fetching of tensors that are only on one branch of cond. | |||||
foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) | |||||
tensor.op.graph.prevent_fetching(tensor.op); | |||||
// Build the graph for the true branch in a new context. | |||||
var context_t = new CondContext(pred, pivot_1, branch: 1); | |||||
context_t.Enter(); | |||||
var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | |||||
context_t.Exit(); | |||||
// Build the graph for the false branch in a new context. | |||||
var context_f = new CondContext(pred, pivot_2, branch: 0); | |||||
context_f.Enter(); | |||||
var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | |||||
context_f.Exit(); | |||||
var res_t_flat = res_t; | |||||
var res_f_flat = res_f; | |||||
return new Tensor(IntPtr.Zero); | |||||
/*var merges = zip(res_f_flat, res_t_flat) | |||||
.Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) | |||||
.ToArray(); | |||||
merges = _convert_flows_to_tensorarrays(orig_res_t, merges); | |||||
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); | |||||
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); | |||||
return merges;*/ | |||||
}); | |||||
} | |||||
public static Tensor[] cond<T>(Tensor pred, | public static Tensor[] cond<T>(Tensor pred, | ||||
Func<T[]> true_fn = null, | Func<T[]> true_fn = null, | ||||
Func<T[]> false_fn = null, | Func<T[]> false_fn = null, | ||||
@@ -12,7 +12,14 @@ namespace Tensorflow | |||||
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | ||||
public static Execute _execute = new Execute(); | public static Execute _execute = new Execute(); | ||||
public static Tensor concat(Tensor[] values, int axis, string name = null) | |||||
/// <summary> | |||||
/// Concatenates tensors along one dimension. | |||||
/// </summary> | |||||
/// <param name="values"></param> | |||||
/// <param name="axis"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor concat_v2(Tensor[] values, int axis, string name = null) | |||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("ConcatV2", name: name, args: new { values, axis }); | var _op = _op_def_lib._apply_op_helper("ConcatV2", name: name, args: new { values, axis }); | ||||
@@ -40,6 +40,7 @@ namespace TensorFlowNET.Examples | |||||
} | } | ||||
catch (Exception ex) | catch (Exception ex) | ||||
{ | { | ||||
errors.Add($"Example {example.Priority}: {example.Name}"); | |||||
Console.WriteLine(ex); | Console.WriteLine(ex); | ||||
} | } | ||||