@@ -19,5 +19,16 @@ namespace Tensorflow | |||
/// </returns> | |||
public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) | |||
=> array_ops.expand_dims(input, axis, name, dim); | |||
/// <summary> | |||
/// Transposes `a`. Permutes the dimensions according to `perm`. | |||
/// </summary> | |||
/// <param name="a"></param> | |||
/// <param name="perm"></param> | |||
/// <param name="name"></param> | |||
/// <param name="conjugate"></param> | |||
/// <returns></returns> | |||
public static Tensor transpose(Tensor a, int[] perm = null, string name = "transpose", bool conjugate = false) | |||
=> array_ops.transpose(a, perm, name, conjugate); | |||
} | |||
} |
@@ -46,6 +46,45 @@ namespace Tensorflow | |||
return layer.apply(inputs); | |||
} | |||
/// <summary> | |||
/// Functional interface for the batch normalization layer. | |||
/// http://arxiv.org/abs/1502.03167 | |||
/// </summary> | |||
/// <param name="inputs"></param> | |||
/// <param name="axis"></param> | |||
/// <param name="momentum"></param> | |||
/// <param name="epsilon"></param> | |||
/// <param name="center"></param> | |||
/// <param name="scale"></param> | |||
/// <param name="beta_initializer"></param> | |||
/// <param name="gamma_initializer"></param> | |||
/// <param name="moving_mean_initializer"></param> | |||
/// <param name="moving_variance_initializer"></param> | |||
/// <param name="training"></param> | |||
/// <param name="trainable"></param> | |||
/// <param name="name"></param> | |||
/// <param name="renorm"></param> | |||
/// <param name="renorm_momentum"></param> | |||
/// <returns></returns> | |||
public static Tensor batch_normalization(Tensor inputs, | |||
int axis = -1, | |||
float momentum = 0.99f, | |||
float epsilon = 0.001f, | |||
bool center = true, | |||
bool scale = true, | |||
IInitializer beta_initializer = null, | |||
IInitializer gamma_initializer = null, | |||
IInitializer moving_mean_initializer = null, | |||
IInitializer moving_variance_initializer = null, | |||
Tensor training = null, | |||
bool trainable = true, | |||
string name = null, | |||
bool renorm = false, | |||
float renorm_momentum = 0.99f) | |||
{ | |||
throw new NotImplementedException("batch_normalization"); | |||
} | |||
} | |||
} | |||
} |
@@ -30,6 +30,7 @@ namespace Tensorflow.Keras.Engine | |||
VariableScope scope = null) | |||
{ | |||
var input_list = new Tensor[] { inputs }; | |||
Tensor outputs = null; | |||
// We will attempt to build a TF graph if & only if all inputs are symbolic. | |||
// This is always the case in graph mode. It can also be the case in eager | |||
@@ -45,9 +46,42 @@ namespace Tensorflow.Keras.Engine | |||
_maybe_build(inputs); | |||
built = true; | |||
} | |||
if (build_graph) | |||
{ | |||
// Symbolic execution on symbolic tensors. We will attempt to build | |||
// the corresponding TF subgraph inside `backend.get_graph()` | |||
var graph = backend.get_graph(); | |||
outputs = call(inputs); | |||
_handle_activity_regularization(inputs, outputs); | |||
_set_mask_metadata(inputs, outputs, null); | |||
} | |||
}); | |||
throw new NotImplementedException(""); | |||
return outputs; | |||
} | |||
private void _handle_activity_regularization(Tensor inputs, Tensor outputs) | |||
{ | |||
//if(_activity_regularizer != null) | |||
{ | |||
} | |||
} | |||
private void _set_mask_metadata(Tensor inputs, Tensor outputs, Tensor previous_mask) | |||
{ | |||
} | |||
private Tensor compute_mask(Tensor inputs, Tensor mask = null) | |||
{ | |||
return null; | |||
} | |||
protected virtual Tensor call(Tensor inputs) | |||
{ | |||
throw new NotImplementedException("Layer.call"); | |||
} | |||
protected virtual string _name_scope() | |||
@@ -90,5 +90,26 @@ namespace Tensorflow.Keras.Layers | |||
built = true; | |||
} | |||
protected override Tensor call(Tensor inputs) | |||
{ | |||
var outputs = _convolution_op.__call__(inputs, kernel); | |||
if (use_bias) | |||
{ | |||
if (data_format == "channels_first") | |||
{ | |||
throw new NotImplementedException("call channels_first"); | |||
} | |||
else | |||
{ | |||
outputs = nn_ops.bias_add(outputs, bias, data_format: "NHWC"); | |||
} | |||
} | |||
if (activation != null) | |||
return activation.Activate(outputs); | |||
return outputs; | |||
} | |||
} | |||
} |
@@ -10,5 +10,10 @@ namespace Tensorflow.Keras | |||
{ | |||
} | |||
public static Graph get_graph() | |||
{ | |||
return ops.get_default_graph(); | |||
} | |||
} | |||
} |
@@ -65,7 +65,10 @@ namespace Tensorflow.Layers | |||
// Actually call layer | |||
var outputs = base.__call__(inputs); | |||
throw new NotImplementedException(""); | |||
// Update global default collections. | |||
//_add_elements_to_collection(updates, ops.GraphKeys.UPDATE_OPS); | |||
return outputs; | |||
} | |||
protected virtual RefVariable add_weight(string name, | |||
@@ -6,6 +6,6 @@ namespace Tensorflow.Operations.Activation | |||
{ | |||
public interface IActivation | |||
{ | |||
Tensor Activate(Tensor features, string name = null); | |||
} | |||
} |
@@ -6,6 +6,16 @@ namespace Tensorflow.Operations.Activation | |||
{ | |||
public class relu : IActivation | |||
{ | |||
public Tensor Activate(Tensor features, string name = null) | |||
{ | |||
OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
var _op = _op_def_lib._apply_op_helper("Relu", name: name, args: new | |||
{ | |||
features | |||
}); | |||
return _op.outputs[0]; | |||
} | |||
} | |||
} |
@@ -62,5 +62,10 @@ namespace Tensorflow.Operations | |||
strides: strides, | |||
name: name); | |||
} | |||
public Tensor __call__(Tensor inp, RefVariable filter) | |||
{ | |||
return conv_op.__call__(inp, filter); | |||
} | |||
} | |||
} |
@@ -52,5 +52,18 @@ namespace Tensorflow.Operations | |||
throw new NotImplementedException("_NonAtrousConvolution conv_dims 3"); | |||
} | |||
} | |||
public Tensor __call__(Tensor inp, RefVariable filter) | |||
{ | |||
return conv_op(new | |||
{ | |||
input = inp, | |||
filter, | |||
strides, | |||
padding, | |||
data_format, | |||
name | |||
}); | |||
} | |||
} | |||
} |
@@ -51,5 +51,10 @@ namespace Tensorflow.Operations | |||
} | |||
} | |||
} | |||
public Tensor __call__(Tensor inp, RefVariable filter) | |||
{ | |||
return call.__call__(inp, filter); | |||
} | |||
} | |||
} |
@@ -6,9 +6,51 @@ namespace Tensorflow.Operations | |||
{ | |||
public class gen_nn_ops | |||
{ | |||
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
public static Tensor conv2d(object parameters) | |||
{ | |||
throw new NotImplementedException("gen_nn_op.conv2d"); | |||
var args = Python.ConvertToDict(parameters); | |||
var input = args["input"]; | |||
var filter = args["filter"]; | |||
var strides = args["strides"]; | |||
var padding = args["padding"]; | |||
var name = args["name"]; | |||
var data_format = args.ContainsKey("data_format") ? args["data_format"] : "NHWC"; | |||
var use_cudnn_on_gpu = args.ContainsKey("use_cudnn_on_gpu") ? args["use_cudnn_on_gpu"] : true; | |||
var dilations = args.ContainsKey("dilations") ? args["dilations"] : new int[] { 1, 1, 1, 1 }; | |||
var _op = _op_def_lib._apply_op_helper("Conv2D", name: name?.ToString(), args: new | |||
{ | |||
input, | |||
filter, | |||
strides, | |||
padding, | |||
use_cudnn_on_gpu, | |||
data_format, | |||
dilations | |||
}); | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor bias_add(Tensor value, | |||
Tensor bias, | |||
string data_format = null, | |||
string name = null) | |||
{ | |||
if (data_format == null) | |||
data_format = "NHWC"; | |||
var _op = _op_def_lib._apply_op_helper("BiasAdd", name: name, args: new | |||
{ | |||
value, | |||
bias, | |||
data_format | |||
}); | |||
return _op.outputs[0]; | |||
} | |||
} | |||
} |
@@ -272,5 +272,14 @@ namespace Tensorflow | |||
{ | |||
return gen_array_ops.gather_v2(@params, indices, axis, name: name); | |||
} | |||
public static Tensor transpose(Tensor a, int[] perm = null, string name = "transpose", bool conjugate = false) | |||
{ | |||
return with(ops.name_scope(name, "transpose", new { a }), scope => | |||
{ | |||
name = scope; | |||
return gen_array_ops.transpose(a, perm, name); | |||
}); | |||
} | |||
} | |||
} |
@@ -157,6 +157,12 @@ namespace Tensorflow | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor transpose(Tensor x, int[] perm, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("Transpose", name, new { x, perm }); | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor zeros_like(Tensor x, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("ZerosLike", name, new { x }); | |||
@@ -20,5 +20,26 @@ namespace Tensorflow | |||
dilation_rate, | |||
name: name, | |||
data_format: data_format); | |||
/// <summary> | |||
/// Adds `bias` to `value`. | |||
/// </summary> | |||
/// <param name="value"></param> | |||
/// <param name="bias"></param> | |||
/// <param name="data_format"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor bias_add(Tensor value, | |||
RefVariable bias, | |||
string data_format = null, | |||
string name = null) | |||
{ | |||
return Python.with(ops.name_scope(name, "BiasAdd", new { value, bias }), scope => | |||
{ | |||
value = ops.convert_to_tensor(value, name: "input"); | |||
var bias_tensor = ops.convert_to_tensor(bias, dtype: value.dtype, name: "bias"); | |||
return gen_nn_ops.bias_add(value, bias_tensor, data_format: data_format, name: name); | |||
}); | |||
} | |||
} | |||
} |
@@ -40,6 +40,10 @@ namespace Tensorflow | |||
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | |||
/// </summary> | |||
public static string SAVEABLE_OBJECTS = "saveable_objects"; | |||
/// <summary> | |||
/// Key to collect update_ops | |||
/// </summary> | |||
public static string UPDATE_OPS = "update_ops"; | |||
} | |||
} | |||
} |
@@ -1,5 +1,6 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow; | |||
@@ -43,14 +44,61 @@ namespace TensorFlowNET.Examples.TextClassification | |||
x_expanded = tf.expand_dims(x_emb, -1); | |||
}); | |||
Tensor conv0 = null; | |||
Tensor conv1 = null; | |||
// First Convolution Layer | |||
with(tf.variable_scope("conv-0"), delegate | |||
{ | |||
var conv0 = tf.layers.conv2d(x_expanded, | |||
conv0 = tf.layers.conv2d(x_expanded, | |||
filters: num_filters[0], | |||
kernel_size: new int[] { filter_sizes[0], embedding_size }, | |||
kernel_initializer: cnn_initializer, | |||
activation: tf.nn.relu); | |||
conv0 = tf.transpose(conv0, new int[] { 0, 1, 3, 2 }); | |||
}); | |||
with(tf.name_scope("conv-block-1"), delegate { | |||
conv1 = conv_block(conv0, 1); | |||
}); | |||
} | |||
private Tensor conv_block(Tensor input, int i, bool max_pool = true) | |||
{ | |||
return with(tf.variable_scope($"conv-block-{i}"), delegate | |||
{ | |||
Tensor conv = null; | |||
// Two "conv-batch_norm-relu" layers. | |||
foreach (var j in Enumerable.Range(0, 2)) | |||
{ | |||
with(tf.variable_scope($"conv-{j}"), delegate | |||
{ | |||
// convolution | |||
conv = tf.layers.conv2d( | |||
input, | |||
filters: num_filters[i], | |||
kernel_size: new int[] { filter_sizes[i], num_filters[i - 1] }, | |||
kernel_initializer: cnn_initializer, | |||
activation: null); | |||
// batch normalization | |||
conv = tf.layers.batch_normalization(conv, training: is_training); | |||
// relu | |||
conv = tf.nn.relu.Activate(conv); | |||
conv = tf.transpose(conv, new int[] { 0, 1, 3, 2 }); | |||
}); | |||
} | |||
if (max_pool) | |||
{ | |||
// Max pooling | |||
throw new NotImplementedException("conv_block"); | |||
} | |||
else | |||
{ | |||
return conv; | |||
} | |||
}); | |||
} | |||
} | |||