@@ -1,5 +1,6 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Text; | |||
using NumSharp; | |||
using Tensorflow; | |||
@@ -21,7 +22,12 @@ namespace Tensorflow.Hub | |||
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); | |||
images.astype(dataType); | |||
// for debug np.multiply performance | |||
var sw = new Stopwatch(); | |||
sw.Start(); | |||
images = np.multiply(images, 1.0f / 255.0f); | |||
sw.Stop(); | |||
Console.WriteLine($"{sw.ElapsedMilliseconds}ms"); | |||
Data = images; | |||
labels.astype(dataType); | |||
@@ -14,10 +14,29 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
namespace Tensorflow | |||
{ | |||
public static partial class tf | |||
{ | |||
public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars, | |||
TensorShape shape_invariants = null, | |||
int parallel_iterations = 10, | |||
bool back_prop = true, | |||
bool swap_memory = false, | |||
string name = null, | |||
int? maximum_iterations = null, | |||
bool return_same_structure = false) | |||
=> control_flow_ops.while_loop(cond, body, loop_vars, | |||
shape_invariants: shape_invariants, | |||
parallel_iterations: parallel_iterations, | |||
back_prop: back_prop, | |||
swap_memory: swap_memory, | |||
name: name, | |||
maximum_iterations: maximum_iterations, | |||
return_same_structure: return_same_structure); | |||
public static _ControlDependenciesController control_dependencies(Operation[] control_inputs) | |||
=> ops.control_dependencies(control_inputs); | |||
} | |||
@@ -39,8 +39,8 @@ namespace Tensorflow | |||
public static Tensor asin(Tensor x, string name = null) | |||
=> gen_math_ops.asin(x, name); | |||
public static Tensor add<Tx, Ty>(Tx a, Ty b) | |||
=> gen_math_ops.add(a, b); | |||
public static Tensor add<Tx, Ty>(Tx a, Ty b, string name = null) | |||
=> gen_math_ops.add(a, b, name: name); | |||
/// <summary> | |||
/// Computes atan of x element-wise. | |||
@@ -33,7 +33,7 @@ namespace Tensorflow | |||
/// </summary> | |||
/// <param name="input_ops">The data input ops for an op to be created.</param> | |||
/// <returns>A list of control inputs for the op to be created.</returns> | |||
private ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops) | |||
public ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops) | |||
{ | |||
var ret = new List<ITensorOrOperation>(); | |||
@@ -53,6 +53,11 @@ namespace Tensorflow.Operations | |||
protected Stack<ControlFlowContext> _context_stack; | |||
protected ControlFlowContext _outer_context; | |||
/// <summary> | |||
/// The keys are the names of tensors referenced by but external to this | |||
/// context. Each value is the Tensor that should be used by this context to | |||
/// access the key value (e.g. a switch output guarding a cond input value). | |||
/// </summary> | |||
protected Dictionary<string, ITensorOrOperation> _external_values; | |||
public ControlFlowContext() | |||
@@ -68,6 +73,12 @@ namespace Tensorflow.Operations | |||
_outer_context = ops.get_default_graph()._get_control_flow_context(); | |||
if (values_def != null) | |||
_init_values_from_proto(values_def, import_scope: import_scope); | |||
else | |||
{ | |||
_values = new HashSet<string>(); | |||
_external_values = new Dictionary<string, ITensorOrOperation>(); | |||
} | |||
} | |||
public void __enter__() | |||
@@ -114,6 +125,27 @@ namespace Tensorflow.Operations | |||
graph._set_control_flow_context(this); | |||
} | |||
protected virtual Tensor _Enter(Tensor data, string frame_name, | |||
bool is_constant = false, | |||
int parallel_iterations = 10, | |||
bool use_ref = true, | |||
bool use_input_shape = true, | |||
string name = null) | |||
{ | |||
Tensor result; | |||
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); | |||
if (data.dtype.is_ref_dtype() && use_ref) | |||
throw new NotImplementedException("_Enter"); | |||
else | |||
result = gen_control_flow_ops.enter( | |||
data, frame_name, is_constant, parallel_iterations, name: name); | |||
if (use_input_shape) | |||
result.SetShape(data.TensorShape); | |||
return result; | |||
} | |||
/// <summary> | |||
/// Exit this control flow context. | |||
/// </summary> | |||
@@ -184,6 +216,10 @@ namespace Tensorflow.Operations | |||
return true; | |||
} | |||
protected virtual bool _IsInOuterContext(Operation op) | |||
{ | |||
throw new NotImplementedException("_IsInOuterContext"); | |||
} | |||
protected virtual void _RemoveExternalControlEdges(Operation op) | |||
{ | |||
@@ -15,8 +15,12 @@ | |||
******************************************************************************/ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Operations.ControlFlows; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Python; | |||
using static Tensorflow.control_flow_ops; | |||
namespace Tensorflow.Operations | |||
{ | |||
@@ -32,10 +36,14 @@ namespace Tensorflow.Operations | |||
bool _swap_memory; | |||
Tensor _pivot_for_pred; | |||
Tensor _pivot_for_body; | |||
Tensor[] _loop_exits; | |||
Tensor[] _loop_enters; | |||
List<Tensor> _loop_exits; | |||
List<Tensor> _loop_enters; | |||
Graph _graph; | |||
public override GradLoopState grad_state => _grad_state; | |||
public override bool back_prop => _back_prop; | |||
public WhileContext(int parallel_iterations = 10, | |||
public WhileContext(int? maximum_iterations = null, | |||
int parallel_iterations = 10, | |||
bool back_prop = true, | |||
bool swap_memory = false, | |||
string name = "while_context", | |||
@@ -49,12 +57,27 @@ namespace Tensorflow.Operations | |||
} | |||
else | |||
{ | |||
__init__(); | |||
_init_from_args(maximum_iterations, parallel_iterations, back_prop, swap_memory, name); | |||
} | |||
_grad_state = grad_state; | |||
} | |||
private void _init_from_args(int? maximum_iterations, | |||
int parallel_iterations, | |||
bool back_prop, | |||
bool swap_memory, | |||
string name) | |||
{ | |||
_name = ops.get_default_graph().unique_name(name); | |||
_back_prop = back_prop; | |||
_swap_memory = swap_memory; | |||
_loop_exits = new List<Tensor>(); | |||
_loop_enters = new List<Tensor>(); | |||
_graph = ops.get_default_graph(); | |||
} | |||
private void _init_from_proto(WhileContextDef context_def, string import_scope = null) | |||
{ | |||
var g = ops.get_default_graph(); | |||
@@ -70,26 +93,156 @@ namespace Tensorflow.Operations | |||
// The boolean tensor for loop termination condition. | |||
_pivot = g.as_graph_element(ops.prepend_name_scope(context_def.PivotName, import_scope)) as Tensor; | |||
// The list of exit tensors for loop variables. | |||
_loop_exits = new Tensor[context_def.LoopExitNames.Count]; | |||
_loop_exits = new List<Tensor>(); | |||
foreach (var (i, exit_name) in enumerate(context_def.LoopExitNames)) | |||
_loop_exits[i] = g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) as Tensor; | |||
_loop_exits.Add(g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) as Tensor); | |||
// The list of enter tensors for loop variables. | |||
_loop_enters = new Tensor[context_def.LoopEnterNames.Count]; | |||
_loop_enters = new List<Tensor>(); | |||
foreach (var (i, enter_name) in enumerate(context_def.LoopEnterNames)) | |||
_loop_enters[i] = g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) as Tensor; | |||
_loop_enters.Add(g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) as Tensor); | |||
__init__(values_def: context_def.ValuesDef, import_scope: import_scope); | |||
} | |||
public override WhileContext GetWhileContext() | |||
/// <summary> | |||
/// Add the loop termination condition and body to the graph. | |||
/// </summary> | |||
public Tensor[] BuildLoop(Func<Tensor, Tensor> pred, | |||
Func<Tensor, Tensor> body, | |||
Tensor[] loop_vars, | |||
TensorShape shape_invariants, | |||
bool return_same_structure) | |||
{ | |||
return this; | |||
// Keep original_loop_vars to identify which are TensorArrays | |||
var original_loop_vars = loop_vars; | |||
// Convert TensorArrays to their flow variables | |||
Enter(); | |||
var(original_body_result, exit_vars) = _BuildLoop( | |||
pred, body, original_loop_vars, loop_vars, shape_invariants); | |||
Exit(); | |||
var flat_result = original_body_result; | |||
var exit_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_result, exit_vars); | |||
var packed_exit_vars = nest.pack_sequence_as( | |||
structure: original_body_result, | |||
flat_sequence: exit_vars_with_tensor_arrays); | |||
return packed_exit_vars as Tensor[]; | |||
} | |||
private (Tensor[], Tensor[]) _BuildLoop(Func<Tensor, Tensor> pred, | |||
Func<Tensor, Tensor> body, | |||
Tensor[] original_loop_vars, | |||
Tensor[] loop_vars, | |||
TensorShape shape_invariants) | |||
{ | |||
var flat_loop_vars = original_loop_vars; | |||
public override GradLoopState grad_state => _grad_state; | |||
// Let the context know the loop variables so the loop variables | |||
// would be added in the outer contexts properly. | |||
_InitializeValues(loop_vars); | |||
var real_vars = loop_vars; | |||
Tensor[] enter_vars = null; | |||
tf_with(ops.control_dependencies(null), delegate | |||
{ | |||
enter_vars = real_vars.Select(x => _Enter(x, | |||
_name, | |||
is_constant: false, | |||
parallel_iterations: _parallel_iterations, | |||
use_input_shape: shape_invariants == null)) | |||
.ToArray(); | |||
public override bool back_prop => _back_prop; | |||
foreach(var x in enter_vars) | |||
{ | |||
x.graph.prevent_feeding(x); | |||
if (_outer_context != null) | |||
_outer_context.AddInnerOp(x.op); | |||
} | |||
}); | |||
// Finds the closest enclosing non-None control pivot. | |||
var outer_context = _outer_context; | |||
while (outer_context != null) | |||
{ | |||
} | |||
_SetShapeInvariants(real_vars, enter_vars, shape_invariants); | |||
// Fix the control inputs and control flow context of these enter ops. | |||
_FixControlInputsAndContext(enter_vars); | |||
_InitializeValues(enter_vars); | |||
_loop_enters = enter_vars.ToList(); | |||
var merge_vars = enter_vars | |||
.Select(x => merge(new[] { x, x })) | |||
.ToArray(); | |||
_pivot_for_pred = merge_vars[0]; | |||
// Build the graph for pred. | |||
var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); | |||
// var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays); | |||
var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0])); | |||
_pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); | |||
var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) | |||
.ToArray(); | |||
// Build the graph for body. | |||
var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); | |||
// Convert TensorArray flow variables inside the context back into | |||
// their associated TensorArrays for calling the body. | |||
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); | |||
var body_result = body(packed_vars_for_body[0]); | |||
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||
// Store body_result to keep track of TensorArrays returned by body | |||
var original_body_result = new[] { body_result }; | |||
// Convert TensorArrays returned by body into their flow variables | |||
var result = new[] { body_result }; | |||
var next_vars = new List<Tensor>(); | |||
foreach (var (m, v) in zip(merge_vars, result)) | |||
next_vars.Add(_AddNextAndBackEdge(m, v)); | |||
// Add the exit ops. | |||
var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); | |||
_loop_exits = exit_vars; | |||
// Exit the loop. | |||
// ExitResult(exit_vars); | |||
return (original_body_result, exit_vars.ToArray()); | |||
} | |||
private void _FixControlInputsAndContext(Tensor[] enters) | |||
{ | |||
var graph = ops.get_default_graph(); | |||
foreach(var e in enters) | |||
{ | |||
var inp_op = e.op.inputs[0].op; | |||
var control_inputs = graph._control_dependencies_for_inputs(new[] { inp_op }); | |||
// op for op in control_inputs if self._IsInOuterContext(op) | |||
var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op)) | |||
.Select(x => x.op) | |||
.ToArray(); | |||
e.op._set_control_flow_context(this); | |||
e.op._add_control_inputs(outer_control_inputs); | |||
graph._record_op_seen_by_control_dependencies(e.op); | |||
} | |||
} | |||
private void _InitializeValues(Tensor[] values) | |||
{ | |||
_values = new HashSet<string>(); | |||
foreach(var x in values) | |||
_values.Add(x.name); | |||
} | |||
public override WhileContext GetWhileContext() | |||
{ | |||
return this; | |||
} | |||
public WhileContext from_proto(WhileContextDef proto, string import_scope) | |||
{ | |||
@@ -141,30 +141,57 @@ namespace Tensorflow.Operations | |||
string base_name = null; | |||
tf_with(ops.name_scope("dynamic_rnn"), scope => base_name = scope); | |||
Func<string, TensorShape, TF_DataType, Tensor> _create_ta = (name, element_shape, dtype_) => | |||
Func<string, TensorShape, TF_DataType, TensorArray> _create_ta = (name, element_shape, dtype_) => | |||
{ | |||
new TensorArray(dtype: dtype_, | |||
var ta = new TensorArray(dtype: dtype_, | |||
size: time_steps, | |||
element_shape: element_shape, | |||
tensor_array_name: base_name + name); | |||
throw new NotImplementedException(""); | |||
return ta; | |||
}; | |||
bool in_graph_mode = true; | |||
var output_ta = new List<TensorArray>(); | |||
var input_ta = new List<TensorArray>(); | |||
if (in_graph_mode) | |||
{ | |||
foreach(var (i, out_size) in enumerate(flat_output_size)) | |||
foreach (var (i, out_size) in enumerate(flat_output_size)) | |||
{ | |||
_create_ta($"output_{i}", | |||
output_ta.Add(_create_ta($"output_{i}", | |||
new TensorShape(const_batch_size).concatenate( | |||
_maybe_tensor_shape_from_tensor(out_size)), | |||
_infer_state_dtype(dtype, state)); | |||
_infer_state_dtype(dtype, state))); | |||
} | |||
foreach (var (i, flat_input_i) in enumerate(flat_input)) | |||
{ | |||
input_ta.Add(_create_ta($"input_{i}", | |||
new TensorShape(flat_input_i.dims.Skip(1).ToArray()), | |||
flat_input_i.dtype)); | |||
} | |||
for (int i = 0; i < input_ta.Count; i++) | |||
{ | |||
var (ta, input_) = (input_ta[0], flat_input[0]); | |||
} | |||
} | |||
// Make sure that we run at least 1 step, if necessary, to ensure | |||
// the TensorArrays pick up the dynamic shape. | |||
Tensor loop_bound; | |||
if (in_graph_mode) | |||
loop_bound = math_ops.minimum( | |||
time_steps, math_ops.maximum(1, max_sequence_length)); | |||
/*Func<Tensor, Tensor> cond = (ctime) => | |||
{ | |||
return null; | |||
}; | |||
control_flow_ops.while_loop( | |||
cond: cond, | |||
body = );*/ | |||
throw new NotImplementedException(""); | |||
} | |||
@@ -26,6 +26,44 @@ namespace Tensorflow | |||
{ | |||
public class control_flow_ops | |||
{ | |||
public static Tensor _AddNextAndBackEdge(Tensor m, Tensor v, bool enforce_shape_invariant = true) | |||
{ | |||
v = ops.convert_to_tensor(v); | |||
v = _NextIteration(v); | |||
if (enforce_shape_invariant) | |||
_EnforceShapeInvariant(m, v); | |||
m.op._update_input(1, v); | |||
return v; | |||
} | |||
/// <summary> | |||
/// Check if the shapes of the loops variables are invariants. | |||
/// </summary> | |||
/// <param name="merge_var"></param> | |||
/// <param name="next_var"></param> | |||
public static void _EnforceShapeInvariant(Tensor merge_var, Tensor next_var) | |||
{ | |||
} | |||
public static Tensor exit(Tensor data, string name = null) | |||
{ | |||
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); | |||
if (data.dtype.is_ref_dtype()) | |||
return gen_control_flow_ops.ref_exit(data, name: name); | |||
else | |||
return gen_control_flow_ops._exit(data, name: name); | |||
} | |||
public static Tensor _NextIteration(Tensor data, string name = null) | |||
{ | |||
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); | |||
if (data.dtype.is_ref_dtype()) | |||
return gen_control_flow_ops.ref_next_iteration(data, name: name); | |||
else | |||
return gen_control_flow_ops.next_iteration(data, name: name); | |||
} | |||
public static Operation Assert(Tensor condition, object[] data, int? summarize = null, string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, "Assert", new { condition, data }), scope => | |||
@@ -213,6 +251,14 @@ namespace Tensorflow | |||
return gen_array_ops.identity(data, name: name); | |||
} | |||
public static void _SetShapeInvariants(Tensor[] input_vars, Tensor[] enter_vars, TensorShape shapes = null) | |||
{ | |||
if (shapes == null) | |||
return; | |||
throw new NotImplementedException("_SetShapeInvariants"); | |||
} | |||
/// <summary> | |||
/// Forwards `data` to an output determined by `pred`. | |||
/// If `pred` is false, the `data` input is forwarded to the first output. | |||
@@ -516,10 +562,52 @@ namespace Tensorflow | |||
throw new NotImplementedException("ZerosLikeOutsideLoop"); | |||
} | |||
// TODO | |||
public static void while_loop(Func<Tensor, Tensor> func, Func<Tensor, Tensor> func1, Tensor[] tensors, int? i) | |||
/// <summary> | |||
/// Repeat `body` while the condition `cond` is true. | |||
/// </summary> | |||
/// <param name="cond"></param> | |||
/// <param name="body"></param> | |||
/// <param name="loop_vars"></param> | |||
/// <param name="i"></param> | |||
public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars, | |||
TensorShape shape_invariants = null, | |||
int parallel_iterations = 10, | |||
bool back_prop = true, | |||
bool swap_memory = false, | |||
string name = null, | |||
int? maximum_iterations = null, | |||
bool return_same_structure = false) | |||
{ | |||
throw new NotImplementedException(); | |||
tf_with(ops.name_scope(name, "while", loop_vars), scope => | |||
{ | |||
if (loop_vars == null || loop_vars.Length == 0) | |||
throw new ValueError("No loop variables provided"); | |||
if (cond == null) | |||
throw new ValueError("cond must be callable."); | |||
if (body == null) | |||
throw new ValueError("body must be callable."); | |||
if (parallel_iterations < 1) | |||
throw new ValueError("parallel_iterations must be a positive integer."); | |||
var loop_context = new WhileContext( | |||
maximum_iterations: maximum_iterations, | |||
parallel_iterations: parallel_iterations, | |||
back_prop: back_prop, | |||
swap_memory: swap_memory); | |||
if (loop_context.outer_context == null) | |||
ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context); | |||
var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, | |||
return_same_structure); | |||
if (maximum_iterations != null) | |||
return results[1]; | |||
else | |||
return results[0]; | |||
}); | |||
throw new NotImplementedException("while_loop"); | |||
} | |||
} | |||
@@ -20,6 +20,93 @@ namespace Tensorflow | |||
{ | |||
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
/// <summary> | |||
/// Creates or finds a child frame, and makes `data` available to the child frame. | |||
/// </summary> | |||
/// <param name="data"></param> | |||
/// <param name="frame_name"></param> | |||
/// <param name="is_constant"></param> | |||
/// <param name="parallel_iterations"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor enter(Tensor data, string frame_name = "frame_name", bool is_constant = false, int parallel_iterations = 10, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("Enter", name, new | |||
{ | |||
data, | |||
frame_name, | |||
is_constant, | |||
parallel_iterations | |||
}); | |||
return _op.output; | |||
} | |||
/// <summary> | |||
/// Forwards the input to the output. | |||
/// </summary> | |||
/// <param name="input"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor loop_cond(Tensor input, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("LoopCond", name, new { input }); | |||
return _op.output; | |||
} | |||
/// <summary> | |||
/// Makes its input available to the next iteration. | |||
/// </summary> | |||
/// <param name="data"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor ref_next_iteration(Tensor data, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("RefNextIteration", name, new { data }); | |||
return _op; | |||
} | |||
/// <summary> | |||
/// Makes its input available to the next iteration. | |||
/// </summary> | |||
/// <param name="data"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor next_iteration(Tensor data, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("NextIteration", name, new { data }); | |||
return _op; | |||
} | |||
/// <summary> | |||
/// Exits the current frame to its parent frame. | |||
/// </summary> | |||
/// <param name="data"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor ref_exit(Tensor data, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("RefExit", name, new { data }); | |||
return _op; | |||
} | |||
/// <summary> | |||
/// Exits the current frame to its parent frame. | |||
/// </summary> | |||
/// <param name="data"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor _exit(Tensor data, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("Exit", name, new { data }); | |||
return _op; | |||
} | |||
public static Operation no_op(string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("NoOp", name, null); | |||
@@ -516,6 +516,9 @@ namespace Tensorflow | |||
}); | |||
} | |||
public static Tensor minimum<Tx, Ty>(Tx x, Ty y, string name = null) | |||
=> gen_math_ops.minimum(x, y, name: name); | |||
public static Tensor maximum<Tx, Ty>(Tx x, Ty y, string name = null) | |||
=> gen_math_ops.maximum(x, y, name: name); | |||
@@ -416,5 +416,6 @@ namespace Tensorflow | |||
} | |||
} | |||
public int tensor_int_val { get; set; } | |||
} | |||
} |
@@ -8,6 +8,18 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
[TestClass] | |||
public class WhileContextTestCase : PythonTest | |||
{ | |||
/// <summary> | |||
/// https://www.tensorflow.org/api_docs/python/tf/while_loop | |||
/// </summary> | |||
[TestMethod] | |||
public void SimpleWhileLoop() | |||
{ | |||
var i = constant_op.constant(0, name: "i"); | |||
var c = new Func<Tensor, Tensor>(x => tf.less(x, 10, name: "c")); | |||
var b = new Func<Tensor, Tensor>(x => tf.add(x, 1, name: "c")); | |||
var r = control_flow_ops.while_loop(c, b, new[] { i }); | |||
} | |||
private void _testWhileContextHelper(int? maximum_iterations = null) | |||
{ | |||
// TODO: implement missing code dependencies | |||
@@ -17,7 +29,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c")); | |||
var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c")); | |||
control_flow_ops.while_loop( | |||
c, b, new[] { i }, maximum_iterations); | |||
c, b, new[] { i }, maximum_iterations: maximum_iterations); | |||
foreach (Operation op in sess.graph.get_operations()) | |||
{ | |||
var control_flow_context = op._get_control_flow_context(); | |||