@@ -55,6 +55,9 @@ namespace Tensorflow | |||||
* is more than one. | * is more than one. | ||||
**/ | **/ | ||||
var grads = new Dictionary<string, List<List<Tensor>>>(); | var grads = new Dictionary<string, List<List<Tensor>>>(); | ||||
Operation[] reachable_to_ops = null; | |||||
ControlFlowState loop_state = null; | |||||
Dictionary<string, int> pending_count = null; | |||||
tf_with(ops.name_scope(name, "gradients", | tf_with(ops.name_scope(name, "gradients", | ||||
values: ys.Concat(xs).Concat(stop_gradients).Concat(grad_ys)), scope => | values: ys.Concat(xs).Concat(stop_gradients).Concat(grad_ys)), scope => | ||||
@@ -81,7 +84,7 @@ namespace Tensorflow | |||||
var to_ops = ys.Select(x => x.op).ToList(); | var to_ops = ys.Select(x => x.op).ToList(); | ||||
var from_ops = xs.Select(x => x.op).ToList(); | var from_ops = xs.Select(x => x.op).ToList(); | ||||
var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); | var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); | ||||
var (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs); | |||||
(reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs); | |||||
// Add the initial gradients for the ys. | // Add the initial gradients for the ys. | ||||
foreach (var (y, grad_y) in zip(ys, grad_ys)) | foreach (var (y, grad_y) in zip(ys, grad_ys)) | ||||
@@ -120,126 +123,135 @@ namespace Tensorflow | |||||
{ | { | ||||
// generate gradient subgraph for op. | // generate gradient subgraph for op. | ||||
var op = queue.Dequeue(); | var op = queue.Dequeue(); | ||||
if(op.name == "rnn/while/basic_rnn_cell/Tanh") | |||||
if(op.name == "rnn/while/Exit") | |||||
{ | { | ||||
} | } | ||||
_maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); | _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); | ||||
//if (loop_state != null) | |||||
//loop_state.EnterGradWhileContext(op, before: true); | |||||
var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method); | |||||
Tensor[] in_grads = null; | |||||
var is_partitioned_call = _IsPartitionedCall(op); | |||||
var is_func_call = false; | |||||
var has_out_grads = out_grads.Exists(x => x != null); | |||||
if (has_out_grads && !stop_ops.Contains(op)) | |||||
{ | { | ||||
// A grad_fn must be defined, either as a function or as None | |||||
// for ops that do not have gradients. | |||||
if (loop_state != null) | |||||
loop_state.EnterGradWhileContext(op, before: true); | |||||
var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method); | |||||
if (loop_state != null) | |||||
loop_state.ExitGradWhileContext(op, before: true); | |||||
Func<Operation, Tensor[], Tensor[]> grad_fn = null; | |||||
try | |||||
{ | |||||
grad_fn = ops.get_gradient_function(op); | |||||
} | |||||
catch (LookupError) | |||||
Tensor[] in_grads = null; | |||||
var is_partitioned_call = _IsPartitionedCall(op); | |||||
var is_func_call = false; | |||||
var has_out_grads = out_grads.Exists(x => x != null); | |||||
if (has_out_grads && !stop_ops.Contains(op)) | |||||
{ | { | ||||
if (is_func_call) | |||||
// A grad_fn must be defined, either as a function or as None | |||||
// for ops that do not have gradients. | |||||
Func<Operation, Tensor[], Tensor[]> grad_fn = null; | |||||
try | |||||
{ | { | ||||
if (is_partitioned_call) | |||||
grad_fn = ops.get_gradient_function(op); | |||||
} | |||||
catch (LookupError) | |||||
{ | |||||
if (is_func_call) | |||||
{ | { | ||||
if (is_partitioned_call) | |||||
{ | |||||
} | |||||
else | |||||
{ | |||||
} | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})"); | |||||
} | } | ||||
} | } | ||||
else | |||||
{ | |||||
throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})"); | |||||
} | |||||
} | |||||
if (loop_state != null) | |||||
loop_state.EnterGradWhileContext(op, before: false); | |||||
if (loop_state != null) | |||||
loop_state.EnterGradWhileContext(op, before: false); | |||||
if ((is_func_call || grad_fn != null) && has_out_grads) | |||||
{ | |||||
// NOTE: If _AggregatedGrads didn't compute a value for the i'th | |||||
// output, it means that the cost does not depend on output[i], | |||||
// therefore dC/doutput[i] is 0. | |||||
foreach (var (i, out_grad) in enumerate(out_grads)) | |||||
if ((is_func_call || grad_fn != null) && has_out_grads) | |||||
{ | { | ||||
if (out_grad == null && | |||||
(grad_fn == null || _IsTrainable(op.outputs[i]))) | |||||
// NOTE: If _AggregatedGrads didn't compute a value for the i'th | |||||
// output, it means that the cost does not depend on output[i], | |||||
// therefore dC/doutput[i] is 0. | |||||
foreach (var (i, out_grad) in enumerate(out_grads)) | |||||
{ | { | ||||
// Only trainable outputs or outputs for a function call that | |||||
// will use SymbolicGradient get a zero gradient. Gradient | |||||
// functions should ignore the gradient for other outputs. | |||||
if (loop_state != null) | |||||
out_grads[i] = new List<Tensor> { loop_state.ZerosLike(op, i) }; | |||||
else | |||||
out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; | |||||
if (out_grad == null && | |||||
(grad_fn == null || _IsTrainable(op.outputs[i]))) | |||||
{ | |||||
// Only trainable outputs or outputs for a function call that | |||||
// will use SymbolicGradient get a zero gradient. Gradient | |||||
// functions should ignore the gradient for other outputs. | |||||
if (loop_state != null) | |||||
out_grads[i] = new List<Tensor> { loop_state.ZerosLike(op, i) }; | |||||
else | |||||
out_grads[i] = new List<Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) }; | |||||
} | |||||
} | } | ||||
} | |||||
tf_with(ops.name_scope(op.name + "_grad"), scope1 => | |||||
{ | |||||
if (grad_fn != null) | |||||
tf_with(ops.name_scope(op.name + "_grad"), scope1 => | |||||
{ | { | ||||
in_grads = _MaybeCompile(grad_scope, | |||||
op, | |||||
out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), | |||||
null, | |||||
grad_fn); | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); | |||||
} | |||||
_VerifyGeneratedGradients(in_grads, op); | |||||
if (gate_gradients && in_grads.Count(x => x != null) > 1) | |||||
{ | |||||
ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); | |||||
in_grads = control_flow_ops.tuple(in_grads); | |||||
} | |||||
}); | |||||
if (grad_fn != null) | |||||
{ | |||||
in_grads = _MaybeCompile(grad_scope, | |||||
op, | |||||
out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), | |||||
null, | |||||
grad_fn); | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); | |||||
} | |||||
_VerifyGeneratedGradients(in_grads, op); | |||||
if (gate_gradients && in_grads.Count(x => x != null) > 1) | |||||
{ | |||||
ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); | |||||
in_grads = control_flow_ops.tuple(in_grads); | |||||
} | |||||
}); | |||||
} | |||||
else | |||||
{ | |||||
// If no grad_fn is defined or none of out_grads is available, | |||||
// just propagate a list of None backwards. | |||||
in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; | |||||
} | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
// If no grad_fn is defined or none of out_grads is available, | |||||
// just propagate a list of None backwards. | |||||
in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; | in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; | ||||
} | } | ||||
} | |||||
else | |||||
{ | |||||
in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; | |||||
} | |||||
var inputs = _NonEagerInputs(op, xs).ToList(); | |||||
foreach (var (t_in, in_grad) in zip(inputs, in_grads)) | |||||
{ | |||||
if (in_grad != null) | |||||
var inputs = _NonEagerInputs(op, xs).ToList(); | |||||
foreach (var (t_in, in_grad) in zip(inputs, in_grads)) | |||||
{ | { | ||||
if (!(in_grad is null) && | |||||
in_grad.Tag == null && // maybe a IndexedSlice | |||||
t_in.dtype != TF_DataType.TF_RESOURCE) | |||||
if (in_grad != null) | |||||
{ | { | ||||
in_grad.set_shape(t_in.TensorShape); | |||||
} | |||||
if (!(in_grad is null) && | |||||
in_grad.Tag == null && // maybe a IndexedSlice | |||||
t_in.dtype != TF_DataType.TF_RESOURCE) | |||||
{ | |||||
in_grad.set_shape(t_in.TensorShape); | |||||
} | |||||
_SetGrad(grads, t_in, in_grad); | |||||
_SetGrad(grads, t_in, in_grad); | |||||
} | |||||
} | } | ||||
} | |||||
if (loop_state != null) | |||||
loop_state.ExitGradWhileContext(op, before: false); | |||||
} | |||||
// Update pending count for the inputs of op and enqueue ready ops. | // Update pending count for the inputs of op and enqueue ready ops. | ||||
_UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs); | _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs); | ||||
} | } | ||||
}); | }); | ||||
if (loop_state != null) | |||||
loop_state.PostProcessing(); | |||||
return xs.Select(x => _GetGrad(grads, x)).ToArray(); | return xs.Select(x => _GetGrad(grads, x)).ToArray(); | ||||
} | } | ||||
@@ -50,10 +50,11 @@ namespace Tensorflow.Layers | |||||
public virtual (Tensor, Tensor) apply(Tensor inputs, Tensor training = null) | public virtual (Tensor, Tensor) apply(Tensor inputs, Tensor training = null) | ||||
{ | { | ||||
return __call__(inputs, training: training); | |||||
var results = __call__(inputs, training: training); | |||||
return (results[0], results[1]); | |||||
} | } | ||||
public (Tensor, Tensor) __call__(Tensor inputs, | |||||
public Tensor[] __call__(Tensor inputs, | |||||
Tensor training = null, | Tensor training = null, | ||||
Tensor state = null, | Tensor state = null, | ||||
VariableScope scope = null) | VariableScope scope = null) | ||||
@@ -73,7 +74,7 @@ namespace Tensorflow.Layers | |||||
auxiliary_name_scope: false); | auxiliary_name_scope: false); | ||||
} | } | ||||
(Tensor, Tensor) outputs = (null, null); | |||||
Tensor[] outputs = null; | |||||
tf_with(scope_context_manager, scope2 => | tf_with(scope_context_manager, scope2 => | ||||
{ | { | ||||
_current_scope = scope2; | _current_scope = scope2; | ||||
@@ -151,27 +151,50 @@ namespace Tensorflow | |||||
/// <param name="colocate_gradients_with_ops"></param> | /// <param name="colocate_gradients_with_ops"></param> | ||||
public static ControlFlowState MaybeCreateControlFlowState(List<Operation> between_op_list, List<Operation> between_ops, bool colocate_gradients_with_ops) | public static ControlFlowState MaybeCreateControlFlowState(List<Operation> between_op_list, List<Operation> between_ops, bool colocate_gradients_with_ops) | ||||
{ | { | ||||
var flag = new List<Operation>(); | |||||
ControlFlowState loop_state = null; | ControlFlowState loop_state = null; | ||||
foreach (var op in between_op_list) | |||||
int pos = 0; | |||||
while(pos < between_op_list.Count) | |||||
{ | { | ||||
var op = between_op_list[pos]; | |||||
if (IsLoopExit(op)) | if (IsLoopExit(op)) | ||||
{ | { | ||||
if(loop_state == null) | |||||
if (loop_state == null) | |||||
{ | { | ||||
loop_state = new ControlFlowState(); | loop_state = new ControlFlowState(); | ||||
} | } | ||||
if (colocate_gradients_with_ops) | |||||
ops.colocate_with(op); | |||||
loop_state.AddWhileContext(op, between_op_list, between_ops); | |||||
} | } | ||||
pos++; | |||||
} | } | ||||
return loop_state; | return loop_state; | ||||
} | } | ||||
public static bool IsLoopExit(Operation op) | public static bool IsLoopExit(Operation op) | ||||
=> op.OpType == "Exit" || op.OpType == "RefExit"; | |||||
public static bool IsLoopSwitch(Operation op) | |||||
{ | |||||
if(IsSwitch(op)) | |||||
{ | |||||
var ctxt = op._get_control_flow_context(); | |||||
return ctxt != null && ctxt.IsWhileContext() && !IsCondSwitch(op); | |||||
} | |||||
return false; | |||||
} | |||||
public static bool IsCondSwitch(Operation op) | |||||
{ | { | ||||
return op.OpType == "Exit" || op.OpType == "RefExit"; | |||||
throw new NotImplementedException("IsCondSwitch"); | |||||
} | } | ||||
public static bool IsSwitch(Operation op) | |||||
=> op.type == "Switch" || op.type == "RefSwitch"; | |||||
public static Tensor[] tuple(Tensor[] tensors, string name = null, Operation[] control_inputs = null) | public static Tensor[] tuple(Tensor[] tensors, string name = null, Operation[] control_inputs = null) | ||||
{ | { | ||||
return tf_with(ops.name_scope(name, "tuple", tensors), scope => | return tf_with(ops.name_scope(name, "tuple", tensors), scope => | ||||
@@ -224,15 +247,10 @@ namespace Tensorflow | |||||
//TODO: missing original code | //TODO: missing original code | ||||
//if context.executing_eagerly(): | //if context.executing_eagerly(): | ||||
// return output_tensor | // return output_tensor | ||||
var values = new List<object>(); | |||||
values.AddRange(dependencies); | |||||
values.Add(output_tensor); | |||||
return tf_with(ops.name_scope(name, "control_dependency", values), scope => | |||||
return tf_with(ops.name_scope(name, "control_dependency", new { dependencies, output_tensor }), scope => | |||||
{ | { | ||||
name = scope; | name = scope; | ||||
// TODO: missing original code | |||||
//with ops.colocate_with(output_tensor): | |||||
ops.colocate_with(output_tensor); | |||||
{ | { | ||||
return tf_with(ops.control_dependencies(dependencies), ctl => | return tf_with(ops.control_dependencies(dependencies), ctl => | ||||
{ | { | ||||
@@ -431,6 +449,7 @@ namespace Tensorflow | |||||
var merges = zip(res_f_flat, res_t_flat) | var merges = zip(res_f_flat, res_t_flat) | ||||
.Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) | .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) | ||||
.Select(m => (Tensor)m) | |||||
.ToArray(); | .ToArray(); | ||||
var merges2 = _convert_flows_to_tensorarrays(new ITensorOrTensorArray[] { (Tensor)orig_res_t }, merges); | var merges2 = _convert_flows_to_tensorarrays(new ITensorOrTensorArray[] { (Tensor)orig_res_t }, merges); | ||||
@@ -479,6 +498,7 @@ namespace Tensorflow | |||||
var merges = zip(res_f_flat, res_t_flat) | var merges = zip(res_f_flat, res_t_flat) | ||||
.Select(pair => merge(new [] { pair.Item1, pair.Item2 })) | .Select(pair => merge(new [] { pair.Item1, pair.Item2 })) | ||||
.Select(m => (Tensor)m) | |||||
.ToArray(); | .ToArray(); | ||||
var merges2 = _convert_flows_to_tensorarrays(orig_res_t.Select(x => (ITensorOrTensorArray)x).ToArray(), merges); | var merges2 = _convert_flows_to_tensorarrays(orig_res_t.Select(x => (ITensorOrTensorArray)x).ToArray(), merges); | ||||
@@ -519,7 +539,7 @@ namespace Tensorflow | |||||
/// <param name="inputs">inputs: The input tensors, at most one of which is available.</param> | /// <param name="inputs">inputs: The input tensors, at most one of which is available.</param> | ||||
/// <param name="name">A name for this operation (optional).</param> | /// <param name="name">A name for this operation (optional).</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Tensor merge(Tensor[] inputs, string name = null) | |||||
public static MergeOutput merge(Tensor[] inputs, string name = null) | |||||
{ | { | ||||
if (inputs.Any(x => x == null)) | if (inputs.Any(x => x == null)) | ||||
throw new ValueError($"At least one of the merge inputs is null: {inputs}"); | throw new ValueError($"At least one of the merge inputs is null: {inputs}"); | ||||
@@ -529,7 +549,7 @@ namespace Tensorflow | |||||
inputs = inputs.Select(inp => | inputs = inputs.Select(inp => | ||||
ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true)) | ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true)) | ||||
.ToArray(); | .ToArray(); | ||||
return gen_control_flow_ops.merge(inputs, name)[0]; | |||||
return gen_control_flow_ops.merge(inputs, name); | |||||
}); | }); | ||||
} | } | ||||
@@ -602,7 +622,7 @@ namespace Tensorflow | |||||
/// <param name="body"></param> | /// <param name="body"></param> | ||||
/// <param name="loop_vars"></param> | /// <param name="loop_vars"></param> | ||||
/// <param name="i"></param> | /// <param name="i"></param> | ||||
public static Tensor while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TItem> body, TItem loop_vars, | |||||
public static TItem while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TItem> body, TItem loop_vars, | |||||
TensorShape[] shape_invariants = null, | TensorShape[] shape_invariants = null, | ||||
int parallel_iterations = 10, | int parallel_iterations = 10, | ||||
bool back_prop = true, | bool back_prop = true, | ||||
@@ -611,7 +631,7 @@ namespace Tensorflow | |||||
Tensor maximum_iterations = null, | Tensor maximum_iterations = null, | ||||
bool return_same_structure = false) | bool return_same_structure = false) | ||||
{ | { | ||||
tf_with(ops.name_scope(name, "while", loop_vars), scope => | |||||
return tf_with(ops.name_scope(name, "while", loop_vars), scope => | |||||
{ | { | ||||
if (loop_vars == null) | if (loop_vars == null) | ||||
throw new ValueError("No loop variables provided"); | throw new ValueError("No loop variables provided"); | ||||
@@ -666,13 +686,11 @@ namespace Tensorflow | |||||
var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars_1, shape_invariants, | var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars_1, shape_invariants, | ||||
return_same_structure); | return_same_structure); | ||||
if (maximum_iterations != null) | |||||
return results[1]; | |||||
else | |||||
return results[0]; | |||||
//if (maximum_iterations != null) | |||||
return results.Item; | |||||
//else | |||||
//return results; | |||||
}); | }); | ||||
throw new NotImplementedException("while_loop"); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -15,6 +15,7 @@ | |||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System; | using System; | ||||
using System.Linq; | |||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -60,6 +61,45 @@ namespace Tensorflow | |||||
public static bool IsSwitch(Operation op) | public static bool IsSwitch(Operation op) | ||||
{ | { | ||||
return op.type == "Switch" || op.type == "RefSwitch"; | return op.type == "Switch" || op.type == "RefSwitch"; | ||||
} | |||||
public static WhileContext GetWhileContext(Operation op) | |||||
=> op.GetWhileContext(); | |||||
public static bool IsCondSwitch(Operation op) | |||||
{ | |||||
if (!IsSwitch(op)) | |||||
return false; | |||||
if (op.outputs == null || op.outputs.Length == 0) | |||||
return false; | |||||
// Switch nodes are not part of the cond control flow context that they | |||||
// represent, so consider the consumers of its outputs to determine if it is | |||||
// cond switch or not. A switch is a cond switch iff all its consumers are in | |||||
// cond contexts. | |||||
var is_cond_switch = true; | |||||
foreach(var o in op.outputs) | |||||
{ | |||||
foreach(var c in o.consumers()) | |||||
{ | |||||
var ctxt = c._get_control_flow_context(); | |||||
if (IsLoopEnter(c)) | |||||
ctxt = ctxt.outer_context; | |||||
is_cond_switch = is_cond_switch &&(ctxt != null && ctxt.IsCondContext()); | |||||
} | |||||
} | |||||
return is_cond_switch; | |||||
} | |||||
public static bool IsLoopSwitch(Operation op) | |||||
{ | |||||
if (IsSwitch(op)) | |||||
{ | |||||
var ctxt = op._get_control_flow_context(); | |||||
return ctxt != null && ctxt.IsWhileContext() && !IsCondSwitch(op); | |||||
} | |||||
return false; | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -87,13 +127,64 @@ namespace Tensorflow | |||||
valid = true; | valid = true; | ||||
else | else | ||||
{ | { | ||||
throw new NotImplementedException(""); | |||||
var while_ctxt = GetContainingWhileContext(op_ctxt); | |||||
var input_while_ctxt = GetContainingWhileContext(input_ctxt); | |||||
if (while_ctxt == null) | |||||
{ | |||||
throw new NotImplementedException("CheckInputFromValidContext"); | |||||
} | |||||
else if (IsContainingContext(while_ctxt, input_while_ctxt)) | |||||
{ | |||||
// input_op is in a while loop which contains op's while loop (or not in a | |||||
// while loop at all). | |||||
valid = true; | |||||
} | |||||
else if (while_ctxt.grad_state != null && | |||||
IsContainingContext(while_ctxt.grad_state.forward_context, | |||||
input_while_ctxt)) | |||||
{ | |||||
valid = true; | |||||
} | |||||
else | |||||
throw new NotImplementedException("CheckInputFromValidContext"); | |||||
} | } | ||||
if (!valid) | if (!valid) | ||||
{ | { | ||||
throw new NotImplementedException(""); | |||||
throw new NotImplementedException("CheckInputFromValidContext"); | |||||
} | |||||
} | |||||
public static Operation GetLoopConstantEnter(Tensor value) | |||||
{ | |||||
var id_ops = new string[] { "Switch", "RefSwitch", "Identity", "RefIdentity" }; | |||||
var op = value.op; | |||||
while (id_ops.Contains(op.type)) | |||||
op = op.inputs[0].op; | |||||
return IsLoopConstantEnter(op) ? op : null; | |||||
} | |||||
public static bool IsContainingContext(WhileContext ctxt, WhileContext maybe_containing_ctxt) | |||||
{ | |||||
while(ctxt != maybe_containing_ctxt) | |||||
{ | |||||
if (ctxt == null) | |||||
return false; | |||||
ctxt = ctxt.outer_context as WhileContext; | |||||
} | |||||
return true; | |||||
} | |||||
public static WhileContext GetContainingWhileContext(ControlFlowContext ctxt, ControlFlowContext stop_ctxt = null) | |||||
{ | |||||
while (ctxt != null) | |||||
{ | |||||
if (ctxt.IsWhileContext() || ctxt == stop_ctxt) | |||||
return ctxt as WhileContext; | |||||
ctxt = ctxt.outer_context; | |||||
} | } | ||||
return null; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -159,6 +159,8 @@ namespace Tensorflow | |||||
}); | }); | ||||
} | } | ||||
public static Tensor greater_equal<Tx, Ty>(Tx x, Ty y, string name = null) | |||||
=> gen_math_ops.greater_equal<Tx, Ty>(x, y, name: name); | |||||
public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
=> gen_math_ops.equal(x, y, name: name); | => gen_math_ops.equal(x, y, name: name); | ||||