Browse Source

CondContext, BatchNormalization.

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
48a11d4710
16 changed files with 276 additions and 13 deletions
  1. +14
    -0
      src/TensorFlowNET.Core/APIs/tf.nn.cs
  2. +4
    -1
      src/TensorFlowNET.Core/Framework/smart_module.cs
  3. +3
    -3
      src/TensorFlowNET.Core/Graphs/Graph.Control.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
  5. +17
    -4
      src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
  6. +4
    -1
      src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs
  7. +76
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
  8. +46
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  9. +10
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs
  10. +10
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
  11. +25
    -0
      src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
  12. +8
    -0
      src/TensorFlowNET.Core/Operations/Operation.Control.cs
  13. +20
    -3
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  14. +32
    -0
      src/TensorFlowNET.Core/Operations/nn_impl.py.cs
  15. +3
    -0
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  16. +3
    -0
      src/TensorFlowNET.Core/ops.GraphKeys.cs

+ 14
- 0
src/TensorFlowNET.Core/APIs/tf.nn.cs View File

@@ -26,6 +26,20 @@ namespace Tensorflow
name: name);

public static IActivation relu => new relu();

public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x,
RefVariable scale,
RefVariable offset,
Tensor mean = null,
Tensor variance = null,
float epsilon = 0.001f,
string data_format = "NHWC",
bool is_training = true,
string name = null) => nn_impl.fused_batch_norm(x, scale, offset, mean, variance,
epsilon: epsilon,
data_format: data_format,
is_training: is_training,
name: name);
}
}
}

+ 4
- 1
src/TensorFlowNET.Core/Framework/smart_module.cs View File

@@ -6,7 +6,10 @@ namespace Tensorflow.Framework
{
public class smart_module
{
public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null)
public static object smart_cond(Tensor pred,
Func<(Tensor, Tensor, Tensor)> true_fn = null,
Func<(Tensor, Tensor, Tensor)> false_fn = null,
string name = null)
{
return control_flow_ops.cond(pred,
true_fn: true_fn,


+ 3
- 3
src/TensorFlowNET.Core/Graphs/Graph.Control.cs View File

@@ -8,7 +8,7 @@ namespace Tensorflow
{
public partial class Graph
{
public Context _control_flow_context;
public IControlFlowContext _control_flow_context;

private Queue<_ControlDependenciesController> _graph_control_dependencies_stack = new Queue<_ControlDependenciesController>();
public Queue<_ControlDependenciesController> _control_dependencies_stack
@@ -72,7 +72,7 @@ namespace Tensorflow
/// Returns the current control flow context.
/// </summary>
/// <returns>A context object.</returns>
public Context _get_control_flow_context()
public IControlFlowContext _get_control_flow_context()
{
return _control_flow_context;
}
@@ -81,7 +81,7 @@ namespace Tensorflow
/// Sets the current control flow context.
/// </summary>
/// <param name="ctx">a context object.</param>
public void _set_control_flow_context(Context ctx)
public void _set_control_flow_context(IControlFlowContext ctx)
{
_control_flow_context = ctx;
}


+ 1
- 1
src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs View File

@@ -15,7 +15,7 @@ namespace Tensorflow
private List<ITensorOrOperation> _seen_nodes;
private Queue<_ControlDependenciesController> _old_stack;
private bool _new_stack;
private Context _old_control_flow_context;
private IControlFlowContext _old_control_flow_context;

public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray();



+ 17
- 4
src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs View File

@@ -142,14 +142,27 @@ namespace Tensorflow.Keras.Layers
var beta = this.beta;
var gamma = this.gamma;

Action _fused_batch_norm_training = () =>
Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_training = () =>
{

return tf.nn.fused_batch_norm(
inputs,
gamma,
beta,
epsilon: epsilon,
data_format: _data_format);
};

Action _fused_batch_norm_inference = () =>
Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_inference = () =>
{

return tf.nn.fused_batch_norm(
inputs,
gamma,
beta,
mean: moving_mean,
variance: moving_variance,
epsilon: epsilon,
is_training: false,
data_format: _data_format);
};

tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference);


+ 4
- 1
src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs View File

@@ -18,7 +18,10 @@ namespace Tensorflow.Keras.Utils
return true;
}

public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null)
public static object smart_cond(Tensor pred,
Func<(Tensor, Tensor, Tensor)> true_fn = null,
Func<(Tensor, Tensor, Tensor)> false_fn = null,
string name = null)
{
return smart_module.smart_cond(pred,
true_fn: true_fn,


+ 76
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs View File

@@ -0,0 +1,76 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations
{
/// <summary>
/// The context for the conditional construct.
/// </summary>
public class CondContext : ControlFlowContext
{
private string _name;
/// <summary>
/// The boolean tensor for the cond predicate
/// </summary>
private Tensor _pred;
/// <summary>
/// The predicate tensor in this branch
/// </summary>
private Tensor _pivot;
/// <summary>
/// 0 or 1 representing this branch
/// </summary>
private int _branch;
/// <summary>
///
/// </summary>
private List<string> _values = new List<string>();
private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>();

/// <summary>
///
/// </summary>
/// <param name="pred">The `boolean` tensor for the conditional predicate.</param>
/// <param name="pivot">The predicate tensor in this branch.</param>
/// <param name="branch">0 or 1 representing this branch.</param>
/// <param name="name">Name of the `CondContext` python object.</param>
/// <param name="context_def"></param>
/// <param name="import_scope"></param>
public CondContext(Tensor pred,
Tensor pivot,
int branch,
string name = "cond_text",
object context_def = null,
string import_scope = null)
{
_name = ops.get_default_graph().unique_name(name);
if (context_def != null)
throw new NotImplementedException("CondContext context_def is not null");
else
{
// Initializes the default fields.
base.__init__();
_pred = pred;
_pivot = pivot;

// Values considered to have been already seen in this context. pred is not
// included in this context.
_values.Add(pred.name);
_external_values[pred.name] = pred;
_values.Add(pivot.name);
pivot.op._set_control_flow_context(this);
}
}

public (Tensor, Tensor, Tensor) BuildCondBranch(Func<(Tensor, Tensor, Tensor)> fn)
{
// Add the subgraph defined by fn() to the graph.
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
var original_result = fn();
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);

return original_result;
}
}
}

+ 46
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -0,0 +1,46 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations
{
public abstract class ControlFlowContext : IPython, IControlFlowContext
{
protected Stack<IControlFlowContext> _context_stack;
public ControlFlowContext()
{
_context_stack = new Stack<IControlFlowContext>();
}

public void __init__()
{

}

public void __enter__()
{
}

public virtual void Enter()
{
var graph = ops.get_default_graph();
_context_stack.Push(graph._get_control_flow_context());
graph._set_control_flow_context(this);
}

public void Exit()
{
var graph = ops.get_default_graph();
var last_context = _context_stack.Pop();
graph._set_control_flow_context(last_context);
}

public void __exit__()
{
}

public void Dispose()
{
}
}
}

+ 10
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public interface IControlFlowContext
{
}
}

+ 10
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations
{
public class WhileContext : ControlFlowContext
{
}
}

+ 25
- 0
src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs View File

@@ -52,5 +52,30 @@ namespace Tensorflow.Operations

return _op.outputs[0];
}

public static (Tensor, Tensor, Tensor) _fused_batch_norm(Tensor x,
Tensor scale,
Tensor offset,
Tensor mean,
Tensor variance,
float epsilon = 0.0001f,
string data_format = "NHWC",
bool is_training = true,
string name = null)
{
var _op = _op_def_lib._apply_op_helper("FusedBatchNorm", name: name, args: new
{
x,
scale,
offset,
mean,
variance,
epsilon,
data_format,
is_training
});

return (_op.outputs[0], _op.outputs[1], _op.outputs[2]);
}
}
}

+ 8
- 0
src/TensorFlowNET.Core/Operations/Operation.Control.cs View File

@@ -1,11 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations;

namespace Tensorflow
{
public partial class Operation
{
private CondContext _control_flow_context;

/// <summary>
/// Add this op to its control flow context.
/// </summary>
@@ -24,5 +27,10 @@ namespace Tensorflow
c_api.TF_AddControlInput(graph, op);
}
}

public void _set_control_flow_context(CondContext ctx)
{
_control_flow_context = ctx;
}
}
}

+ 20
- 3
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Operations;

namespace Tensorflow
{
@@ -136,9 +137,9 @@ namespace Tensorflow
return gen_array_ops.identity(data, name: name);
}

public static (Tensor, Tensor) cond(Tensor pred,
Action true_fn = null,
Action false_fn = null,
public static (Tensor, Tensor) cond(Tensor pred,
Func<(Tensor, Tensor, Tensor)> true_fn = null,
Func<(Tensor, Tensor, Tensor)> false_fn = null,
bool strict = false,
string name = null)
{
@@ -154,6 +155,22 @@ namespace Tensorflow
foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred })
tensor.op.graph.prevent_fetching(tensor.op);

// Build the graph for the true branch in a new context.
var context_t = new CondContext(pred, pivot_1, branch: 1);
context_t.Enter();
var res_t = context_t.BuildCondBranch(true_fn);
context_t.Exit();

// Build the graph for the false branch in a new context.
var context_f = new CondContext(pred, pivot_2, branch: 0);
context_f.Enter();
var res_f = context_f.BuildCondBranch(false_fn);
context_f.Exit();

var res_t_flat = new Tensor[] { res_t.Item1, res_t.Item2, res_t.Item3 };
var res_f_flat = new Tensor[] { res_f.Item1, res_f.Item2, res_f.Item3 };


return (p_2, p_1);
});
}


+ 32
- 0
src/TensorFlowNET.Core/Operations/nn_impl.py.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations;

namespace Tensorflow
{
@@ -44,5 +45,36 @@ namespace Tensorflow
return (mean, variance);
});
}

public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x,
RefVariable scale,
RefVariable offset,
Tensor mean,
Tensor variance,
float epsilon = 0.001f,
string data_format = "NHWC",
bool is_training = true,
string name = null)
{
x = ops.convert_to_tensor(x, name: "input");
var scale_tensor = ops.convert_to_tensor(scale, name: "scale");
var offset_tensor = ops.convert_to_tensor(offset, name: "offset");
if (mean == null)
mean = constant_op.constant(new float[0]);
if(variance == null)
variance = constant_op.constant(new float[0]);
var min_epsilon = 1.001e-5f;
epsilon = epsilon > min_epsilon ? epsilon : min_epsilon;

return gen_nn_ops._fused_batch_norm(x,
scale_tensor,
offset_tensor,
mean,
variance,
epsilon,
data_format,
is_training,
name);
}
}
}

+ 3
- 0
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -107,6 +107,9 @@ namespace Tensorflow
case float floatVal:
nparray = floatVal;
break;
case float[] floatVals:
nparray = floatVals;
break;
case double doubleVal:
nparray = doubleVal;
break;


+ 3
- 0
src/TensorFlowNET.Core/ops.GraphKeys.cs View File

@@ -44,6 +44,9 @@ namespace Tensorflow
/// Key to collect update_ops
/// </summary>
public static string UPDATE_OPS = "update_ops";

// Used to store v2 summary names.
public static string _SUMMARY_COLLECTION = "_SUMMARY_V2";
}
}
}

Loading…
Cancel
Save