@@ -34,9 +34,11 @@ namespace Tensorflow | |||
public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int squeeze_dims = -1) | |||
=> gen_array_ops.squeeze(input, axis, name); | |||
public static Tensor one_hot(Tensor indices, int depth) | |||
{ | |||
throw new NotImplementedException("one_hot"); | |||
} | |||
public static Tensor one_hot(Tensor indices, int depth, | |||
Tensor on_value = null, | |||
Tensor off_value = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
int axis = -1, | |||
string name = null) => array_ops.one_hot(indices, depth, dtype: dtype, axis: axis, name: name); | |||
} | |||
} |
@@ -0,0 +1,12 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public static partial class tf | |||
{ | |||
public static _ControlDependenciesController control_dependencies(Operation[] control_inputs) | |||
=> ops.control_dependencies(control_inputs); | |||
} | |||
} |
@@ -36,6 +36,9 @@ namespace Tensorflow | |||
public static Tensor reduce_sum(Tensor input, int[] axis = null) | |||
=> math_ops.reduce_sum(input); | |||
public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) | |||
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name); | |||
public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) | |||
=> math_ops.cast(x, dtype, name); | |||
@@ -46,6 +46,18 @@ namespace Tensorflow | |||
public static Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null) | |||
=> gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name); | |||
public static Tensor bias_add(Tensor value, RefVariable bias, string data_format = null, string name = null) | |||
{ | |||
return Python.with(ops.name_scope(name, "BiasAdd", new { value, bias }), scope => | |||
{ | |||
name = scope; | |||
return gen_nn_ops.bias_add(value, bias, data_format: data_format, name: name); | |||
}); | |||
} | |||
public static Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null) | |||
=> nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name); | |||
} | |||
} | |||
} |
@@ -1,6 +1,7 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.Layers; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
@@ -3,11 +3,10 @@ using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Keras.Utils; | |||
using Tensorflow.Layers; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public class BatchNormalization : Layer | |||
public class BatchNormalization : Tensorflow.Layers.Layer | |||
{ | |||
private bool _USE_V2_BEHAVIOR = true; | |||
private float momentum; | |||
@@ -4,6 +4,7 @@ using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Operations.Activation; | |||
using static Tensorflow.tf; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
@@ -55,5 +56,26 @@ namespace Tensorflow.Keras.Layers | |||
built = true; | |||
} | |||
protected override Tensor call(Tensor inputs, Tensor training = null) | |||
{ | |||
Tensor outputs = null; | |||
var rank = inputs.rank; | |||
if(rank > 2) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
else | |||
{ | |||
outputs = gen_math_ops.mat_mul(inputs, kernel); | |||
} | |||
if (use_bias) | |||
outputs = nn.bias_add(outputs, bias); | |||
if (activation != null) | |||
return activation.Activate(outputs); | |||
return outputs; | |||
} | |||
} | |||
} |
@@ -2,9 +2,10 @@ | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Utils; | |||
namespace Tensorflow.Keras.Engine | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
/// <summary> | |||
/// Base layer class. | |||
@@ -106,7 +107,7 @@ namespace Tensorflow.Keras.Engine | |||
protected virtual Tensor call(Tensor inputs, Tensor training = null) | |||
{ | |||
throw new NotImplementedException("Layer.call"); | |||
return inputs; | |||
} | |||
protected virtual string _name_scope() |
@@ -2,11 +2,10 @@ | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Keras.Engine; | |||
namespace Tensorflow.Layers | |||
{ | |||
public class Layer : Keras.Engine.Layer | |||
public class Layer : Keras.Layers.Layer | |||
{ | |||
protected Graph _graph; | |||
@@ -108,5 +108,23 @@ namespace Tensorflow.Operations | |||
return _op.outputs; | |||
} | |||
/// <summary> | |||
/// Computes softmax cross entropy cost and gradients to backpropagate. | |||
/// </summary> | |||
/// <param name="features"></param> | |||
/// <param name="labels"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static (Tensor, Tensor) softmax_cross_entropy_with_logits(Tensor features, Tensor labels, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("SoftmaxCrossEntropyWithLogits", name: name, args: new | |||
{ | |||
features, | |||
labels | |||
}); | |||
return (_op.outputs[0], _op.outputs[1]); | |||
} | |||
} | |||
} |
@@ -46,6 +46,22 @@ namespace Tensorflow | |||
} | |||
} | |||
public static Tensor _autopacking_helper(Tensor[] list_or_tuple, TF_DataType dtype, string name) | |||
{ | |||
var must_pack = false; | |||
var converted_elems = new List<Tensor>(); | |||
return with(ops.name_scope(name), scope => | |||
{ | |||
foreach (var (i, elem) in enumerate(list_or_tuple)) | |||
{ | |||
converted_elems.Add(elem); | |||
must_pack = true; | |||
} | |||
return gen_array_ops.pack(converted_elems.ToArray(), name: scope); | |||
}); | |||
} | |||
public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) => expand_dims_v2(input, axis, name); | |||
private static Tensor expand_dims_v2(Tensor input, int axis, string name = null) => gen_array_ops.expand_dims(input, axis, name); | |||
@@ -109,6 +125,44 @@ namespace Tensorflow | |||
}); | |||
} | |||
public static Tensor one_hot(Tensor indices, int depth, | |||
Tensor on_value = null, | |||
Tensor off_value = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
int axis = -1, | |||
string name = null) | |||
{ | |||
return with(ops.name_scope(name, "one_hot", new { indices, depth, dtype }), scope => | |||
{ | |||
name = scope; | |||
var on_exists = false; | |||
var off_exists = false; | |||
var on_dtype = TF_DataType.DtInvalid; | |||
var off_dtype = TF_DataType.DtInvalid; | |||
if (dtype == TF_DataType.DtInvalid) | |||
dtype = TF_DataType.TF_FLOAT; | |||
if(!on_exists) | |||
{ | |||
on_value = ops.convert_to_tensor(1, dtype, name: "on_value"); | |||
on_dtype = dtype; | |||
} | |||
if (!off_exists) | |||
{ | |||
off_value = ops.convert_to_tensor(0, dtype, name = "off_value"); | |||
off_dtype = dtype; | |||
} | |||
return gen_array_ops.one_hot(indices, depth, | |||
on_value: on_value, | |||
off_value: off_value, | |||
axis: axis, | |||
name: name); | |||
}); | |||
} | |||
public static Tensor where(Tensor condition, Tensor x = null, Tensor y = null, string name = null) | |||
{ | |||
if( x == null && y == null) | |||
@@ -298,5 +352,8 @@ namespace Tensorflow | |||
return gen_array_ops.transpose(a, perm, name); | |||
}); | |||
} | |||
public static Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null) | |||
=> gen_array_ops.slice(input, begin, size, name: name); | |||
} | |||
} |
@@ -40,6 +40,13 @@ namespace Tensorflow | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor pack(Tensor[] values, int axis = 0, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("Pack", name: name, args: new { values, axis }); | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("Placeholder", name: name, args: new { dtype, shape }); | |||
@@ -126,6 +133,17 @@ namespace Tensorflow | |||
throw new NotImplementedException("where"); | |||
} | |||
public static Tensor one_hot(Tensor indices, int depth, | |||
Tensor on_value = null, | |||
Tensor off_value = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
int axis = -1, | |||
string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("OneHot", name, new { indices, depth, on_value, off_value, axis }); | |||
return _op.outputs[0]; | |||
} | |||
/// <summary> | |||
/// A placeholder op that passes through `input` when its output is not fed. | |||
/// </summary> | |||
@@ -174,12 +192,20 @@ namespace Tensorflow | |||
var _op = _op_def_lib._apply_op_helper("ZerosLike", name, new { x }); | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor stop_gradient(Tensor x, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("StopGradient", name, args: new { input = x, name }); | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size }); | |||
return _op.outputs[0]; | |||
} | |||
/// <summary> | |||
/// Removes dimensions of size 1 from the shape of a tensor. | |||
/// Given a tensor `input`, this operation returns a tensor of the same type with | |||
@@ -60,6 +60,11 @@ namespace Tensorflow | |||
return gen_math_ops.square(x, name); | |||
} | |||
public static Tensor subtract<Tx, Ty>(Tx x, Ty y, string name = null) | |||
{ | |||
return gen_math_ops.sub(x, y, name); | |||
} | |||
public static Tensor log(Tensor x, string name = null) | |||
{ | |||
return gen_math_ops.log(x, name); | |||
@@ -41,5 +41,38 @@ namespace Tensorflow | |||
return gen_nn_ops.bias_add(value, bias_tensor, data_format: data_format, name: name); | |||
}); | |||
} | |||
public static Tensor softmax_cross_entropy_with_logits_v2_helper(Tensor labels, | |||
Tensor logits, | |||
int axis = -1, | |||
string name = null) | |||
{ | |||
return Python.with(ops.name_scope(name, "softmax_cross_entropy_with_logits", new { }), scope => | |||
{ | |||
var precise_logits = logits; | |||
var input_rank = array_ops.rank(precise_logits); | |||
var shape = logits.getShape(); | |||
if (axis != -1) | |||
throw new NotImplementedException("softmax_cross_entropy_with_logits_v2_helper axis != -1"); | |||
var input_shape = array_ops.shape(precise_logits); | |||
// Do the actual op computation. | |||
// The second output tensor contains the gradients. We use it in | |||
// _CrossEntropyGrad() in nn_grad but not here. | |||
var (cost, unused_backprop) = gen_nn_ops.softmax_cross_entropy_with_logits(precise_logits, labels, name: name); | |||
// The output cost shape should be the input minus axis. | |||
var output_shape = array_ops.slice(input_shape, | |||
new int[] { 0 }, | |||
new Tensor[] { math_ops.subtract(input_rank, 1) }); | |||
cost = array_ops.reshape(cost, output_shape); | |||
return cost; | |||
}); | |||
} | |||
} | |||
} |
@@ -434,7 +434,8 @@ namespace Tensorflow | |||
public static Tensor internal_convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, | |||
string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid, | |||
bool as_ref = false) | |||
bool as_ref = false, | |||
string scope = null) | |||
{ | |||
if (dtype == TF_DataType.DtInvalid) | |||
dtype = preferred_dtype; | |||
@@ -443,6 +444,8 @@ namespace Tensorflow | |||
{ | |||
case Tensor tensor: | |||
return tensor; | |||
case Tensor[] tensors: | |||
return array_ops._autopacking_helper(tensors, dtype, name); | |||
case string str: | |||
return constant_op.constant(str, dtype: dtype, name: name); | |||
case string[] strArray: | |||
@@ -2,6 +2,7 @@ | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
using static Tensorflow.ops; | |||
namespace Tensorflow | |||
{ | |||
@@ -22,6 +22,9 @@ namespace TensorFlowNET.Examples.TextClassification | |||
private RefVariable embeddings; | |||
private Tensor x_emb; | |||
private Tensor x_expanded; | |||
private Tensor logits; | |||
private Tensor predictions; | |||
private Tensor loss; | |||
public VdCnn(int alphabet_size, int document_max_len, int num_class) | |||
{ | |||
@@ -55,8 +58,6 @@ namespace TensorFlowNET.Examples.TextClassification | |||
Tensor h_flat = null; | |||
Tensor fc1_out = null; | |||
Tensor fc2_out = null; | |||
Tensor logits = null; | |||
Tensor predictions = null; | |||
// First Convolution Layer | |||
with(tf.variable_scope("conv-0"), delegate | |||
@@ -116,6 +117,13 @@ namespace TensorFlowNET.Examples.TextClassification | |||
with(tf.name_scope("loss"), delegate | |||
{ | |||
var y_one_hot = tf.one_hot(y, num_class); | |||
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot)); | |||
var update_ops = tf.get_collection(ops.GraphKeys.UPDATE_OPS) as List<Operation>; | |||
with(tf.control_dependencies(update_ops.ToArray()), delegate | |||
{ | |||
}); | |||
}); | |||
} | |||