@@ -1,5 +1,6 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Diagnostics; | |||||
using System.Text; | using System.Text; | ||||
using NumSharp; | using NumSharp; | ||||
using Tensorflow; | using Tensorflow; | ||||
@@ -21,7 +22,12 @@ namespace Tensorflow.Hub | |||||
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); | images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); | ||||
images.astype(dataType); | images.astype(dataType); | ||||
// for debug np.multiply performance | |||||
var sw = new Stopwatch(); | |||||
sw.Start(); | |||||
images = np.multiply(images, 1.0f / 255.0f); | images = np.multiply(images, 1.0f / 255.0f); | ||||
sw.Stop(); | |||||
Console.WriteLine($"{sw.ElapsedMilliseconds}ms"); | |||||
Data = images; | Data = images; | ||||
labels.astype(dataType); | labels.astype(dataType); | ||||
@@ -14,10 +14,29 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public static partial class tf | public static partial class tf | ||||
{ | { | ||||
public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars, | |||||
TensorShape shape_invariants = null, | |||||
int parallel_iterations = 10, | |||||
bool back_prop = true, | |||||
bool swap_memory = false, | |||||
string name = null, | |||||
int? maximum_iterations = null, | |||||
bool return_same_structure = false) | |||||
=> control_flow_ops.while_loop(cond, body, loop_vars, | |||||
shape_invariants: shape_invariants, | |||||
parallel_iterations: parallel_iterations, | |||||
back_prop: back_prop, | |||||
swap_memory: swap_memory, | |||||
name: name, | |||||
maximum_iterations: maximum_iterations, | |||||
return_same_structure: return_same_structure); | |||||
public static _ControlDependenciesController control_dependencies(Operation[] control_inputs) | public static _ControlDependenciesController control_dependencies(Operation[] control_inputs) | ||||
=> ops.control_dependencies(control_inputs); | => ops.control_dependencies(control_inputs); | ||||
} | } | ||||
@@ -39,8 +39,8 @@ namespace Tensorflow | |||||
public static Tensor asin(Tensor x, string name = null) | public static Tensor asin(Tensor x, string name = null) | ||||
=> gen_math_ops.asin(x, name); | => gen_math_ops.asin(x, name); | ||||
public static Tensor add<Tx, Ty>(Tx a, Ty b) | |||||
=> gen_math_ops.add(a, b); | |||||
public static Tensor add<Tx, Ty>(Tx a, Ty b, string name = null) | |||||
=> gen_math_ops.add(a, b, name: name); | |||||
/// <summary> | /// <summary> | ||||
/// Computes atan of x element-wise. | /// Computes atan of x element-wise. | ||||
@@ -33,7 +33,7 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
/// <param name="input_ops">The data input ops for an op to be created.</param> | /// <param name="input_ops">The data input ops for an op to be created.</param> | ||||
/// <returns>A list of control inputs for the op to be created.</returns> | /// <returns>A list of control inputs for the op to be created.</returns> | ||||
private ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops) | |||||
public ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops) | |||||
{ | { | ||||
var ret = new List<ITensorOrOperation>(); | var ret = new List<ITensorOrOperation>(); | ||||
@@ -53,6 +53,11 @@ namespace Tensorflow.Operations | |||||
protected Stack<ControlFlowContext> _context_stack; | protected Stack<ControlFlowContext> _context_stack; | ||||
protected ControlFlowContext _outer_context; | protected ControlFlowContext _outer_context; | ||||
/// <summary> | |||||
/// The keys are the names of tensors referenced by but external to this | |||||
/// context. Each value is the Tensor that should be used by this context to | |||||
/// access the key value (e.g. a switch output guarding a cond input value). | |||||
/// </summary> | |||||
protected Dictionary<string, ITensorOrOperation> _external_values; | protected Dictionary<string, ITensorOrOperation> _external_values; | ||||
public ControlFlowContext() | public ControlFlowContext() | ||||
@@ -68,6 +73,12 @@ namespace Tensorflow.Operations | |||||
_outer_context = ops.get_default_graph()._get_control_flow_context(); | _outer_context = ops.get_default_graph()._get_control_flow_context(); | ||||
if (values_def != null) | if (values_def != null) | ||||
_init_values_from_proto(values_def, import_scope: import_scope); | _init_values_from_proto(values_def, import_scope: import_scope); | ||||
else | |||||
{ | |||||
_values = new HashSet<string>(); | |||||
_external_values = new Dictionary<string, ITensorOrOperation>(); | |||||
} | |||||
} | } | ||||
public void __enter__() | public void __enter__() | ||||
@@ -114,6 +125,27 @@ namespace Tensorflow.Operations | |||||
graph._set_control_flow_context(this); | graph._set_control_flow_context(this); | ||||
} | } | ||||
protected virtual Tensor _Enter(Tensor data, string frame_name, | |||||
bool is_constant = false, | |||||
int parallel_iterations = 10, | |||||
bool use_ref = true, | |||||
bool use_input_shape = true, | |||||
string name = null) | |||||
{ | |||||
Tensor result; | |||||
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); | |||||
if (data.dtype.is_ref_dtype() && use_ref) | |||||
throw new NotImplementedException("_Enter"); | |||||
else | |||||
result = gen_control_flow_ops.enter( | |||||
data, frame_name, is_constant, parallel_iterations, name: name); | |||||
if (use_input_shape) | |||||
result.SetShape(data.TensorShape); | |||||
return result; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Exit this control flow context. | /// Exit this control flow context. | ||||
/// </summary> | /// </summary> | ||||
@@ -184,6 +216,10 @@ namespace Tensorflow.Operations | |||||
return true; | return true; | ||||
} | } | ||||
protected virtual bool _IsInOuterContext(Operation op) | |||||
{ | |||||
throw new NotImplementedException("_IsInOuterContext"); | |||||
} | |||||
protected virtual void _RemoveExternalControlEdges(Operation op) | protected virtual void _RemoveExternalControlEdges(Operation op) | ||||
{ | { | ||||
@@ -15,8 +15,12 @@ | |||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System; | using System; | ||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using Tensorflow.Operations.ControlFlows; | using Tensorflow.Operations.ControlFlows; | ||||
using Tensorflow.Util; | |||||
using static Tensorflow.Python; | using static Tensorflow.Python; | ||||
using static Tensorflow.control_flow_ops; | |||||
namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
{ | { | ||||
@@ -32,10 +36,14 @@ namespace Tensorflow.Operations | |||||
bool _swap_memory; | bool _swap_memory; | ||||
Tensor _pivot_for_pred; | Tensor _pivot_for_pred; | ||||
Tensor _pivot_for_body; | Tensor _pivot_for_body; | ||||
Tensor[] _loop_exits; | |||||
Tensor[] _loop_enters; | |||||
List<Tensor> _loop_exits; | |||||
List<Tensor> _loop_enters; | |||||
Graph _graph; | |||||
public override GradLoopState grad_state => _grad_state; | |||||
public override bool back_prop => _back_prop; | |||||
public WhileContext(int parallel_iterations = 10, | |||||
public WhileContext(int? maximum_iterations = null, | |||||
int parallel_iterations = 10, | |||||
bool back_prop = true, | bool back_prop = true, | ||||
bool swap_memory = false, | bool swap_memory = false, | ||||
string name = "while_context", | string name = "while_context", | ||||
@@ -49,12 +57,27 @@ namespace Tensorflow.Operations | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
__init__(); | |||||
_init_from_args(maximum_iterations, parallel_iterations, back_prop, swap_memory, name); | |||||
} | } | ||||
_grad_state = grad_state; | _grad_state = grad_state; | ||||
} | } | ||||
private void _init_from_args(int? maximum_iterations, | |||||
int parallel_iterations, | |||||
bool back_prop, | |||||
bool swap_memory, | |||||
string name) | |||||
{ | |||||
_name = ops.get_default_graph().unique_name(name); | |||||
_back_prop = back_prop; | |||||
_swap_memory = swap_memory; | |||||
_loop_exits = new List<Tensor>(); | |||||
_loop_enters = new List<Tensor>(); | |||||
_graph = ops.get_default_graph(); | |||||
} | |||||
private void _init_from_proto(WhileContextDef context_def, string import_scope = null) | private void _init_from_proto(WhileContextDef context_def, string import_scope = null) | ||||
{ | { | ||||
var g = ops.get_default_graph(); | var g = ops.get_default_graph(); | ||||
@@ -70,26 +93,156 @@ namespace Tensorflow.Operations | |||||
// The boolean tensor for loop termination condition. | // The boolean tensor for loop termination condition. | ||||
_pivot = g.as_graph_element(ops.prepend_name_scope(context_def.PivotName, import_scope)) as Tensor; | _pivot = g.as_graph_element(ops.prepend_name_scope(context_def.PivotName, import_scope)) as Tensor; | ||||
// The list of exit tensors for loop variables. | // The list of exit tensors for loop variables. | ||||
_loop_exits = new Tensor[context_def.LoopExitNames.Count]; | |||||
_loop_exits = new List<Tensor>(); | |||||
foreach (var (i, exit_name) in enumerate(context_def.LoopExitNames)) | foreach (var (i, exit_name) in enumerate(context_def.LoopExitNames)) | ||||
_loop_exits[i] = g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) as Tensor; | |||||
_loop_exits.Add(g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) as Tensor); | |||||
// The list of enter tensors for loop variables. | // The list of enter tensors for loop variables. | ||||
_loop_enters = new Tensor[context_def.LoopEnterNames.Count]; | |||||
_loop_enters = new List<Tensor>(); | |||||
foreach (var (i, enter_name) in enumerate(context_def.LoopEnterNames)) | foreach (var (i, enter_name) in enumerate(context_def.LoopEnterNames)) | ||||
_loop_enters[i] = g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) as Tensor; | |||||
_loop_enters.Add(g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) as Tensor); | |||||
__init__(values_def: context_def.ValuesDef, import_scope: import_scope); | __init__(values_def: context_def.ValuesDef, import_scope: import_scope); | ||||
} | } | ||||
public override WhileContext GetWhileContext() | |||||
/// <summary> | |||||
/// Add the loop termination condition and body to the graph. | |||||
/// </summary> | |||||
public Tensor[] BuildLoop(Func<Tensor, Tensor> pred, | |||||
Func<Tensor, Tensor> body, | |||||
Tensor[] loop_vars, | |||||
TensorShape shape_invariants, | |||||
bool return_same_structure) | |||||
{ | { | ||||
return this; | |||||
// Keep original_loop_vars to identify which are TensorArrays | |||||
var original_loop_vars = loop_vars; | |||||
// Convert TensorArrays to their flow variables | |||||
Enter(); | |||||
var(original_body_result, exit_vars) = _BuildLoop( | |||||
pred, body, original_loop_vars, loop_vars, shape_invariants); | |||||
Exit(); | |||||
var flat_result = original_body_result; | |||||
var exit_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_result, exit_vars); | |||||
var packed_exit_vars = nest.pack_sequence_as( | |||||
structure: original_body_result, | |||||
flat_sequence: exit_vars_with_tensor_arrays); | |||||
return packed_exit_vars as Tensor[]; | |||||
} | } | ||||
private (Tensor[], Tensor[]) _BuildLoop(Func<Tensor, Tensor> pred, | |||||
Func<Tensor, Tensor> body, | |||||
Tensor[] original_loop_vars, | |||||
Tensor[] loop_vars, | |||||
TensorShape shape_invariants) | |||||
{ | |||||
var flat_loop_vars = original_loop_vars; | |||||
public override GradLoopState grad_state => _grad_state; | |||||
// Let the context know the loop variables so the loop variables | |||||
// would be added in the outer contexts properly. | |||||
_InitializeValues(loop_vars); | |||||
var real_vars = loop_vars; | |||||
Tensor[] enter_vars = null; | |||||
tf_with(ops.control_dependencies(null), delegate | |||||
{ | |||||
enter_vars = real_vars.Select(x => _Enter(x, | |||||
_name, | |||||
is_constant: false, | |||||
parallel_iterations: _parallel_iterations, | |||||
use_input_shape: shape_invariants == null)) | |||||
.ToArray(); | |||||
public override bool back_prop => _back_prop; | |||||
foreach(var x in enter_vars) | |||||
{ | |||||
x.graph.prevent_feeding(x); | |||||
if (_outer_context != null) | |||||
_outer_context.AddInnerOp(x.op); | |||||
} | |||||
}); | |||||
// Finds the closest enclosing non-None control pivot. | |||||
var outer_context = _outer_context; | |||||
while (outer_context != null) | |||||
{ | |||||
} | |||||
_SetShapeInvariants(real_vars, enter_vars, shape_invariants); | |||||
// Fix the control inputs and control flow context of these enter ops. | |||||
_FixControlInputsAndContext(enter_vars); | |||||
_InitializeValues(enter_vars); | |||||
_loop_enters = enter_vars.ToList(); | |||||
var merge_vars = enter_vars | |||||
.Select(x => merge(new[] { x, x })) | |||||
.ToArray(); | |||||
_pivot_for_pred = merge_vars[0]; | |||||
// Build the graph for pred. | |||||
var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); | |||||
// var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays); | |||||
var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0])); | |||||
_pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); | |||||
var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) | |||||
.ToArray(); | |||||
// Build the graph for body. | |||||
var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); | |||||
// Convert TensorArray flow variables inside the context back into | |||||
// their associated TensorArrays for calling the body. | |||||
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); | |||||
var body_result = body(packed_vars_for_body[0]); | |||||
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
// Store body_result to keep track of TensorArrays returned by body | |||||
var original_body_result = new[] { body_result }; | |||||
// Convert TensorArrays returned by body into their flow variables | |||||
var result = new[] { body_result }; | |||||
var next_vars = new List<Tensor>(); | |||||
foreach (var (m, v) in zip(merge_vars, result)) | |||||
next_vars.Add(_AddNextAndBackEdge(m, v)); | |||||
// Add the exit ops. | |||||
var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); | |||||
_loop_exits = exit_vars; | |||||
// Exit the loop. | |||||
// ExitResult(exit_vars); | |||||
return (original_body_result, exit_vars.ToArray()); | |||||
} | |||||
private void _FixControlInputsAndContext(Tensor[] enters) | |||||
{ | |||||
var graph = ops.get_default_graph(); | |||||
foreach(var e in enters) | |||||
{ | |||||
var inp_op = e.op.inputs[0].op; | |||||
var control_inputs = graph._control_dependencies_for_inputs(new[] { inp_op }); | |||||
// op for op in control_inputs if self._IsInOuterContext(op) | |||||
var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op)) | |||||
.Select(x => x.op) | |||||
.ToArray(); | |||||
e.op._set_control_flow_context(this); | |||||
e.op._add_control_inputs(outer_control_inputs); | |||||
graph._record_op_seen_by_control_dependencies(e.op); | |||||
} | |||||
} | |||||
private void _InitializeValues(Tensor[] values) | |||||
{ | |||||
_values = new HashSet<string>(); | |||||
foreach(var x in values) | |||||
_values.Add(x.name); | |||||
} | |||||
public override WhileContext GetWhileContext() | |||||
{ | |||||
return this; | |||||
} | |||||
public WhileContext from_proto(WhileContextDef proto, string import_scope) | public WhileContext from_proto(WhileContextDef proto, string import_scope) | ||||
{ | { | ||||
@@ -141,30 +141,57 @@ namespace Tensorflow.Operations | |||||
string base_name = null; | string base_name = null; | ||||
tf_with(ops.name_scope("dynamic_rnn"), scope => base_name = scope); | tf_with(ops.name_scope("dynamic_rnn"), scope => base_name = scope); | ||||
Func<string, TensorShape, TF_DataType, Tensor> _create_ta = (name, element_shape, dtype_) => | |||||
Func<string, TensorShape, TF_DataType, TensorArray> _create_ta = (name, element_shape, dtype_) => | |||||
{ | { | ||||
new TensorArray(dtype: dtype_, | |||||
var ta = new TensorArray(dtype: dtype_, | |||||
size: time_steps, | size: time_steps, | ||||
element_shape: element_shape, | element_shape: element_shape, | ||||
tensor_array_name: base_name + name); | tensor_array_name: base_name + name); | ||||
throw new NotImplementedException(""); | |||||
return ta; | |||||
}; | }; | ||||
bool in_graph_mode = true; | bool in_graph_mode = true; | ||||
var output_ta = new List<TensorArray>(); | |||||
var input_ta = new List<TensorArray>(); | |||||
if (in_graph_mode) | if (in_graph_mode) | ||||
{ | { | ||||
foreach(var (i, out_size) in enumerate(flat_output_size)) | |||||
foreach (var (i, out_size) in enumerate(flat_output_size)) | |||||
{ | { | ||||
_create_ta($"output_{i}", | |||||
output_ta.Add(_create_ta($"output_{i}", | |||||
new TensorShape(const_batch_size).concatenate( | new TensorShape(const_batch_size).concatenate( | ||||
_maybe_tensor_shape_from_tensor(out_size)), | _maybe_tensor_shape_from_tensor(out_size)), | ||||
_infer_state_dtype(dtype, state)); | |||||
_infer_state_dtype(dtype, state))); | |||||
} | |||||
foreach (var (i, flat_input_i) in enumerate(flat_input)) | |||||
{ | |||||
input_ta.Add(_create_ta($"input_{i}", | |||||
new TensorShape(flat_input_i.dims.Skip(1).ToArray()), | |||||
flat_input_i.dtype)); | |||||
} | |||||
for (int i = 0; i < input_ta.Count; i++) | |||||
{ | |||||
var (ta, input_) = (input_ta[0], flat_input[0]); | |||||
} | } | ||||
} | } | ||||
// Make sure that we run at least 1 step, if necessary, to ensure | |||||
// the TensorArrays pick up the dynamic shape. | |||||
Tensor loop_bound; | |||||
if (in_graph_mode) | |||||
loop_bound = math_ops.minimum( | |||||
time_steps, math_ops.maximum(1, max_sequence_length)); | |||||
/*Func<Tensor, Tensor> cond = (ctime) => | |||||
{ | |||||
return null; | |||||
}; | |||||
control_flow_ops.while_loop( | |||||
cond: cond, | |||||
body = );*/ | |||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
@@ -26,6 +26,44 @@ namespace Tensorflow | |||||
{ | { | ||||
public class control_flow_ops | public class control_flow_ops | ||||
{ | { | ||||
public static Tensor _AddNextAndBackEdge(Tensor m, Tensor v, bool enforce_shape_invariant = true) | |||||
{ | |||||
v = ops.convert_to_tensor(v); | |||||
v = _NextIteration(v); | |||||
if (enforce_shape_invariant) | |||||
_EnforceShapeInvariant(m, v); | |||||
m.op._update_input(1, v); | |||||
return v; | |||||
} | |||||
/// <summary> | |||||
/// Check if the shapes of the loops variables are invariants. | |||||
/// </summary> | |||||
/// <param name="merge_var"></param> | |||||
/// <param name="next_var"></param> | |||||
public static void _EnforceShapeInvariant(Tensor merge_var, Tensor next_var) | |||||
{ | |||||
} | |||||
public static Tensor exit(Tensor data, string name = null) | |||||
{ | |||||
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); | |||||
if (data.dtype.is_ref_dtype()) | |||||
return gen_control_flow_ops.ref_exit(data, name: name); | |||||
else | |||||
return gen_control_flow_ops._exit(data, name: name); | |||||
} | |||||
public static Tensor _NextIteration(Tensor data, string name = null) | |||||
{ | |||||
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); | |||||
if (data.dtype.is_ref_dtype()) | |||||
return gen_control_flow_ops.ref_next_iteration(data, name: name); | |||||
else | |||||
return gen_control_flow_ops.next_iteration(data, name: name); | |||||
} | |||||
public static Operation Assert(Tensor condition, object[] data, int? summarize = null, string name = null) | public static Operation Assert(Tensor condition, object[] data, int? summarize = null, string name = null) | ||||
{ | { | ||||
return tf_with(ops.name_scope(name, "Assert", new { condition, data }), scope => | return tf_with(ops.name_scope(name, "Assert", new { condition, data }), scope => | ||||
@@ -213,6 +251,14 @@ namespace Tensorflow | |||||
return gen_array_ops.identity(data, name: name); | return gen_array_ops.identity(data, name: name); | ||||
} | } | ||||
public static void _SetShapeInvariants(Tensor[] input_vars, Tensor[] enter_vars, TensorShape shapes = null) | |||||
{ | |||||
if (shapes == null) | |||||
return; | |||||
throw new NotImplementedException("_SetShapeInvariants"); | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Forwards `data` to an output determined by `pred`. | /// Forwards `data` to an output determined by `pred`. | ||||
/// If `pred` is false, the `data` input is forwarded to the first output. | /// If `pred` is false, the `data` input is forwarded to the first output. | ||||
@@ -516,10 +562,52 @@ namespace Tensorflow | |||||
throw new NotImplementedException("ZerosLikeOutsideLoop"); | throw new NotImplementedException("ZerosLikeOutsideLoop"); | ||||
} | } | ||||
// TODO | |||||
public static void while_loop(Func<Tensor, Tensor> func, Func<Tensor, Tensor> func1, Tensor[] tensors, int? i) | |||||
/// <summary> | |||||
/// Repeat `body` while the condition `cond` is true. | |||||
/// </summary> | |||||
/// <param name="cond"></param> | |||||
/// <param name="body"></param> | |||||
/// <param name="loop_vars"></param> | |||||
/// <param name="i"></param> | |||||
public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars, | |||||
TensorShape shape_invariants = null, | |||||
int parallel_iterations = 10, | |||||
bool back_prop = true, | |||||
bool swap_memory = false, | |||||
string name = null, | |||||
int? maximum_iterations = null, | |||||
bool return_same_structure = false) | |||||
{ | { | ||||
throw new NotImplementedException(); | |||||
tf_with(ops.name_scope(name, "while", loop_vars), scope => | |||||
{ | |||||
if (loop_vars == null || loop_vars.Length == 0) | |||||
throw new ValueError("No loop variables provided"); | |||||
if (cond == null) | |||||
throw new ValueError("cond must be callable."); | |||||
if (body == null) | |||||
throw new ValueError("body must be callable."); | |||||
if (parallel_iterations < 1) | |||||
throw new ValueError("parallel_iterations must be a positive integer."); | |||||
var loop_context = new WhileContext( | |||||
maximum_iterations: maximum_iterations, | |||||
parallel_iterations: parallel_iterations, | |||||
back_prop: back_prop, | |||||
swap_memory: swap_memory); | |||||
if (loop_context.outer_context == null) | |||||
ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context); | |||||
var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, | |||||
return_same_structure); | |||||
if (maximum_iterations != null) | |||||
return results[1]; | |||||
else | |||||
return results[0]; | |||||
}); | |||||
throw new NotImplementedException("while_loop"); | |||||
} | } | ||||
} | } | ||||
@@ -20,6 +20,93 @@ namespace Tensorflow | |||||
{ | { | ||||
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | ||||
/// <summary> | |||||
/// Creates or finds a child frame, and makes `data` available to the child frame. | |||||
/// </summary> | |||||
/// <param name="data"></param> | |||||
/// <param name="frame_name"></param> | |||||
/// <param name="is_constant"></param> | |||||
/// <param name="parallel_iterations"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor enter(Tensor data, string frame_name = "frame_name", bool is_constant = false, int parallel_iterations = 10, string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("Enter", name, new | |||||
{ | |||||
data, | |||||
frame_name, | |||||
is_constant, | |||||
parallel_iterations | |||||
}); | |||||
return _op.output; | |||||
} | |||||
/// <summary> | |||||
/// Forwards the input to the output. | |||||
/// </summary> | |||||
/// <param name="input"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor loop_cond(Tensor input, string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("LoopCond", name, new { input }); | |||||
return _op.output; | |||||
} | |||||
/// <summary> | |||||
/// Makes its input available to the next iteration. | |||||
/// </summary> | |||||
/// <param name="data"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor ref_next_iteration(Tensor data, string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("RefNextIteration", name, new { data }); | |||||
return _op; | |||||
} | |||||
/// <summary> | |||||
/// Makes its input available to the next iteration. | |||||
/// </summary> | |||||
/// <param name="data"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor next_iteration(Tensor data, string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("NextIteration", name, new { data }); | |||||
return _op; | |||||
} | |||||
/// <summary> | |||||
/// Exits the current frame to its parent frame. | |||||
/// </summary> | |||||
/// <param name="data"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor ref_exit(Tensor data, string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("RefExit", name, new { data }); | |||||
return _op; | |||||
} | |||||
/// <summary> | |||||
/// Exits the current frame to its parent frame. | |||||
/// </summary> | |||||
/// <param name="data"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor _exit(Tensor data, string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("Exit", name, new { data }); | |||||
return _op; | |||||
} | |||||
public static Operation no_op(string name = null) | public static Operation no_op(string name = null) | ||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("NoOp", name, null); | var _op = _op_def_lib._apply_op_helper("NoOp", name, null); | ||||
@@ -516,6 +516,9 @@ namespace Tensorflow | |||||
}); | }); | ||||
} | } | ||||
public static Tensor minimum<Tx, Ty>(Tx x, Ty y, string name = null) | |||||
=> gen_math_ops.minimum(x, y, name: name); | |||||
public static Tensor maximum<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor maximum<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
=> gen_math_ops.maximum(x, y, name: name); | => gen_math_ops.maximum(x, y, name: name); | ||||
@@ -416,5 +416,6 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
public int tensor_int_val { get; set; } | |||||
} | } | ||||
} | } |
@@ -8,6 +8,18 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
[TestClass] | [TestClass] | ||||
public class WhileContextTestCase : PythonTest | public class WhileContextTestCase : PythonTest | ||||
{ | { | ||||
/// <summary> | |||||
/// https://www.tensorflow.org/api_docs/python/tf/while_loop | |||||
/// </summary> | |||||
[TestMethod] | |||||
public void SimpleWhileLoop() | |||||
{ | |||||
var i = constant_op.constant(0, name: "i"); | |||||
var c = new Func<Tensor, Tensor>(x => tf.less(x, 10, name: "c")); | |||||
var b = new Func<Tensor, Tensor>(x => tf.add(x, 1, name: "c")); | |||||
var r = control_flow_ops.while_loop(c, b, new[] { i }); | |||||
} | |||||
private void _testWhileContextHelper(int? maximum_iterations = null) | private void _testWhileContextHelper(int? maximum_iterations = null) | ||||
{ | { | ||||
// TODO: implement missing code dependencies | // TODO: implement missing code dependencies | ||||
@@ -17,7 +29,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c")); | var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c")); | ||||
var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c")); | var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c")); | ||||
control_flow_ops.while_loop( | control_flow_ops.while_loop( | ||||
c, b, new[] { i }, maximum_iterations); | |||||
c, b, new[] { i }, maximum_iterations: maximum_iterations); | |||||
foreach (Operation op in sess.graph.get_operations()) | foreach (Operation op in sess.graph.get_operations()) | ||||
{ | { | ||||
var control_flow_context = op._get_control_flow_context(); | var control_flow_context = op._get_control_flow_context(); | ||||