diff --git a/README.md b/README.md
index cc5c717f..f2f72d75 100644
--- a/README.md
+++ b/README.md
@@ -140,6 +140,7 @@ Read the docs & book [The Definitive Guide to Tensorflow.NET](https://tensorflow
* [CNN Text Classification](test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs)
* [Named Entity Recognition](test/TensorFlowNET.Examples/TextProcess/NER)
+* [Transfer Learning for Image Classification in InceptionV3](test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs)
### Contribute:
diff --git a/src/KerasNET.Core/Layers/Dense.cs b/src/KerasNET.Core/Layers/Dense.cs
index 66569882..c6c086a4 100644
--- a/src/KerasNET.Core/Layers/Dense.cs
+++ b/src/KerasNET.Core/Layers/Dense.cs
@@ -40,7 +40,7 @@ namespace Keras.Layers
var dot = tf.matmul(x, W);
if (this.activation != null)
dot = activation.Activate(dot);
- Console.WriteLine("Calling Layer \"" + name + "(" + np.array(dot.getShape().Dimensions).ToString() + ")\" ...");
+ Console.WriteLine("Calling Layer \"" + name + "(" + np.array(dot.GetShape().Dimensions).ToString() + ")\" ...");
return dot;
}
public TensorShape __shape__()
diff --git a/src/KerasNET.Core/Model.cs b/src/KerasNET.Core/Model.cs
index d1d889fc..cb960169 100644
--- a/src/KerasNET.Core/Model.cs
+++ b/src/KerasNET.Core/Model.cs
@@ -65,7 +65,7 @@ namespace Keras
#endregion
#region Model Graph Form Layer Stack
- var flow_shape = features.getShape();
+ var flow_shape = features.GetShape();
Flow = features;
for (int i = 0; i < layer_stack.Count; i++)
{
diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs
index 899aead5..a97a8dda 100644
--- a/src/TensorFlowNET.Core/APIs/tf.array.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.array.cs
@@ -49,6 +49,20 @@ namespace Tensorflow
Tensor off_value = null,
TF_DataType dtype = TF_DataType.DtInvalid,
int axis = -1,
- string name = null) => array_ops.one_hot(indices, depth, dtype: dtype, axis: axis, name: name);
+ string name = null) => array_ops.one_hot(indices, depth, dtype: dtype, axis: axis, name: name);
+
+ ///
+ /// A placeholder op that passes through `input` when its output is not fed.
+ ///
+ ///
+ /// A `Tensor`. The default value to produce when output is not fed.
+ ///
+ /// A `tf.TensorShape` or list of `int`s. The (possibly partial) shape of
+ /// the tensor.
+ ///
+ /// A name for the operation (optional).
+ /// A `Tensor`. Has the same type as `input`.
+ public static Tensor placeholder_with_default(T input, int[] shape, string name = null)
+ => gen_array_ops.placeholder_with_default(input, shape, name: name);
}
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index 53a3b097..bad41103 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -277,9 +277,21 @@ namespace Tensorflow
}
public static Tensor reduce_sum(Tensor input, int axis, int? reduction_indices = null)
- {
- return math_ops.reduce_sum(input, axis);
- }
+ => math_ops.reduce_sum(input, axis);
+
+ ///
+ /// Computes the maximum of elements across dimensions of a tensor.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor reduce_max(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ => math_ops.reduce_max(input_tensor, axis, keepdims, name);
+
+ public static Tensor reduce_min(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ => math_ops.reduce_min(input_tensor, axis, keepdims, name);
public static Tensor sigmoid(T x, string name = null)
=> math_ops.sigmoid(x, name: name);
diff --git a/src/TensorFlowNET.Core/Framework/common_shapes.py.cs b/src/TensorFlowNET.Core/Framework/common_shapes.py.cs
index 70bea7b0..e0a08184 100644
--- a/src/TensorFlowNET.Core/Framework/common_shapes.py.cs
+++ b/src/TensorFlowNET.Core/Framework/common_shapes.py.cs
@@ -37,7 +37,7 @@ namespace Tensorflow.Framework
public static bool has_fully_defined_shape(Tensor tensor)
{
- return tensor.getShape().is_fully_defined();
+ return tensor.GetShape().is_fully_defined();
}
}
}
diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs
index 799af2fa..64a1f5d9 100644
--- a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs
+++ b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs
@@ -19,7 +19,7 @@ namespace Tensorflow
return meta_graph_def;
}
- public static (Dictionary, ITensorOrOperation[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file,
+ public static (Dictionary, ITensorOrOperation[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file,
bool clear_devices = false,
string import_scope = "",
Dictionary input_map = null,
@@ -61,7 +61,7 @@ namespace Tensorflow
return_elements: return_elements);
// Restores all the other collections.
- var variable_objects = new Dictionary();
+ var variable_objects = new Dictionary();
foreach(var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key))
{
// Don't add unbound_inputs to the new graph.
@@ -83,11 +83,14 @@ namespace Tensorflow
{
foreach (var value in col.Value.BytesList.Value)
{
- RefVariable variable = null;
+ VariableV1 variable = null;
if (!variable_objects.ContainsKey(value))
{
var proto = VariableDef.Parser.ParseFrom(value);
- variable = new RefVariable(variable_def: proto, import_scope: scope_to_prepend_to_names);
+ if (proto.IsResource)
+ variable = new ResourceVariable(variable_def: proto, import_scope: scope_to_prepend_to_names);
+ else
+ variable = new RefVariable(variable_def: proto, import_scope: scope_to_prepend_to_names);
variable_objects[value] = variable;
}
variable = variable_objects[value];
@@ -126,9 +129,9 @@ namespace Tensorflow
}
}
- var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
+ var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
scope: scope_to_prepend_to_names);
- var var_list = new Dictionary();
+ var var_list = new Dictionary();
variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v);
return (var_list, imported_return_elements);
diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs
index cac1c85e..5b5f6d4c 100644
--- a/src/TensorFlowNET.Core/Gradients/math_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs
@@ -32,6 +32,27 @@ namespace Tensorflow.Gradients
return new Tensor[] { r1, r2 };
}
+ public static Tensor[] _DivNoNanGrad(Operation op, Tensor[] grads)
+ {
+ var grad = grads[0];
+ var x = op.inputs[0];
+ var y = op.inputs[1];
+ var sx = array_ops.shape(x);
+ var sy = array_ops.shape(y);
+ var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
+ x = math_ops.conj(x);
+ y = math_ops.conj(y);
+
+ var reduce_sum1 = math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx);
+ var reduce_sum2 = math_ops.reduce_sum(grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y), ry);
+
+ return new Tensor[]
+ {
+ array_ops.reshape(reduce_sum1, sx),
+ array_ops.reshape(reduce_sum2, sy)
+ };
+ }
+
///
/// Returns grad * exp(x).
///
diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
index 9a70b90d..a28d1bc5 100644
--- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
@@ -74,6 +74,23 @@ namespace Tensorflow.Gradients
};
}
+ public static Tensor[] _SparseSoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[] grads)
+ {
+ var sparse_softmax_grad_without_gradient = array_ops.prevent_gradient(
+ op.outputs[1],
+ message: "Currently there is no way to take the second " +
+ "derivative of sparse_softmax_cross_entropy_with_logits due to the fused " +
+ "implementation's interaction with tf.gradients()");
+
+ var grad_0 = grads[0];
+
+ return new Tensor[]
+ {
+ _BroadcastMul(grad_0, sparse_softmax_grad_without_gradient),
+ null
+ };
+ }
+
private static bool IsZero(Tensor g)
{
if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type))
diff --git a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
index eab9948b..98574339 100644
--- a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
+++ b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
@@ -24,6 +24,8 @@ namespace Tensorflow
return nn_grad._BiasAddGrad(oper, out_grads);
case "ConcatV2":
return array_grad._ConcatGradV2(oper, out_grads);
+ case "DivNoNan":
+ return math_grad._DivNoNanGrad(oper, out_grads);
case "Exp":
return math_grad._ExpGrad(oper, out_grads);
case "Identity":
@@ -62,6 +64,8 @@ namespace Tensorflow
return nn_grad._SoftmaxGrad(oper, out_grads);
case "SoftmaxCrossEntropyWithLogits":
return nn_grad._SoftmaxCrossEntropyWithLogitsGrad(oper, out_grads);
+ case "SparseSoftmaxCrossEntropyWithLogits":
+ return nn_grad._SparseSoftmaxCrossEntropyWithLogitsGrad(oper, out_grads);
case "Transpose":
return array_grad._TransposeGrad(oper, out_grads);
case "TopK":
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs
index c84684e6..ce4b2fc6 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.cs
@@ -70,6 +70,8 @@ namespace Tensorflow
public string _name_stack = "";
public string _graph_key;
+ public string _last_loss_reduction;
+
public Status Status { get; }
///
@@ -443,7 +445,7 @@ namespace Tensorflow
public void Dispose()
{
- c_api.TF_DeleteGraph(_handle);
+ // c_api.TF_DeleteGraph(_handle);
}
///
diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
index c93c07c0..d733127f 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
@@ -106,17 +106,17 @@ namespace Tensorflow.Keras.Layers
param_shape,
dtype: param_dtype,
initializer: moving_mean_initializer,
- synchronization: VariableSynchronization.ON_READ,
+ synchronization: VariableSynchronization.OnRead,
trainable: false,
- aggregation: VariableAggregation.MEAN);
+ aggregation: VariableAggregation.Mean);
moving_variance = add_weight("moving_variance",
shape: param_shape,
dtype: param_dtype,
initializer: moving_variance_initializer,
- synchronization: VariableSynchronization.ON_READ,
+ synchronization: VariableSynchronization.OnRead,
trainable: false,
- aggregation: VariableAggregation.MEAN);
+ aggregation: VariableAggregation.Mean);
if (renorm)
throw new NotImplementedException("build when renorm is true");
diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
index 4b8eebba..fe3dd36f 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
@@ -136,7 +136,7 @@ namespace Tensorflow.Keras.Layers
protected void _maybe_build(Tensor[] inputs)
{
var input_list = inputs;
- build(input_list[0].getShape());
+ build(input_list[0].GetShape());
}
protected virtual void build(TensorShape input_shape)
diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs
index 4b2d9cf3..eea80e04 100644
--- a/src/TensorFlowNET.Core/Layers/Layer.cs
+++ b/src/TensorFlowNET.Core/Layers/Layer.cs
@@ -77,12 +77,12 @@ namespace Tensorflow.Layers
TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null,
bool? trainable = null,
- VariableSynchronization synchronization = VariableSynchronization.AUTO,
- VariableAggregation aggregation = VariableAggregation.NONE)
+ VariableSynchronization synchronization = VariableSynchronization.Auto,
+ VariableAggregation aggregation = VariableAggregation.None)
{
var default_graph = ops.get_default_graph();
Graph init_graph = null;
- RefVariable[] existing_variables = null;
+ VariableV1[] existing_variables = null;
if (default_graph.building_function)
{
diff --git a/src/TensorFlowNET.Core/Operations/Losses/Reduction.cs b/src/TensorFlowNET.Core/Operations/Losses/Reduction.cs
new file mode 100644
index 00000000..f2421b32
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/Losses/Reduction.cs
@@ -0,0 +1,16 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class Reduction
+ {
+ public const string NONE = "none";
+ public const string SUM = "weighted_sum";
+ public const string SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size";
+ public const string MEAN = "weighted_mean";
+ public const string SUM_BY_NONZERO_WEIGHTS = "weighted_sum_by_nonzero_weights";
+ public const string SUM_OVER_NONZERO_WEIGHTS = SUM_BY_NONZERO_WEIGHTS;
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/Losses/Util.cs b/src/TensorFlowNET.Core/Operations/Losses/Util.cs
new file mode 100644
index 00000000..0bd390bd
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/Losses/Util.cs
@@ -0,0 +1,15 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Operations.Losses
+{
+ public class Util
+ {
+ public static void add_loss(Tensor loss, string loss_collection = ops.GraphKeys.LOSSES)
+ {
+ if (!string.IsNullOrEmpty(loss_collection))
+ ops.add_to_collection(loss_collection, loss);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs
index ced44b78..c07712a1 100644
--- a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs
+++ b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs
@@ -7,33 +7,134 @@ namespace Tensorflow
{
public class LossesImpl
{
+ public Tensor compute_weighted_loss(Tensor losses, Tensor weights = null, string scope = null,
+ string loss_collection = ops.GraphKeys.LOSSES, string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS)
+ {
+ return with(ops.name_scope(scope, default_name: "weighted_loss", (losses, weights)), delegate
+ {
+ // Save the `reduction` argument for loss normalization when distributing
+ // to multiple replicas. Used only for estimator + v1 optimizer flow.
+ ops.get_default_graph()._last_loss_reduction = reduction;
+
+ /*var dp = weights_broadcast_ops.assert_broadcastable(weights, losses);
+ with(ops.control_dependencies(dp), delegate
+ {
+
+ });*/
+
+ losses = ops.convert_to_tensor(losses);
+ var input_dtype = losses.dtype;
+ losses = math_ops.cast(losses, dtype: dtypes.float32);
+ weights = math_ops.cast(weights, dtype: dtypes.float32);
+ var weighted_losses = math_ops.multiply(losses, weights);
+ Tensor loss = null;
+ if (reduction == Reduction.NONE)
+ loss = weighted_losses;
+ else
+ {
+ loss = math_ops.reduce_sum(weighted_losses);
+ if (reduction == Reduction.MEAN)
+ loss = _safe_mean(
+ loss, math_ops.reduce_sum(array_ops.ones_like(losses) * weights));
+ else if (reduction == Reduction.SUM_BY_NONZERO_WEIGHTS ||
+ reduction == Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ loss = _safe_mean(loss, _num_present(losses, weights));
+ else if (reduction == Reduction.SUM_OVER_BATCH_SIZE)
+ loss = _safe_mean(loss, _num_elements(losses));
+ }
+
+ // Convert the result back to the input type.
+ loss = math_ops.cast(loss, input_dtype);
+ Operations.Losses.Util.add_loss(loss, loss_collection);
+ return loss;
+ });
+ }
+
+ public Tensor _safe_mean(Tensor losses, Tensor num_present)
+ {
+ var total_loss = math_ops.reduce_sum(losses);
+ return math_ops.div_no_nan(total_loss, num_present, name: "value");
+ }
+
+ public Tensor _num_elements(Tensor losses)
+ {
+ throw new NotImplementedException("LossesImpl._num_elements");
+ }
+
+ public Tensor _num_present(Tensor losses, Tensor weights, bool per_batch = false)
+ {
+ return with(ops.name_scope(null, default_name: "num_present", (losses, weights)), name_scope =>
+ {
+ string scope = name_scope;
+ weights = math_ops.cast(weights, dtype: dtypes.float32);
+ var present = array_ops.where(
+ math_ops.equal(weights, 0.0),
+ array_ops.zeros_like(weights),
+ array_ops.ones_like(weights));
+ present = weights_broadcast_ops.broadcast_weights(present, losses);
+
+ if (per_batch)
+ return math_ops.reduce_sum(
+ present,
+ axis: math_ops.range(1, array_ops.rank(present)),
+ keepdims: true,
+ name: scope);
+ return math_ops.reduce_sum(present, name: scope);
+ });
+ }
+
public Tensor sparse_softmax_cross_entropy(Tensor labels,
Tensor logits,
float weights = 1.0f,
- string scope = "",
- string loss_collection= "losses")
+ string scope = null,
+ string loss_collection= ops.GraphKeys.LOSSES,
+ string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS)
{
- with(ops.name_scope(scope,
+ return with(ops.name_scope(scope,
"sparse_softmax_cross_entropy_loss",
(logits, labels, weights)),
- namescope =>
+ name_scope =>
{
- (labels, logits, weights) = _remove_squeezable_dimensions(
- labels, logits, weights, expected_rank_diff: 1);
+ scope = name_scope;
+ Tensor weights_tensor = null;
+ (labels, logits, weights_tensor) = _remove_squeezable_dimensions(
+ labels, logits, weights, expected_rank_diff: 1);
+ var losses = nn_ops.sparse_softmax_cross_entropy_with_logits(labels: labels,
+ logits: logits,
+ name: "xentropy");
+ return compute_weighted_loss(losses, weights_tensor, scope, loss_collection, reduction: reduction);
});
-
- throw new NotImplementedException("sparse_softmax_cross_entropy");
}
- public (Tensor, Tensor, float) _remove_squeezable_dimensions(Tensor labels,
+ public (Tensor, Tensor, Tensor) _remove_squeezable_dimensions(Tensor labels,
Tensor predictions,
float weights = 0,
int expected_rank_diff = 0)
{
- (labels, predictions, weights) = confusion_matrix.remove_squeezable_dimensions(
+ (labels, predictions) = confusion_matrix.remove_squeezable_dimensions(
labels, predictions, expected_rank_diff: expected_rank_diff);
+ if(weights > 0)
+ {
+ var weights_tensor = ops.convert_to_tensor(weights);
+ var labels_rank = labels.GetShape().NDim;
+ var weights_shape = weights_tensor.GetShape();
+ var weights_rank = weights_shape.NDim;
+
+ if (labels_rank > -1 && weights_rank > -1)
+ {
+ // Use static rank.
+ var rank_diff = weights_rank - labels_rank;
+ if (rank_diff == 1)
+ weights = array_ops.squeeze(weights_tensor, new int[] { -1 });
+ return (labels, predictions, weights_tensor);
+ }
+
+ // Use dynamic rank.
+ throw new NotImplementedException("_remove_squeezable_dimensions dynamic rank");
+ }
+
throw new NotImplementedException("_remove_squeezable_dimensions");
}
}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs b/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs
index d86b5cb6..8f77e0ea 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs
@@ -18,7 +18,7 @@ namespace Tensorflow.Operations
string data_format = null)
{
var dilation_rate_tensor = ops.convert_to_tensor(dilation_rate, TF_DataType.TF_INT32, name: "dilation_rate");
- var rate_shape = dilation_rate_tensor.getShape();
+ var rate_shape = dilation_rate_tensor.GetShape();
var num_spatial_dims = rate_shape.Dimensions[0];
int starting_spatial_dim = -1;
if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC"))
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
index 9d0d49b1..9dfc882e 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
@@ -176,6 +176,42 @@ namespace Tensorflow.Operations
return (_op.outputs[0], _op.outputs[1]);
}
+ ///
+ /// Computes softmax cross entropy cost and gradients to backpropagate.
+ ///
+ ///
+ /// batch_size x num_classes matrix
+ ///
+ ///
+ /// batch_size vector with values in [0, num_classes).
+ /// This is the label for the given minibatch entry.
+ ///
+ ///
+ /// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseSoftmaxCrossEntropyWithLogits'.
+ ///
+ ///
+ /// Returns a tuple with multiple values, as follows:
+ /// loss : Per example loss (batch_size vector).
+ /// backprop : backpropagated gradients (batch_size x num_classes matrix).
+ /// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property.
+ ///
+ ///
+ /// Unlike SoftmaxCrossEntropyWithLogits, this operation does not accept
+ /// a matrix of label probabilities, but rather a single label per row
+ /// of features. This label is considered to have probability 1.0 for the
+ /// given row.
+ ///
+ /// Inputs are the logits, not probabilities.
+ ///
+ public static (Tensor loss, Tensor backprop) sparse_softmax_cross_entropy_with_logits(Tensor features, Tensor labels, string name = "SparseSoftmaxCrossEntropyWithLogits")
+ {
+ var op = _op_def_lib._apply_op_helper("SparseSoftmaxCrossEntropyWithLogits", name: name, args: new { features, labels });
+ int _idx = 0;
+ var loss = op.outputs[_idx++];
+ var backprop = op.outputs[_idx++];
+ return (loss, backprop);
+ }
+
///
/// Computes rectified linear: `max(features, 0)`.
///
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
index 8cde60cc..34adba6d 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
@@ -11,6 +11,9 @@ namespace Tensorflow
public static Tensor placeholder_with_default(T input, int[] shape, string name = null)
=> gen_array_ops.placeholder_with_default(input, shape, name);
+ public static Tensor prevent_gradient(Tensor input, string message = "", string name = null)
+ => gen_array_ops.prevent_gradient(input, message: message, name: name);
+
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
{
dtype = dtype.as_base_dtype();
diff --git a/src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs b/src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs
index fa2163e4..0cb54647 100644
--- a/src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs
+++ b/src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs
@@ -1,17 +1,47 @@
using System;
using System.Collections.Generic;
using System.Text;
+using static Tensorflow.Python;
namespace Tensorflow
{
public class confusion_matrix
{
- public static (Tensor, Tensor, float) remove_squeezable_dimensions(Tensor labels,
+ ///
+ /// Squeeze last dim if ranks differ from expected by exactly 1.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static (Tensor, Tensor) remove_squeezable_dimensions(Tensor labels,
Tensor predictions,
int expected_rank_diff = 0,
string name = null)
{
- throw new NotImplementedException("remove_squeezable_dimensions");
+ return with(ops.name_scope(name, default_name: "remove_squeezable_dimensions", (labels, predictions)), delegate
+ {
+ predictions = ops.convert_to_tensor(predictions);
+ labels = ops.convert_to_tensor(labels);
+ var predictions_shape = predictions.GetShape();
+ var predictions_rank = predictions_shape.NDim;
+ var labels_shape = labels.GetShape();
+ var labels_rank = labels_shape.NDim;
+ if(labels_rank > -1 && predictions_rank > -1)
+ {
+ // Use static rank.
+ var rank_diff = predictions_rank - labels_rank;
+ if (rank_diff == expected_rank_diff + 1)
+ predictions = array_ops.squeeze(predictions, new int[] { -1 });
+ else if (rank_diff == expected_rank_diff - 1)
+ labels = array_ops.squeeze(labels, new int[] { -1 });
+ return (labels, predictions);
+ }
+
+ // Use dynamic rank.
+ throw new NotImplementedException("remove_squeezable_dimensions dynamic rank");
+ });
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
index a5cd13b7..a83c636e 100644
--- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
@@ -100,6 +100,38 @@ namespace Tensorflow
return new Tensor(_op, 0, dtype);
}
+ ///
+ /// An identity op that triggers an error if a gradient is requested.
+ ///
+ ///
+ /// any tensor.
+ ///
+ ///
+ /// If specified, the created operation in the graph will be this one, otherwise it will be named 'PreventGradient'.
+ ///
+ ///
+ /// Will be printed in the error when anyone tries to differentiate
+ /// this operation.
+ ///
+ ///
+ /// the same input tensor.
+ /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result.
+ ///
+ ///
+ /// When executed in a graph, this op outputs its input tensor as-is.
+ ///
+ /// When building ops to compute gradients, the TensorFlow gradient system
+ /// will return an error when trying to lookup the gradient of this op,
+ /// because no gradient must ever be registered for this function. This
+ /// op exists to prevent subtle bugs from silently returning unimplemented
+ /// gradients in some corner cases.
+ ///
+ public static Tensor prevent_gradient(Tensor input, string message = "", string name = null)
+ {
+ var op = _op_def_lib._apply_op_helper("PreventGradient", name: name, args: new { input, message });
+ return op.output;
+ }
+
///
/// Return a tensor with the same shape and contents as the input tensor or value.
///
diff --git a/src/TensorFlowNET.Core/Operations/gen_logging_ops.cs b/src/TensorFlowNET.Core/Operations/gen_logging_ops.cs
index 0af59d63..8dfbf8f0 100644
--- a/src/TensorFlowNET.Core/Operations/gen_logging_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_logging_ops.cs
@@ -18,6 +18,13 @@ namespace Tensorflow
return _op;
}
+ public static Tensor histogram_summary(string tag, Tensor values, string name = null)
+ {
+ var dict = new Dictionary();
+ var op = _op_def_lib._apply_op_helper("HistogramSummary", name: name, args: new { tag, values });
+ return op.output;
+ }
+
///
/// Outputs a Summary protocol buffer with scalar values.
///
diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
index 74056057..e5670dd0 100644
--- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
@@ -37,7 +37,31 @@ namespace Tensorflow
///
public static Tensor arg_min(Tensor input, int dimension, TF_DataType output_type= TF_DataType.TF_INT64, string name= null)
=>_op_def_lib._apply_op_helper("ArgMin", name, args: new { input, dimension, output_type }).outputs[0];
-
+
+
+ ///
+ /// Returns 0 if the denominator is zero.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DivNoNan'.
+ ///
+ ///
+ /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result.
+ ///
+ ///
+ ///
+ /// *NOTE*: DivNoNan supports broadcasting. More about broadcasting
+ /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+ ///
+ public static Tensor div_no_nan(Tensor x, Tensor y, string name = null)
+ {
+ var op = _op_def_lib._apply_op_helper("DivNoNan", name: name, args: new { x, y });
+ return op.output;
+ }
///
/// Computes the mean of elements across dimensions of a tensor.
diff --git a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs
new file mode 100644
index 00000000..b4173f28
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs
@@ -0,0 +1,18 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public static class gen_resource_variable_ops
+ {
+ public static OpDefLibrary _op_def_lib = new OpDefLibrary();
+
+ public static Operation assign_variable_op(Tensor resource, Tensor value, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("AssignVariableOp", name, new { resource, value });
+
+ return _op;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs
index 8b94da2b..f37bd0dd 100644
--- a/src/TensorFlowNET.Core/Operations/math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/math_ops.cs
@@ -65,6 +65,39 @@ namespace Tensorflow
});
}
+ ///
+ /// Returns 0 if the denominator is zero.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// If specified, the created operation in the graph will be this one, otherwise it will be named 'DivNoNan'.
+ ///
+ ///
+ /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result.
+ ///
+ ///
+ ///
+ /// *NOTE*: DivNoNan supports broadcasting. More about broadcasting
+ /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+ ///
+ public static Tensor div_no_nan(Tensor x, Tensor y, string name = null)
+ {
+ return with(ops.name_scope(name, "div_no_nan", (x, y)), name_scope =>
+ {
+ name = name_scope;
+ x = ops.convert_to_tensor(x, name: "x");
+ y = ops.convert_to_tensor(y, name: "y", dtype: x.dtype.as_base_dtype());
+ var x_dtype = x.dtype.as_base_dtype();
+ var y_dtype = y.dtype.as_base_dtype();
+ if (x_dtype != y_dtype)
+ throw new TypeError($"x and y must have the same dtype, got {x_dtype} != {y_dtype}");
+ return gen_math_ops.div_no_nan(x, y, name: name);
+ });
+ }
+
public static Tensor equal(Tx x, Ty y, string name = null)
=> gen_math_ops.equal(x, y, name: name);
@@ -254,6 +287,13 @@ namespace Tensorflow
return _may_reduce_to_scalar(keepdims, axis, max);
}
+ public static Tensor reduce_min(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ {
+ var r = _ReductionDims(input_tensor, axis);
+ var min = gen_math_ops._min(input_tensor, r, keepdims, name);
+ return _may_reduce_to_scalar(keepdims, axis, min);
+ }
+
///
/// Casts a tensor to type `int32`.
///
diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs
index a5a2815c..bb6f643e 100644
--- a/src/TensorFlowNET.Core/Operations/nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs
@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations;
+using static Tensorflow.Python;
namespace Tensorflow
{
@@ -59,6 +60,47 @@ namespace Tensorflow
throw new NotImplementedException("_softmax helper");
}
+ ///
+ /// Computes sparse softmax cross entropy between `logits` and `labels`.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor sparse_softmax_cross_entropy_with_logits(Tensor labels = null,
+ Tensor logits = null, string name = null)
+ {
+ // Reshape logits and labels to rank 2.
+ return with(ops.name_scope(name, default_name: "SparseSoftmaxCrossEntropyWithLogits", (labels, logits)), delegate
+ {
+ labels = ops.convert_to_tensor(labels);
+ logits = ops.convert_to_tensor(logits);
+ var precise_logits = logits.dtype == TF_DataType.TF_HALF ? math_ops.cast(logits, dtypes.float32) : logits;
+
+ // Store label shape for result later.
+ var labels_static_shape = labels.GetShape();
+ var labels_shape = array_ops.shape(labels);
+ /*bool static_shapes_fully_defined = (
+ labels_static_shape.is_fully_defined() &&
+ logits.get_shape()[:-1].is_fully_defined());*/
+
+ // Check if no reshapes are required.
+ if(logits.GetShape().NDim == 2)
+ {
+ var (cost, _) = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
+ precise_logits, labels, name: name);
+ if (logits.dtype == dtypes.float16)
+ return math_ops.cast(cost, dtypes.float32);
+ else
+ return cost;
+ }
+
+ // Perform a check of the dynamic shapes if the static shapes are not fully
+ // defined.
+ throw new NotImplementedException("sparse_softmax_cross_entropy_with_logits");
+ });
+ }
+
public static Tensor softmax_cross_entropy_with_logits_v2_helper(Tensor labels,
Tensor logits,
int axis = -1,
@@ -68,7 +110,7 @@ namespace Tensorflow
{
var precise_logits = logits;
var input_rank = array_ops.rank(precise_logits);
- var shape = logits.getShape();
+ var shape = logits.GetShape();
if (axis != -1)
throw new NotImplementedException("softmax_cross_entropy_with_logits_v2_helper axis != -1");
diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
new file mode 100644
index 00000000..f8ecf7b9
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
@@ -0,0 +1,20 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ ///
+ /// tensorflow\python\ops\resource_variable_ops.py
+ ///
+ public class resource_variable_ops
+ {
+ public static ITensorOrOperation shape_safe_assign_variable_handle(Tensor handle, int[] shape, Tensor value, string name = null)
+ {
+ var value_tensor = ops.convert_to_tensor(value);
+ return gen_resource_variable_ops.assign_variable_op(handle,
+ value_tensor,
+ name: name);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/weights_broadcast_ops.cs b/src/TensorFlowNET.Core/Operations/weights_broadcast_ops.cs
new file mode 100644
index 00000000..f0afa1fe
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/weights_broadcast_ops.cs
@@ -0,0 +1,30 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using static Tensorflow.Python;
+
+namespace Tensorflow
+{
+ public class weights_broadcast_ops
+ {
+ public static Tensor broadcast_weights(Tensor weights, Tensor values)
+ {
+ return with(ops.name_scope(null, "broadcast_weights", (weights, values)), scope =>
+ {
+ values = ops.convert_to_tensor(values, name: "values");
+ weights = ops.convert_to_tensor(
+ weights, dtype: values.dtype.as_base_dtype(), name: "weights");
+
+ // Try static check for exact match.
+ var weights_shape = weights.GetShape();
+ var values_shape = values.GetShape();
+ if (weights_shape.is_fully_defined() &&
+ values_shape.is_fully_defined())
+ return weights;
+
+ return math_ops.multiply(
+ weights, array_ops.ones_like(values), name: scope);
+ });
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Protobuf/Variable.cs b/src/TensorFlowNET.Core/Protobuf/Variable.cs
index 18cdd6d2..9f24c138 100644
--- a/src/TensorFlowNET.Core/Protobuf/Variable.cs
+++ b/src/TensorFlowNET.Core/Protobuf/Variable.cs
@@ -25,26 +25,92 @@ namespace Tensorflow {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"Cih0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3ZhcmlhYmxlLnByb3RvEgp0",
- "ZW5zb3JmbG93ItQBCgtWYXJpYWJsZURlZhIVCg12YXJpYWJsZV9uYW1lGAEg",
+ "ZW5zb3JmbG93IsgCCgtWYXJpYWJsZURlZhIVCg12YXJpYWJsZV9uYW1lGAEg",
"ASgJEhoKEmluaXRpYWxfdmFsdWVfbmFtZRgGIAEoCRIYChBpbml0aWFsaXpl",
"cl9uYW1lGAIgASgJEhUKDXNuYXBzaG90X25hbWUYAyABKAkSOQoTc2F2ZV9z",
"bGljZV9pbmZvX2RlZhgEIAEoCzIcLnRlbnNvcmZsb3cuU2F2ZVNsaWNlSW5m",
- "b0RlZhITCgtpc19yZXNvdXJjZRgFIAEoCBIRCgl0cmFpbmFibGUYByABKAgi",
- "YAoQU2F2ZVNsaWNlSW5mb0RlZhIRCglmdWxsX25hbWUYASABKAkSEgoKZnVs",
- "bF9zaGFwZRgCIAMoAxISCgp2YXJfb2Zmc2V0GAMgAygDEhEKCXZhcl9zaGFw",
- "ZRgEIAMoA0JuChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtCDlZhcmlhYmxl",
- "UHJvdG9zUAFaPWdpdGh1Yi5jb20vdGVuc29yZmxvdy90ZW5zb3JmbG93L3Rl",
- "bnNvcmZsb3cvZ28vY29yZS9mcmFtZXdvcmv4AQFiBnByb3RvMw=="));
+ "b0RlZhITCgtpc19yZXNvdXJjZRgFIAEoCBIRCgl0cmFpbmFibGUYByABKAgS",
+ "PAoPc3luY2hyb25pemF0aW9uGAggASgOMiMudGVuc29yZmxvdy5WYXJpYWJs",
+ "ZVN5bmNocm9uaXphdGlvbhI0CgthZ2dyZWdhdGlvbhgJIAEoDjIfLnRlbnNv",
+ "cmZsb3cuVmFyaWFibGVBZ2dyZWdhdGlvbiJgChBTYXZlU2xpY2VJbmZvRGVm",
+ "EhEKCWZ1bGxfbmFtZRgBIAEoCRISCgpmdWxsX3NoYXBlGAIgAygDEhIKCnZh",
+ "cl9vZmZzZXQYAyADKAMSEQoJdmFyX3NoYXBlGAQgAygDKqwBChdWYXJpYWJs",
+ "ZVN5bmNocm9uaXphdGlvbhIhCh1WQVJJQUJMRV9TWU5DSFJPTklaQVRJT05f",
+ "QVVUTxAAEiEKHVZBUklBQkxFX1NZTkNIUk9OSVpBVElPTl9OT05FEAESJQoh",
+ "VkFSSUFCTEVfU1lOQ0hST05JWkFUSU9OX09OX1dSSVRFEAISJAogVkFSSUFC",
+ "TEVfU1lOQ0hST05JWkFUSU9OX09OX1JFQUQQAyqeAQoTVmFyaWFibGVBZ2dy",
+ "ZWdhdGlvbhIdChlWQVJJQUJMRV9BR0dSRUdBVElPTl9OT05FEAASHAoYVkFS",
+ "SUFCTEVfQUdHUkVHQVRJT05fU1VNEAESHQoZVkFSSUFCTEVfQUdHUkVHQVRJ",
+ "T05fTUVBThACEisKJ1ZBUklBQkxFX0FHR1JFR0FUSU9OX09OTFlfRklSU1Rf",
+ "UkVQTElDQRADQi8KGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IOVmFyaWFi",
+ "bGVQcm90b3NQAfgBAWIGcHJvdG8z"));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
- new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
- new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.VariableDef), global::Tensorflow.VariableDef.Parser, new[]{ "VariableName", "InitialValueName", "InitializerName", "SnapshotName", "SaveSliceInfoDef", "IsResource", "Trainable" }, null, null, null),
+ new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Tensorflow.VariableSynchronization), typeof(global::Tensorflow.VariableAggregation), }, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.VariableDef), global::Tensorflow.VariableDef.Parser, new[]{ "VariableName", "InitialValueName", "InitializerName", "SnapshotName", "SaveSliceInfoDef", "IsResource", "Trainable", "Synchronization", "Aggregation" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SaveSliceInfoDef), global::Tensorflow.SaveSliceInfoDef.Parser, new[]{ "FullName", "FullShape", "VarOffset", "VarShape" }, null, null, null)
}));
}
#endregion
}
+ #region Enums
+ ///
+ /// Indicates when a distributed variable will be synced.
+ ///
+ public enum VariableSynchronization {
+ ///
+ /// `AUTO`: Indicates that the synchronization will be determined by the
+ /// current `DistributionStrategy` (eg. With `MirroredStrategy` this would be
+ /// `ON_WRITE`).
+ ///
+ [pbr::OriginalName("VARIABLE_SYNCHRONIZATION_AUTO")] Auto = 0,
+ ///
+ /// `NONE`: Indicates that there will only be one copy of the variable, so
+ /// there is no need to sync.
+ ///
+ [pbr::OriginalName("VARIABLE_SYNCHRONIZATION_NONE")] None = 1,
+ ///
+ /// `ON_WRITE`: Indicates that the variable will be updated across devices
+ /// every time it is written.
+ ///
+ [pbr::OriginalName("VARIABLE_SYNCHRONIZATION_ON_WRITE")] OnWrite = 2,
+ ///
+ /// `ON_READ`: Indicates that the variable will be aggregated across devices
+ /// when it is read (eg. when checkpointing or when evaluating an op that uses
+ /// the variable).
+ ///
+ [pbr::OriginalName("VARIABLE_SYNCHRONIZATION_ON_READ")] OnRead = 3,
+ }
+
+ ///
+ /// Indicates how a distributed variable will be aggregated.
+ ///
+ public enum VariableAggregation {
+ ///
+ /// `NONE`: This is the default, giving an error if you use a
+ /// variable-update operation with multiple replicas.
+ ///
+ [pbr::OriginalName("VARIABLE_AGGREGATION_NONE")] None = 0,
+ ///
+ /// `SUM`: Add the updates across replicas.
+ ///
+ [pbr::OriginalName("VARIABLE_AGGREGATION_SUM")] Sum = 1,
+ ///
+ /// `MEAN`: Take the arithmetic mean ("average") of the updates across
+ /// replicas.
+ ///
+ [pbr::OriginalName("VARIABLE_AGGREGATION_MEAN")] Mean = 2,
+ ///
+ /// `ONLY_FIRST_REPLICA`: This is for when every replica is performing the same
+ /// update, but we only want to perform the update once. Used, e.g., for the
+ /// global step counter.
+ ///
+ [pbr::OriginalName("VARIABLE_AGGREGATION_ONLY_FIRST_REPLICA")] OnlyFirstReplica = 3,
+ }
+
+ #endregion
+
#region Messages
///
/// Protocol buffer representing a Variable.
@@ -81,6 +147,8 @@ namespace Tensorflow {
saveSliceInfoDef_ = other.saveSliceInfoDef_ != null ? other.saveSliceInfoDef_.Clone() : null;
isResource_ = other.isResource_;
trainable_ = other.trainable_;
+ synchronization_ = other.synchronization_;
+ aggregation_ = other.aggregation_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}
@@ -187,6 +255,34 @@ namespace Tensorflow {
}
}
+ /// Field number for the "synchronization" field.
+ public const int SynchronizationFieldNumber = 8;
+ private global::Tensorflow.VariableSynchronization synchronization_ = 0;
+ ///
+ /// Indicates when a distributed variable will be synced.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::Tensorflow.VariableSynchronization Synchronization {
+ get { return synchronization_; }
+ set {
+ synchronization_ = value;
+ }
+ }
+
+ /// Field number for the "aggregation" field.
+ public const int AggregationFieldNumber = 9;
+ private global::Tensorflow.VariableAggregation aggregation_ = 0;
+ ///
+ /// Indicates how a distributed variable will be aggregated.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::Tensorflow.VariableAggregation Aggregation {
+ get { return aggregation_; }
+ set {
+ aggregation_ = value;
+ }
+ }
+
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as VariableDef);
@@ -207,6 +303,8 @@ namespace Tensorflow {
if (!object.Equals(SaveSliceInfoDef, other.SaveSliceInfoDef)) return false;
if (IsResource != other.IsResource) return false;
if (Trainable != other.Trainable) return false;
+ if (Synchronization != other.Synchronization) return false;
+ if (Aggregation != other.Aggregation) return false;
return Equals(_unknownFields, other._unknownFields);
}
@@ -220,6 +318,8 @@ namespace Tensorflow {
if (saveSliceInfoDef_ != null) hash ^= SaveSliceInfoDef.GetHashCode();
if (IsResource != false) hash ^= IsResource.GetHashCode();
if (Trainable != false) hash ^= Trainable.GetHashCode();
+ if (Synchronization != 0) hash ^= Synchronization.GetHashCode();
+ if (Aggregation != 0) hash ^= Aggregation.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
@@ -261,6 +361,14 @@ namespace Tensorflow {
output.WriteRawTag(56);
output.WriteBool(Trainable);
}
+ if (Synchronization != 0) {
+ output.WriteRawTag(64);
+ output.WriteEnum((int) Synchronization);
+ }
+ if (Aggregation != 0) {
+ output.WriteRawTag(72);
+ output.WriteEnum((int) Aggregation);
+ }
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
@@ -290,6 +398,12 @@ namespace Tensorflow {
if (Trainable != false) {
size += 1 + 1;
}
+ if (Synchronization != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Synchronization);
+ }
+ if (Aggregation != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Aggregation);
+ }
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
@@ -325,6 +439,12 @@ namespace Tensorflow {
if (other.Trainable != false) {
Trainable = other.Trainable;
}
+ if (other.Synchronization != 0) {
+ Synchronization = other.Synchronization;
+ }
+ if (other.Aggregation != 0) {
+ Aggregation = other.Aggregation;
+ }
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}
@@ -367,6 +487,14 @@ namespace Tensorflow {
Trainable = input.ReadBool();
break;
}
+ case 64: {
+ synchronization_ = (global::Tensorflow.VariableSynchronization) input.ReadEnum();
+ break;
+ }
+ case 72: {
+ aggregation_ = (global::Tensorflow.VariableAggregation) input.ReadEnum();
+ break;
+ }
}
}
}
diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs
index bd1b7d83..42d2063a 100644
--- a/src/TensorFlowNET.Core/Python.cs
+++ b/src/TensorFlowNET.Core/Python.cs
@@ -154,6 +154,18 @@ namespace Tensorflow
}
}
+ public static IEnumerable<(TKey, TValue)> enumerate(Dictionary values)
+ {
+ foreach (var item in values)
+ yield return (item.Key, item.Value);
+ }
+
+ public static IEnumerable<(TKey, TValue)> enumerate(KeyValuePair[] values)
+ {
+ foreach (var item in values)
+ yield return (item.Key, item.Value);
+ }
+
public static IEnumerable<(int, T)> enumerate(IList values)
{
for (int i = 0; i < values.Count; i++)
diff --git a/src/TensorFlowNET.Core/Summaries/Summary.cs b/src/TensorFlowNET.Core/Summaries/Summary.cs
index 43828238..4fcc5666 100644
--- a/src/TensorFlowNET.Core/Summaries/Summary.cs
+++ b/src/TensorFlowNET.Core/Summaries/Summary.cs
@@ -15,6 +15,14 @@ namespace Tensorflow.Summaries
flush_secs: flush_secs, filename_suffix: filename_suffix,
session: session);
+ public Tensor histogram(string name, Tensor tensor, string[] collections = null, string family = null)
+ {
+ var (tag, scope) = summary_scope(name, family: family, values: new Tensor[] { tensor }, default_name: "HistogramSummary");
+ var val = gen_logging_ops.histogram_summary(tag: tag, values: tensor, name: scope);
+ collect(val, collections?.ToList(), new List { ops.GraphKeys.SUMMARIES });
+ return val;
+ }
+
public Tensor merge_all(string key = ops.GraphKeys.SUMMARIES, string scope= null, string name= null)
{
var summary_ops = ops.get_collection(key, scope: scope);
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index e5b9e786..5e3d611b 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -84,12 +84,12 @@ namespace Tensorflow
return shape.Select(x => (int)x).ToArray();
}
- public TensorShape getShape()
+ public TensorShape GetShape()
{
return tensor_util.to_shape(shape);
}
- public void setShape(Shape shape)
+ public void SetShape(Shape shape)
{
this.shape = shape.Dimensions;
}
diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
index 84db6717..9498cead 100644
--- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs
+++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
@@ -12,6 +12,13 @@ namespace Tensorflow
///
public class TensorShape : Shape
{
+ public TensorShape(TensorShapeProto proto)
+ {
+ if (proto.UnknownRank) return;
+
+ ReShape(proto.Dim.Select(x => (int)x.Size).ToArray());
+ }
+
public TensorShape(params int[] dims) : base(dims)
{
@@ -25,5 +32,10 @@ namespace Tensorflow
{
return Dimensions != null && Dimensions.Count(x => x < 1) == 0;
}
+
+ public bool is_compatible_with(TensorShape shape2)
+ {
+ throw new NotImplementedException("TensorShape is_compatible_with");
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs
index cd5a46c7..444f384d 100644
--- a/src/TensorFlowNET.Core/Tensors/dtypes.cs
+++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs
@@ -8,6 +8,7 @@ namespace Tensorflow
{
public static TF_DataType int8 = TF_DataType.TF_INT8;
public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32?
+ public static TF_DataType float16 = TF_DataType.TF_HALF;
public static Type as_numpy_datatype(this TF_DataType type)
{
diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs
index 541ba8e7..f068f910 100644
--- a/src/TensorFlowNET.Core/Train/Optimizer.cs
+++ b/src/TensorFlowNET.Core/Train/Optimizer.cs
@@ -246,6 +246,9 @@ namespace Tensorflow
case List values:
var_list = values;
break;
+ case List values:
+ var_list = values.Select(x => x as RefVariable).ToList();
+ break;
}
var processors = var_list.Select(v => optimizer._get_processor(v)).ToList();
diff --git a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
index e366b796..ea52a6de 100644
--- a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
+++ b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
@@ -58,7 +58,7 @@ namespace Tensorflow
return gen_io_ops.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray());
}
- public virtual SaverDef _build_internal(RefVariable[] names_to_saveables,
+ public virtual SaverDef _build_internal(VariableV1[] names_to_saveables,
bool reshape = false,
bool sharded = false,
int max_to_keep = 5,
@@ -111,6 +111,12 @@ namespace Tensorflow
var cols = graph.get_collection(collection_type);
switch (cols)
{
+ case List values:
+ foreach (var element in values) ;
+ break;
+ case List values:
+ foreach (var element in values) ;
+ break;
case List values:
foreach (var element in values) ;
break;
@@ -166,10 +172,14 @@ namespace Tensorflow
string name = "restore_all")
{
var all_tensors = bulk_restore(filename_tensor, saveables, preferred_shard, restore_sequentially);
- var assign_ops = new List();
+ var assign_ops = new List();
int idx = 0;
- foreach(var saveable in saveables)
+ // Load and optionally reshape on the CPU, as string tensors are not
+ // available on the GPU.
+ // TODO(touts): Re-enable restore on GPU when we can support annotating
+ // string tensors as "HostMemory" inputs.
+ foreach (var saveable in saveables)
{
List shapes = null;
if (reshape)
@@ -179,7 +189,8 @@ namespace Tensorflow
var saveable_tensors = all_tensors.Skip(idx).Take(saveable.specs.Length);
idx += saveable.specs.Length;
- assign_ops.Add(saveable.restore(saveable_tensors.ToArray(), shapes == null ? null : shapes.ToArray()));
+ var restored = saveable.restore(saveable_tensors.ToArray(), shapes == null ? null : shapes.ToArray());
+ assign_ops.Add(restored as Operation);
}
return control_flow_ops.group(assign_ops.ToArray(), name: name);
diff --git a/src/TensorFlowNET.Core/Train/Saving/ISaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/ISaverBuilder.cs
index ef25f81a..d16d2a21 100644
--- a/src/TensorFlowNET.Core/Train/Saving/ISaverBuilder.cs
+++ b/src/TensorFlowNET.Core/Train/Saving/ISaverBuilder.cs
@@ -10,7 +10,7 @@ namespace Tensorflow
Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially);
- SaverDef _build_internal(RefVariable[] names_to_saveables,
+ SaverDef _build_internal(VariableV1[] names_to_saveables,
bool reshape = false,
bool sharded = false,
int max_to_keep = 5,
diff --git a/src/TensorFlowNET.Core/Train/Saving/ResourceVariableSaveable.cs b/src/TensorFlowNET.Core/Train/Saving/ResourceVariableSaveable.cs
new file mode 100644
index 00000000..a504acc6
--- /dev/null
+++ b/src/TensorFlowNET.Core/Train/Saving/ResourceVariableSaveable.cs
@@ -0,0 +1,34 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class ResourceVariableSaveable : SaveableObject
+ {
+ string _var_device;
+ int[] _var_shape;
+ Tensor handle_op;
+
+ public ResourceVariableSaveable(Tensor var, string slice_spec, string name)
+ {
+ _var_device = var.Device;
+ _var_shape = var.shape;
+ handle_op = var.op.inputs[0];
+ var tensor = var;
+ var spec = new SaveSpec(tensor, slice_spec, name, dtype: var.dtype);
+
+ op = var;
+ specs = new SaveSpec[] { spec };
+ this.name = name;
+ }
+
+ public override ITensorOrOperation restore(Tensor[] restored_tensors, TensorShape[] restored_shapes = null)
+ {
+ var restored_tensor = restored_tensors[0];
+ restored_tensor = array_ops.identity(restored_tensor);
+ return resource_variable_ops.shape_safe_assign_variable_handle(
+ handle_op, _var_shape, restored_tensor);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs b/src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs
index e381cf14..90a4fff7 100644
--- a/src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs
+++ b/src/TensorFlowNET.Core/Train/Saving/SaveableObject.cs
@@ -28,7 +28,7 @@ namespace Tensorflow
this.name = name;
}
- public virtual Tensor restore(Tensor[] restored_tensors, TensorShape[] restored_shapes = null)
+ public virtual ITensorOrOperation restore(Tensor[] restored_tensors, TensorShape[] restored_shapes = null)
{
var restored_tensor = restored_tensors[0];
return gen_state_ops.assign(op,
diff --git a/src/TensorFlowNET.Core/Train/Saving/Saver.cs b/src/TensorFlowNET.Core/Train/Saving/Saver.cs
index c223ca97..a6ced8fa 100644
--- a/src/TensorFlowNET.Core/Train/Saving/Saver.cs
+++ b/src/TensorFlowNET.Core/Train/Saving/Saver.cs
@@ -11,7 +11,7 @@ namespace Tensorflow
///
public class Saver
{
- private RefVariable[] _var_list;
+ private VariableV1[] _var_list;
private bool _reshape;
private bool _sharded;
private int _max_to_keep;
@@ -32,7 +32,7 @@ namespace Tensorflow
private Dictionary _last_checkpoints;
private Dictionary _checkpoints_to_be_deleted;
- public Saver(RefVariable[] var_list = null,
+ public Saver(VariableV1[] var_list = null,
bool reshape = false,
bool sharded = false,
int max_to_keep = 5,
diff --git a/src/TensorFlowNET.Core/Train/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Train/Saving/saveable_object_util.py.cs
index 81d5e590..83031922 100644
--- a/src/TensorFlowNET.Core/Train/Saving/saveable_object_util.py.cs
+++ b/src/TensorFlowNET.Core/Train/Saving/saveable_object_util.py.cs
@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
+using static Tensorflow.Python;
namespace Tensorflow
{
@@ -12,15 +13,15 @@ namespace Tensorflow
///
///
///
- public static SaveableObject[] validate_and_slice_inputs(RefVariable[] names_to_saveables)
+ public static SaveableObject[] validate_and_slice_inputs(VariableV1[] names_to_saveables)
{
var names_to_saveables_dict = op_list_to_dict(names_to_saveables);
var saveables = new List();
var seen_ops = new List();
- foreach (var item in names_to_saveables_dict)
+ foreach (var (name, op) in enumerate(names_to_saveables_dict))
{
- foreach (var converted_saveable_object in saveable_objects_for_op(item.Value, item.Key))
+ foreach (var converted_saveable_object in saveable_objects_for_op(op, name))
_add_saveable(saveables, seen_ops, converted_saveable_object);
}
return saveables.ToArray();
@@ -51,25 +52,31 @@ namespace Tensorflow
{
ops.init_scope();
var variable = ops.internal_convert_to_tensor(op, as_ref: true);
- if (variable.op.type == "VariableV2")
+ if (variable.op.type == "Variable" ||
+ variable.op.type == "VariableV2" ||
+ variable.op.type == "AutoReloadVariable")
yield return new ReferenceVariableSaveable(variable, "", name);
+ else
+ yield return new ResourceVariableSaveable(variable, "", name);
}
}
- public static Dictionary op_list_to_dict(RefVariable[] op_list, bool convert_variable_to_tensor = true)
+ public static Dictionary op_list_to_dict(VariableV1[] op_list, bool convert_variable_to_tensor = true)
{
op_list = op_list.OrderBy(x => x.name).ToArray();
var names_to_saveables = new Dictionary();
foreach(var var in op_list)
{
+ bool resource_or_ref_variable = var is RefVariable || var is ResourceVariable;
if (false)
{
throw new NotImplementedException("op_list_to_dict");
}
else
{
- if(false) // eager
+ // Variables (reference and resource) have an _in_graph_mode property
+ if (false) // eager
{
}
@@ -80,11 +87,14 @@ namespace Tensorflow
if (convert_variable_to_tensor)
{
- tensor = ops.internal_convert_to_tensor(var, as_ref: true);
+ if (var is ResourceVariable)
+ tensor = var.graph_element;
+ else
+ tensor = ops.internal_convert_to_tensor(var, as_ref: true);
}
- if (var.op.type == "ReadVariableOp")
- name = var.op.inputs[0].op.name;
+ if (tensor.op.type == "ReadVariableOp")
+ name = tensor.op.inputs[0].op.name;
else
name = var.op.name;
diff --git a/src/TensorFlowNET.Core/Train/Saving/saver.py.cs b/src/TensorFlowNET.Core/Train/Saving/saver.py.cs
index d5d1ff47..303f41a4 100644
--- a/src/TensorFlowNET.Core/Train/Saving/saver.py.cs
+++ b/src/TensorFlowNET.Core/Train/Saving/saver.py.cs
@@ -37,7 +37,7 @@ namespace Tensorflow
///
public static Saver _create_saver_from_imported_meta_graph(MetaGraphDef meta_graph_def,
string import_scope,
- Dictionary imported_vars)
+ Dictionary imported_vars)
{
if(meta_graph_def.SaverDef != null)
{
diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs
index 376e4a9f..9ad86b09 100644
--- a/src/TensorFlowNET.Core/Variables/RefVariable.cs
+++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs
@@ -19,13 +19,13 @@ namespace Tensorflow
public bool _save_slice_info;
private Operation _initializer_op;
- public Operation initializer => _initializer_op;
- public Operation op => _variable.op;
+ public override Operation initializer => _initializer_op;
+ public override Operation op => _variable.op;
public Graph graph => _variable.graph;
public TF_DataType dtype => _variable.dtype;
public TensorShape shape => tensor_util.to_shape(_variable.shape);
- public string name => _variable.name;
+ public override string name => _variable.name;
public RefVariable(object initial_value = null,
bool trainable = true,
@@ -153,7 +153,7 @@ namespace Tensorflow
// Manually overrides the variable's shape with the initial value's.
if (validate_shape)
{
- var initial_value_shape = _initial_value.getShape();
+ var initial_value_shape = _initial_value.GetShape();
if (!initial_value_shape.is_fully_defined())
throw new ValueError($"initial_value must have a shape specified: {_initial_value}");
}
@@ -176,7 +176,7 @@ namespace Tensorflow
_snapshot = gen_array_ops.identity(_variable, name = "read");
}
- ops.add_to_collections(collections, this);
+ ops.add_to_collections(collections, this as VariableV1);
});
}
diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs
new file mode 100644
index 00000000..57222ba8
--- /dev/null
+++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs
@@ -0,0 +1,118 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ ///
+ /// Variable based on resource handles.
+ ///
+ public class ResourceVariable : VariableV1
+ {
+ bool _in_graph_mode;
+ Tensor _handle;
+ TensorShape _shape;
+ public TensorShape shape => _shape;
+ string _handle_name;
+ string _unique_id;
+ Operation _initializer_op;
+ public override Operation initializer => _initializer_op;
+ Tensor _initial_value;
+ bool _trainable;
+ public bool tranable => _trainable;
+ Tensor _cached_value;
+ Tensor _graph_element;
+ public override Tensor graph_element => _graph_element;
+ TF_DataType _dtype;
+ public TF_DataType dtype => _dtype;
+ public override string name => _handle.name;
+ public string Device => _handle.Device;
+ public Graph Graph => _handle.graph;
+ public override Operation op => _handle.op;
+
+ public ResourceVariable(object initial_value = null,
+ bool trainable = true,
+ List collections = null,
+ bool validate_shape = true,
+ string caching_device = "",
+ string name = null,
+ VariableDef variable_def = null,
+ TF_DataType dtype = TF_DataType.DtInvalid,
+ string import_scope = "") : base(initial_value,
+ trainable,
+ collections,
+ validate_shape,
+ caching_device,
+ name,
+ dtype)
+ {
+ if (variable_def != null)
+ {
+ if (initial_value != null)
+ throw new ValueError("variable_def and initial_value are mutually exclusive.");
+ _init_from_proto(variable_def, import_scope: import_scope);
+ }
+ else
+ {
+ throw new NotImplementedException("ResourceVariable _init_from_args");
+ //_init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype);
+ }
+ }
+
+ private void _init_from_proto(VariableDef variable_def, string import_scope = null)
+ {
+ _in_graph_mode = true;
+ if (!variable_def.IsResource)
+ throw new ValueError("Trying to restore Variable as ResourceVariable.");
+
+ // Create from variable_def.
+ var g = ops.get_default_graph();
+ var prepend_name_scope = ops.prepend_name_scope(variable_def.VariableName, import_scope: import_scope);
+ _handle = g.as_graph_element(prepend_name_scope) as Tensor;
+ _shape = new TensorShape(_handle.op.get_attr("shape") as TensorShapeProto);
+ _handle_name = _handle.name;
+ _unique_id = _handle_name;
+ prepend_name_scope = ops.prepend_name_scope(variable_def.InitializerName, import_scope: import_scope);
+ _initializer_op = g.as_graph_element(prepend_name_scope) as Operation;
+ if (!string.IsNullOrEmpty(variable_def.InitialValueName))
+ {
+ prepend_name_scope = ops.prepend_name_scope(variable_def.InitialValueName, import_scope: import_scope);
+ _initial_value = g.as_graph_element(prepend_name_scope) as Tensor;
+ }
+
+ _trainable = variable_def.Trainable;
+ /*var (synchronization, aggregation, trainable) =
+ variables.validate_synchronization_aggregation_trainable(
+ variable_def.Synchronization,
+ variable_def.Aggregation,
+ variable_def.Trainable,
+ variable_def.VariableName);*/
+ if (!string.IsNullOrEmpty(variable_def.SnapshotName))
+ {
+ prepend_name_scope = ops.prepend_name_scope(variable_def.SnapshotName, import_scope: import_scope);
+ var snapshot = g.as_graph_element(prepend_name_scope) as Tensor;
+ if (snapshot.op.type != "ReadVariableOp")
+ _cached_value = snapshot;
+ while (snapshot.op.type != "ReadVariableOp")
+ snapshot = snapshot.op.inputs[0];
+ _graph_element = snapshot;
+ }
+ else
+ {
+ throw new NotImplementedException("SnapshotName _init_from_proto");
+ }
+
+ if (variable_def.SaveSliceInfoDef != null)
+ {
+ throw new NotImplementedException("SaveSliceInfoDef _init_from_proto");
+ }
+
+ _dtype = dtypes.as_tf_dtype((DataType)_handle.op.get_attr("dtype"));
+ }
+
+ public override string ToString()
+ {
+ return $"tf.ResourceVariable '{name}' shape={shape} dtype={dtype}";
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Variables/VariableAggregation.cs b/src/TensorFlowNET.Core/Variables/VariableAggregation.cs
deleted file mode 100644
index 3f7e4ff0..00000000
--- a/src/TensorFlowNET.Core/Variables/VariableAggregation.cs
+++ /dev/null
@@ -1,14 +0,0 @@
-using System;
-using System.Collections.Generic;
-using System.Text;
-
-namespace Tensorflow
-{
- public enum VariableAggregation
- {
- NONE = 0,
- SUM = 1,
- MEAN = 2,
- ONLY_FIRST_REPLICA = 3 // ONLY_FIRST_TOWER
- }
-}
diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs
index e20622cd..9e97e373 100644
--- a/src/TensorFlowNET.Core/Variables/VariableScope.cs
+++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs
@@ -36,8 +36,8 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.DtInvalid,
object initializer = null, // IInitializer or Tensor
bool? trainable = null,
- VariableSynchronization synchronization = VariableSynchronization.AUTO,
- VariableAggregation aggregation= VariableAggregation.NONE)
+ VariableSynchronization synchronization = VariableSynchronization.Auto,
+ VariableAggregation aggregation= VariableAggregation.None)
{
string full_name = !string.IsNullOrEmpty(this._name) ? this._name + "/" + name : name;
return with(ops.name_scope(null), scope =>
diff --git a/src/TensorFlowNET.Core/Variables/VariableSynchronization.cs b/src/TensorFlowNET.Core/Variables/VariableSynchronization.cs
deleted file mode 100644
index 8a16f285..00000000
--- a/src/TensorFlowNET.Core/Variables/VariableSynchronization.cs
+++ /dev/null
@@ -1,17 +0,0 @@
-using System;
-using System.Collections.Generic;
-using System.Text;
-
-namespace Tensorflow
-{
- ///
- /// Indicates when a distributed variable will be synced.
- ///
- public enum VariableSynchronization
- {
- AUTO = 0,
- NONE = 1,
- ON_WRITE = 2,
- ON_READ = 3
- }
-}
diff --git a/src/TensorFlowNET.Core/Variables/VariableV1.cs b/src/TensorFlowNET.Core/Variables/VariableV1.cs
index 8d5a5fc8..1e9aed72 100644
--- a/src/TensorFlowNET.Core/Variables/VariableV1.cs
+++ b/src/TensorFlowNET.Core/Variables/VariableV1.cs
@@ -17,6 +17,11 @@ namespace Tensorflow
///
public class VariableV1
{
+ public virtual string name { get; }
+ public virtual Tensor graph_element { get; }
+ public virtual Operation op { get; }
+ public virtual Operation initializer { get; }
+
public VariableV1(object initial_value = null,
bool trainable = true,
List collections = null,
diff --git a/src/TensorFlowNET.Core/Variables/_VariableStore.cs b/src/TensorFlowNET.Core/Variables/_VariableStore.cs
index 33365cbd..97c6b912 100644
--- a/src/TensorFlowNET.Core/Variables/_VariableStore.cs
+++ b/src/TensorFlowNET.Core/Variables/_VariableStore.cs
@@ -27,8 +27,8 @@ namespace Tensorflow
bool? reuse = null,
bool? trainable = null,
bool validate_shape = true,
- VariableSynchronization synchronization = VariableSynchronization.AUTO,
- VariableAggregation aggregation = VariableAggregation.NONE)
+ VariableSynchronization synchronization = VariableSynchronization.Auto,
+ VariableAggregation aggregation = VariableAggregation.None)
{
dtype = dtype.as_base_dtype();
trainable = variable_scope._get_trainable_value(synchronization, trainable);
@@ -49,8 +49,8 @@ namespace Tensorflow
object initializer = null,
bool? trainable = null,
bool validate_shape = true,
- VariableSynchronization synchronization = VariableSynchronization.AUTO,
- VariableAggregation aggregation = VariableAggregation.NONE)
+ VariableSynchronization synchronization = VariableSynchronization.Auto,
+ VariableAggregation aggregation = VariableAggregation.None)
{
bool is_scalar = !(shape is null) && shape.NDim == 0;
@@ -98,8 +98,8 @@ namespace Tensorflow
bool? trainable = null,
bool validate_shape = false,
bool? use_resource = null,
- VariableSynchronization synchronization = VariableSynchronization.AUTO,
- VariableAggregation aggregation = VariableAggregation.NONE)
+ VariableSynchronization synchronization = VariableSynchronization.Auto,
+ VariableAggregation aggregation = VariableAggregation.None)
{
bool initializing_from_value = false;
if (use_resource == null)
@@ -161,8 +161,8 @@ namespace Tensorflow
bool? trainable = null,
bool validate_shape = false,
bool? use_resource = null,
- VariableSynchronization synchronization = VariableSynchronization.AUTO,
- VariableAggregation aggregation = VariableAggregation.NONE)
+ VariableSynchronization synchronization = VariableSynchronization.Auto,
+ VariableAggregation aggregation = VariableAggregation.None)
{
if (use_resource == null)
use_resource = false;
diff --git a/src/TensorFlowNET.Core/Variables/tf.variable.cs b/src/TensorFlowNET.Core/Variables/tf.variable.cs
index aac58a9a..d4f71b74 100644
--- a/src/TensorFlowNET.Core/Variables/tf.variable.cs
+++ b/src/TensorFlowNET.Core/Variables/tf.variable.cs
@@ -17,8 +17,8 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.DtInvalid,
object initializer = null, // IInitializer or Tensor
bool? trainable = null,
- VariableSynchronization synchronization = VariableSynchronization.AUTO,
- VariableAggregation aggregation = VariableAggregation.NONE)
+ VariableSynchronization synchronization = VariableSynchronization.Auto,
+ VariableAggregation aggregation = VariableAggregation.None)
{
var scope = Tensorflow.variable_scope.get_variable_scope();
var store = Tensorflow.variable_scope._get_default_variable_store();
diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs
index bfdfbd24..ba508e4f 100644
--- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs
+++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs
@@ -136,8 +136,8 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.DtInvalid,
bool validate_shape = false,
bool ? use_resource = null,
- VariableSynchronization synchronization = VariableSynchronization.AUTO,
- VariableAggregation aggregation = VariableAggregation.NONE)
+ VariableSynchronization synchronization = VariableSynchronization.Auto,
+ VariableAggregation aggregation = VariableAggregation.None)
{
trainable = _get_trainable_value(synchronization, trainable);
if (!use_resource.HasValue)
@@ -208,7 +208,7 @@ namespace Tensorflow
public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = true)
{
- if (synchronization == VariableSynchronization.ON_READ)
+ if (synchronization == VariableSynchronization.OnRead)
{
if (trainable.Value)
throw new ValueError("Synchronization value can be set to " +
diff --git a/src/TensorFlowNET.Core/Variables/variables.py.cs b/src/TensorFlowNET.Core/Variables/variables.py.cs
index bd38fc77..26066d3b 100644
--- a/src/TensorFlowNET.Core/Variables/variables.py.cs
+++ b/src/TensorFlowNET.Core/Variables/variables.py.cs
@@ -21,17 +21,17 @@ namespace Tensorflow
///
///
///
- public static RefVariable[] _all_saveable_objects(string scope = "")
+ public static VariableV1[] _all_saveable_objects(string scope = "")
{
- var all = new List();
+ var all = new List();
var collection = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope);
if(collection != null)
- all.AddRange(collection as List);
+ all.AddRange(collection as List);
collection = ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope);
if (collection != null)
- all.AddRange(collection as List);
+ all.AddRange(collection as List);
return all.ToArray();
}
@@ -47,11 +47,11 @@ namespace Tensorflow
/// special tokens filters by prefix.
///
/// A list of `Variable` objects.
- public static List global_variables(string scope = null)
+ public static List global_variables(string scope = null)
{
var result = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope);
- return result == null ? new List() : result as List;
+ return result == null ? new List() : result as List;
}
///
@@ -60,7 +60,7 @@ namespace Tensorflow
/// List of `Variable` objects to initialize.
/// Optional name for the returned operation.
/// An Op that run the initializers of all the specified variables.
- public static Operation variables_initializer(RefVariable[] var_list, string name = "init")
+ public static Operation variables_initializer(VariableV1[] var_list, string name = "init")
{
if (var_list.Length > 0)
return control_flow_ops.group(var_list.Select(x => x.initializer).ToArray(), name);
diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs
index 00763dc2..bb3b7188 100644
--- a/src/TensorFlowNET.Core/ops.GraphKeys.cs
+++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs
@@ -25,7 +25,7 @@ namespace Tensorflow
///
/// Key to collect losses
///
- public static string LOSSES = "losses";
+ public const string LOSSES = "losses";
///
/// Key to collect Variable objects that are global (shared across machines).
diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs
index 08a7dc97..53f03375 100644
--- a/src/TensorFlowNET.Core/ops.py.cs
+++ b/src/TensorFlowNET.Core/ops.py.cs
@@ -471,6 +471,8 @@ namespace Tensorflow
return array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name);
case RefVariable varVal:
return varVal._TensorConversionFunction(as_ref: as_ref);
+ case ResourceVariable varVal:
+ return null;
case object[] objects:
return array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name);
default:
diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs
index b4b97c90..32b045e6 100644
--- a/src/TensorFlowNET.Core/tf.cs
+++ b/src/TensorFlowNET.Core/tf.cs
@@ -11,6 +11,7 @@ namespace Tensorflow
public static TF_DataType bytes = TF_DataType.TF_STRING;
public static TF_DataType int16 = TF_DataType.TF_INT16;
public static TF_DataType int32 = TF_DataType.TF_INT32;
+ public static TF_DataType int64 = TF_DataType.TF_INT64;
public static TF_DataType float16 = TF_DataType.TF_HALF;
public static TF_DataType float32 = TF_DataType.TF_FLOAT;
public static TF_DataType float64 = TF_DataType.TF_DOUBLE;
diff --git a/test/KerasNET.Test/BaseTests.cs b/test/KerasNET.Test/BaseTests.cs
index 6a716c5f..6ab72276 100644
--- a/test/KerasNET.Test/BaseTests.cs
+++ b/test/KerasNET.Test/BaseTests.cs
@@ -15,8 +15,8 @@ namespace Keras.Test
{
var dense_1 = new Dense(1, name: "dense_1", activation: tf.nn.relu());
var input = new Tensor(np.array(new int[] { 3 }));
- dense_1.__build__(input.getShape());
- var outputShape = dense_1.output_shape(input.getShape());
+ dense_1.__build__(input.GetShape());
+ var outputShape = dense_1.output_shape(input.GetShape());
var a = (int[])(outputShape.Dimensions);
var b = (int[])(new int[] { 1 });
var _a = np.array(a);
diff --git a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs
index 61ba33b3..8f4a6ecf 100644
--- a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs
+++ b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs
@@ -36,6 +36,7 @@ namespace TensorFlowNET.Examples.ImageProcess
string tfhub_module = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3";
float testing_percentage = 0.1f;
float validation_percentage = 0.1f;
+ float learning_rate = 0.01f;
Tensor resized_image_tensor;
Dictionary> image_lists;
int how_many_training_steps = 200;
@@ -43,21 +44,38 @@ namespace TensorFlowNET.Examples.ImageProcess
int train_batch_size = 100;
int validation_batch_size = 100;
int intermediate_store_frequency = 0;
+ int class_count = 0;
const int MAX_NUM_IMAGES_PER_CLASS = 134217727;
+ Operation train_step;
+ Tensor final_tensor;
+ Tensor bottleneck_input;
+ Tensor cross_entropy;
+ Tensor ground_truth_input;
public bool Run()
{
PrepareData();
- var graph = tf.Graph().as_default();
- tf.train.import_meta_graph("graph/InceptionV3.meta");
- Tensor bottleneck_tensor = graph.OperationByName("module_apply_default/hub_output/feature_vector/SpatialSqueeze");
+ // Set up the pre-trained graph.
+ var (graph, bottleneck_tensor, resized_image_tensor, wants_quantization) =
+ create_module_graph();
+
+ // Add the new layer that we'll be training.
+ with(graph.as_default(), delegate
+ {
+ (train_step, cross_entropy, bottleneck_input,
+ ground_truth_input, final_tensor) = add_final_retrain_ops(
+ class_count, "final_result", bottleneck_tensor,
+ wants_quantization, is_training: true);
+ });
+
+ /*Tensor bottleneck_tensor = graph.OperationByName("module_apply_default/hub_output/feature_vector/SpatialSqueeze");
Tensor resized_image_tensor = graph.OperationByName("Placeholder");
Tensor final_tensor = graph.OperationByName("final_result");
Tensor ground_truth_input = graph.OperationByName("input/GroundTruthInput");
- Operation train_step = graph.OperationByName("train/GradientDescent");
+ train_step = graph.OperationByName("train/GradientDescent");
Tensor bottleneck_input = graph.OperationByName("input/BottleneckInputPlaceholder");
- Tensor cross_entropy = graph.OperationByName("cross_entropy/sparse_softmax_cross_entropy_loss/value");
+ Tensor cross_entropy = graph.OperationByName("cross_entropy/sparse_softmax_cross_entropy_loss/value");*/
var sw = new Stopwatch();
@@ -87,7 +105,7 @@ namespace TensorFlowNET.Examples.ImageProcess
// Create a train saver that is used to restore values into an eval graph
// when exporting models.
- var train_saver = tf.train.Saver();
+ // var train_saver = tf.train.Saver();
for (int i = 0; i < how_many_training_steps; i++)
{
@@ -147,12 +165,180 @@ namespace TensorFlowNET.Examples.ImageProcess
}
// After training is complete, force one last save of the train checkpoint.
- train_saver.save(sess, CHECKPOINT_NAME);
+ // train_saver.save(sess, CHECKPOINT_NAME);
+
+ // We've completed all our training, so run a final test evaluation on
+ // some new images we haven't used before.
+ run_final_eval(sess, null, class_count, image_lists,
+ jpeg_data_tensor, decoded_image_tensor, resized_image_tensor,
+ bottleneck_tensor);
});
return false;
}
+ ///
+ /// Runs a final evaluation on an eval graph using the test data set.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ private void run_final_eval(Session train_session, object module_spec, int class_count,
+ Dictionary> image_lists,
+ Tensor jpeg_data_tensor, Tensor decoded_image_tensor,
+ Tensor resized_image_tensor, Tensor bottleneck_tensor)
+ {
+ /*var (eval_session, _, bottleneck_input, ground_truth_input, evaluation_step,
+ prediction) = build_eval_session(module_spec, class_count);*/
+ }
+
+ private void build_eval_session(int class_count)
+ {
+ // If quantized, we need to create the correct eval graph for exporting.
+ var (eval_graph, bottleneck_tensor, resized_input_tensor, wants_quantization) = create_module_graph();
+ var eval_sess = tf.Session(graph: eval_graph);
+
+ with(eval_graph.as_default(), graph =>
+ {
+
+
+ });
+ }
+
+ ///
+ /// Adds a new softmax and fully-connected layer for training and eval.
+ ///
+ /// We need to retrain the top layer to identify our new classes, so this function
+ /// adds the right operations to the graph, along with some variables to hold the
+ /// weights, and then sets up all the gradients for the backward pass.
+ ///
+ /// The set up for the softmax and fully-connected layers is based on:
+ /// https://www.tensorflow.org/tutorials/mnist/beginners/index.html
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ private (Operation, Tensor, Tensor, Tensor, Tensor) add_final_retrain_ops(int class_count, string final_tensor_name,
+ Tensor bottleneck_tensor, bool quantize_layer, bool is_training)
+ {
+ var (batch_size, bottleneck_tensor_size) = (bottleneck_tensor.GetShape().Dimensions[0], bottleneck_tensor.GetShape().Dimensions[1]);
+ with(tf.name_scope("input"), scope =>
+ {
+ bottleneck_input = tf.placeholder_with_default(
+ bottleneck_tensor,
+ shape: bottleneck_tensor.GetShape().Dimensions,
+ name: "BottleneckInputPlaceholder");
+
+ ground_truth_input = tf.placeholder(tf.int64, new TensorShape(batch_size), name: "GroundTruthInput");
+ });
+
+ // Organizing the following ops so they are easier to see in TensorBoard.
+ string layer_name = "final_retrain_ops";
+ Tensor logits = null;
+ with(tf.name_scope(layer_name), scope =>
+ {
+ RefVariable layer_weights = null;
+ with(tf.name_scope("weights"), delegate
+ {
+ var initial_value = tf.truncated_normal(new int[] { bottleneck_tensor_size, class_count }, stddev: 0.001f);
+ layer_weights = tf.Variable(initial_value, name: "final_weights");
+ variable_summaries(layer_weights);
+ });
+
+ RefVariable layer_biases = null;
+ with(tf.name_scope("biases"), delegate
+ {
+ layer_biases = tf.Variable(tf.zeros((class_count)), name: "final_biases");
+ variable_summaries(layer_biases);
+ });
+
+ with(tf.name_scope("Wx_plus_b"), delegate
+ {
+ logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases;
+ tf.summary.histogram("pre_activations", logits);
+ });
+ });
+
+ final_tensor = tf.nn.softmax(logits, name: final_tensor_name);
+
+ // The tf.contrib.quantize functions rewrite the graph in place for
+ // quantization. The imported model graph has already been rewritten, so upon
+ // calling these rewrites, only the newly added final layer will be
+ // transformed.
+ if (quantize_layer)
+ {
+ throw new NotImplementedException("quantize_layer");
+ /*if (is_training)
+ tf.contrib.quantize.create_training_graph();
+ else
+ tf.contrib.quantize.create_eval_graph();*/
+ }
+
+ tf.summary.histogram("activations", final_tensor);
+
+ // If this is an eval graph, we don't need to add loss ops or an optimizer.
+ if (!is_training)
+ return (null, null, bottleneck_input, ground_truth_input, final_tensor);
+
+ Tensor cross_entropy_mean = null;
+ with(tf.name_scope("cross_entropy"), delegate
+ {
+ cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
+ labels: ground_truth_input, logits: logits);
+ });
+
+ tf.summary.scalar("cross_entropy", cross_entropy_mean);
+
+ with(tf.name_scope("train"), delegate
+ {
+ var optimizer = tf.train.GradientDescentOptimizer(learning_rate);
+ train_step = optimizer.minimize(cross_entropy_mean);
+ });
+
+ return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input,
+ final_tensor);
+ }
+
+ private void variable_summaries(RefVariable var)
+ {
+ with(tf.name_scope("summaries"), delegate
+ {
+ var mean = tf.reduce_mean(var);
+ tf.summary.scalar("mean", mean);
+ Tensor stddev = null;
+ with(tf.name_scope("stddev"), delegate {
+ stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)));
+ });
+ tf.summary.scalar("stddev", stddev);
+ tf.summary.scalar("max", tf.reduce_max(var));
+ tf.summary.scalar("min", tf.reduce_min(var));
+ tf.summary.histogram("histogram", var);
+ });
+ }
+
+ private (Graph, Tensor, Tensor, bool) create_module_graph()
+ {
+ var (height, width) = (299, 299);
+
+ return with(tf.Graph().as_default(), graph =>
+ {
+ tf.train.import_meta_graph("graph/InceptionV3.meta");
+ Tensor resized_input_tensor = graph.OperationByName("Placeholder"); //tf.placeholder(tf.float32, new TensorShape(-1, height, width, 3));
+ // var m = hub.Module(module_spec);
+ Tensor bottleneck_tensor = graph.OperationByName("module_apply_default/hub_output/feature_vector/SpatialSqueeze");// m(resized_input_tensor);
+ var wants_quantization = false;
+ return (graph, bottleneck_tensor, resized_input_tensor, wants_quantization);
+ });
+ }
+
private (NDArray, long[], string[]) get_random_cached_bottlenecks(Session sess, Dictionary> image_lists,
int how_many, string category, string bottleneck_dir, string image_dir,
Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor,
@@ -161,7 +347,7 @@ namespace TensorFlowNET.Examples.ImageProcess
var bottlenecks = new List();
var ground_truths = new List();
var filenames = new List();
- int class_count = image_lists.Keys.Count;
+ class_count = image_lists.Keys.Count;
foreach (var unused_i in range(how_many))
{
int label_index = new Random().Next(class_count);
@@ -353,7 +539,7 @@ namespace TensorFlowNET.Examples.ImageProcess
// Look at the folder structure, and create lists of all the images.
image_lists = create_image_lists();
- var class_count = len(image_lists);
+ class_count = len(image_lists);
if (class_count == 0)
print($"No valid folders of images found at {image_dir}");
if (class_count == 1)
diff --git a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs
index 50cceddb..4b4623dc 100644
--- a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs
+++ b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs
@@ -33,7 +33,7 @@ namespace TensorFlowNET.UnitTest.nn_test
var y_np = this._ZeroFraction(x_np);
var x_tf = constant_op.constant(x_np);
- x_tf.setShape(x_shape);
+ x_tf.SetShape(x_shape);
var y_tf = nn_impl.zero_fraction(x_tf);
var y_tf_np = self.evaluate(y_tf);