diff --git a/src/TensorFlowNET.Core/Framework/IndexedSlices.cs b/src/TensorFlowNET.Core/Framework/IndexedSlices.cs
index 516d0163..ef422968 100644
--- a/src/TensorFlowNET.Core/Framework/IndexedSlices.cs
+++ b/src/TensorFlowNET.Core/Framework/IndexedSlices.cs
@@ -11,10 +11,26 @@ namespace Tensorflow.Framework
{
Tensor _values;
public Tensor values => _values;
+ Tensor _indices;
+ public Tensor indices => _indices;
+ Tensor _dense_shape;
+ public Tensor dense_shape => _dense_shape;
+
+ public string name => _values.name;
+
+ public string device => _values.Device;
+
+ public Operation op => _values.op;
+
+ public TF_DataType dtype => _values.dtype;
+
+ public Graph graph => _values.graph;
public IndexedSlices(Tensor values, Tensor indices, Tensor dense_shape = null)
{
-
+ _values = values;
+ _indices = indices;
+ _dense_shape = dense_shape;
}
public static implicit operator Tensor(IndexedSlices indexedSlices)
diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs
index aa74f7f1..58dd7e4a 100644
--- a/src/TensorFlowNET.Core/Gradients/array_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs
@@ -83,13 +83,13 @@ namespace Tensorflow.Gradients
new Tensor[] { non_neg_concat_dim, tf.constant(0) },
new Tensor[] { tf.constant(1), tf.constant(-1) });
var squeeze_sizes = array_ops.squeeze(slice);
- out_grads = gen_ops.split(grad, squeeze_sizes, non_neg_concat_dim).ToList();
+ out_grads = gen_array_ops.split(grad, squeeze_sizes, non_neg_concat_dim).ToList();
}
else
{
- var offset = gen_ops.concat_offset(non_neg_concat_dim, sizes);
+ var offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes);
foreach (var (begin, size) in zip(offset, sizes))
- out_grads.Add(gen_ops.slice(grad, begin, size));
+ out_grads.Add(gen_array_ops.slice(grad, begin, size));
}
return (end_value_index <= dim_index ?
@@ -129,7 +129,7 @@ namespace Tensorflow.Gradients
if (fully_known)
return sizes;
else
- return gen_ops.shape_n(inputs);
+ return gen_array_ops.shape_n(inputs);
}
///
diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs
index c68fdfa3..12a50479 100644
--- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs
+++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs
@@ -93,10 +93,7 @@ namespace Tensorflow
{
// generate gradient subgraph for op.
var op = queue.Dequeue();
- if(op.name == "embedding/ExpandDims")
- {
- }
_maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops);
//if (loop_state != null)
//loop_state.EnterGradWhileContext(op, before: true);
@@ -311,9 +308,10 @@ namespace Tensorflow
// Aggregate multiple gradients, and convert [] to None.
if (out_grad.Count > 0)
{
+ string used = "";
if (out_grad.Count < 2)
{
- string used = "nop";
+ used = "nop";
if (out_grad.Count == 0)
{
throw new ValueError("_AggregatedGrads out_grad.Length == 0");
@@ -321,6 +319,11 @@ namespace Tensorflow
return_grads[i] = out_grad[0];
}
+ else
+ {
+ used = "add_n";
+ out_grads[i] = new List { _MultiDeviceAddN(out_grad.ToArray(), gradient_uid) };
+ }
}
else
{
@@ -331,6 +334,38 @@ namespace Tensorflow
return return_grads;
}
+ ///
+ /// Adds tensors from potentially multiple devices.
+ ///
+ ///
+ ///
+ ///
+ private static Tensor _MultiDeviceAddN(Tensor[] tensor_list, string gradient_uid)
+ {
+ // Basic function structure comes from control_flow_ops.group().
+ // Sort tensors according to their devices.
+ var tensors_on_device = new Dictionary>();
+
+ foreach (var tensor in tensor_list)
+ {
+ if (!tensors_on_device.ContainsKey(tensor.Device))
+ tensors_on_device[tensor.Device] = new List();
+
+ tensors_on_device[tensor.Device].Add(tensor);
+ }
+
+ // For each device, add the tensors on that device first.
+ var summands = new List();
+ foreach(var dev in tensors_on_device.Keys)
+ {
+ var tensors = tensors_on_device[dev];
+ ops._colocate_with_for_gradient(tensors[0].op, gradient_uid, ignore_existing: true);
+ summands.Add(math_ops.add_n(tensors.ToArray()));
+ }
+
+ return math_ops.add_n(summands.ToArray());
+ }
+
///
/// The set of ops that terminate the gradient computation.
///
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
index 4a6ee82d..c997f179 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
@@ -276,6 +276,9 @@ namespace Tensorflow
});
}
+ public static (Tensor, Tensor) unique(Tensor x, TF_DataType out_idx = TF_DataType.TF_INT32, string name = null)
+ => gen_array_ops.unique(x, out_idx: out_idx, name: name);
+
public static Tensor where(Tensor condition, object x = null, object y = null, string name = null)
{
if( x == null && y == null)
diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
index fb980259..087a2430 100644
--- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
@@ -26,6 +26,13 @@ namespace Tensorflow
return _op.outputs[0];
}
+ public static Tensor[] concat_offset(Tensor concat_dim, Tensor[] shape, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("ConcatOffset", name: name, args: new { concat_dim, shape });
+
+ return _op.outputs;
+ }
+
///
/// Returns a diagonal tensor with a given diagonal values.
///
@@ -205,6 +212,21 @@ namespace Tensorflow
return _op.outputs[0];
}
+ ///
+ /// Finds unique elements in a 1-D tensor.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static (Tensor, Tensor) unique(Tensor x, TF_DataType out_idx = TF_DataType.TF_INT32, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("Unique", name, new { x, out_idx });
+ // TODO
+ throw new NotImplementedException("_result = _UniqueOutput._make(_result)");
+ // return _op.outputs[0];
+ }
+
public static Tensor where()
{
throw new NotImplementedException("where");
@@ -271,6 +293,26 @@ namespace Tensorflow
return _op.outputs[0];
}
+ ///
+ /// Return a slice from 'input'
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size });
+ return _op.outputs[0];
+ }
+
+ public static Tensor[] split(Tensor axis, Tensor value, int num_split, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split });
+ return _op.outputs;
+ }
+
public static Tensor tile(Tensor input, Tensor multiples, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Tile", name, new { input, multiples });
diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
index e5670dd0..763a4bd8 100644
--- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
@@ -16,6 +16,19 @@ namespace Tensorflow
return _op.outputs[0];
}
+ ///
+ /// Add all input tensors element wise.
+ ///
+ ///
+ ///
+ ///
+ public static Tensor add_n(Tensor[] inputs, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("AddN", name, args: new { inputs });
+
+ return _op.outputs[0];
+ }
+
///
/// Returns the index with the largest value across dimensions of a tensor.
///
@@ -198,6 +211,20 @@ namespace Tensorflow
return _op.outputs[0];
}
+ ///
+ /// Computes the sum along segments of a tensor.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor unsorted_segment_sum(Tensor data, Tensor segment_ids, Tensor num_segments, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("UnsortedSegmentSum", name, new { data, segment_ids, num_segments });
+ return _op.outputs[0];
+ }
+
public static Tensor tan(Tensor x, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Tan", name, args: new { x });
diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs
index 580ae33c..29e9d671 100644
--- a/src/TensorFlowNET.Core/Operations/math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/math_ops.cs
@@ -44,8 +44,8 @@ namespace Tensorflow
return array_ops.identity(values, name: name);
return values;
}
- throw new NotImplementedException("math_ops add_n n > 1");
- // return gen_math_ops.add_n(inputs, name: name);
+
+ return gen_math_ops.add_n(inputs, name: name);
}
public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
@@ -126,6 +126,9 @@ namespace Tensorflow
public static Tensor equal(Tx x, Ty y, string name = null)
=> gen_math_ops.equal(x, y, name: name);
+ public static Tensor sqrt(Tensor x, string name = null)
+ => gen_math_ops.sqrt(x, name: name);
+
public static Tensor multiply(Tx x, Ty y, string name = null)
=> gen_math_ops.mul(x, y, name: name);
@@ -319,6 +322,17 @@ namespace Tensorflow
return _may_reduce_to_scalar(keepdims, axis, min);
}
+ ///
+ /// Computes the sum along segments of a tensor.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor unsorted_segment_sum(Tensor data, Tensor segment_ids, Tensor num_segments, string name = null)
+ => gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments, name: name);
+
///
/// Casts a tensor to type `int32`.
///
diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
index cecdbd38..63f440c1 100644
--- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
+++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
@@ -5,10 +5,10 @@
TensorFlow.NET
Tensorflow
1.14.0
- 0.8.1
+ 0.8.2
Haiping Chen
SciSharp STACK
- true
+ false
Apache 2.0
https://github.com/SciSharp/TensorFlow.NET
git
@@ -17,14 +17,15 @@
TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#
Google's TensorFlow full binding in .NET Standard.
Docs: https://tensorflownet.readthedocs.io
- 0.8.1.0
+ 0.8.2.0
Changes since v0.8:
1. Remove global static graph instance.
2. Provide custom gradient function.
-3. Add gradient function for Conv2D.
+3. Add gradient function for Conv2D.
+4. Fix bug for Transfer Learning example.
7.2
- 0.8.1.0
+ 0.8.2.0
@@ -42,6 +43,10 @@ Docs: https://tensorflownet.readthedocs.io
+
+
+
+
diff --git a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs
index b6063234..56e69881 100644
--- a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs
+++ b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs
@@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using System.Text;
+using Tensorflow.Framework;
+using static Tensorflow.Python;
namespace Tensorflow.Train
{
@@ -10,9 +12,10 @@ namespace Tensorflow.Train
///
public class AdamOptimizer : Optimizer
{
- private float _beta1;
- private float _beta2;
- private float _epsilon;
+ float _beta1;
+ float _beta2;
+ float _epsilon;
+ Tensor _lr_t, _beta1_t, _beta2_t, _epsilon_t;
public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, string name = "Adam")
: base(learning_rate, use_locking, name)
@@ -21,5 +24,51 @@ namespace Tensorflow.Train
_beta2 = beta2;
_epsilon = epsilon;
}
+
+ public override Operation _apply_sparse(IndexedSlices grad, RefVariable var)
+ {
+ return _apply_sparse_shared(grad.values, var, grad.indices, (x, i, v) =>
+ {
+ return state_ops.scatter_add(x, i, v, use_locking: _use_locking);
+ });
+ }
+
+ private Operation _apply_sparse_shared(Tensor grad, RefVariable var, Tensor indices, Func scatter_add)
+ {
+ var (beta1_power_v, beta2_power_v) = _get_beta_accumulators();
+ Tensor beta1_power = math_ops.cast(beta1_power_v, var.dtype.as_base_dtype());
+ Tensor beta2_power = math_ops.cast(beta2_power_v, var.dtype.as_base_dtype());
+ var lr_t = math_ops.cast(_lr_t, var.dtype.as_base_dtype());
+ var beta1_t = math_ops.cast(_beta1_t, var.dtype.as_base_dtype());
+ var beta2_t = math_ops.cast(_beta2_t, var.dtype.as_base_dtype());
+ var epsilon_t = math_ops.cast(_epsilon_t, var.dtype.as_base_dtype());
+ var lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power));
+ var m = get_slot(var, "m");
+ var m_scaled_g_values = grad * (1 - beta1_t);
+ var m_t = state_ops.assign(m, m * beta1_t, use_locking: _use_locking);
+ with(ops.control_dependencies(new[] { m_t }), delegate
+ {
+ m_t = scatter_add(m, indices, m_scaled_g_values);
+ });
+
+ var v = get_slot(var, "v");
+ var v_scaled_g_values = (grad * grad) * (1 - beta2_t);
+ var v_t = state_ops.assign(v, v * beta2_t, use_locking: _use_locking);
+ with(ops.control_dependencies(new[] { v_t }), delegate
+ {
+ v_t = scatter_add(v, indices, v_scaled_g_values);
+ });
+ var v_sqrt = math_ops.sqrt(v_t);
+ var var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking: _use_locking);
+ return control_flow_ops.group(new[] { var_update, m_t, v_t });
+ }
+
+ private (RefVariable, RefVariable) _get_beta_accumulators()
+ {
+ ops.init_scope();
+ var graph = ops.get_default_graph();
+ return (_get_non_slot_variable("beta1_power", graph: graph),
+ _get_non_slot_variable("beta2_power", graph: graph));
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs
index 3a14390d..f5474c23 100644
--- a/src/TensorFlowNET.Core/Train/Optimizer.cs
+++ b/src/TensorFlowNET.Core/Train/Optimizer.cs
@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
+using Tensorflow.Framework;
using static Tensorflow.Python;
namespace Tensorflow
@@ -23,8 +24,8 @@ namespace Tensorflow
public float LearningRate { get; set; }
public Tensor LearningRateTensor { get; set; }
public bool _use_locking;
- public Dictionary _slots;
- public Dictionary _non_slot_dict;
+ public Dictionary> _slots;
+ public Dictionary _non_slot_dict;
public Dictionary _deferred_slot_restorations;
public Optimizer(float learning_rate, bool use_locking, string name = null)
@@ -36,8 +37,8 @@ namespace Tensorflow
_use_locking = use_locking;
LearningRate = learning_rate;
// Dictionary of slots.
- _slots = new Dictionary();
- _non_slot_dict = new Dictionary();
+ _slots = new Dictionary>();
+ _non_slot_dict = new Dictionary();
_deferred_slot_restorations = new Dictionary();
}
@@ -110,7 +111,7 @@ namespace Tensorflow
public Operation apply_gradients(Tuple[] grads_and_vars, RefVariable global_step = null, string name = null)
{
// No DistributionStrategy case.
- var converted_grads_and_vars = new List>();
+ var converted_grads_and_vars = new List<(Tensor, RefVariable, _OptimizableVariable)>();
foreach (var (g, v) in grads_and_vars)
{
if(g != null)
@@ -118,7 +119,7 @@ namespace Tensorflow
// Convert the grad to Tensor or IndexedSlices if necessary.
var gR = ops.convert_to_tensor_or_indexed_slices(g);
var p = _get_processor(v);
- converted_grads_and_vars.Add(new Tuple(gR, v, p));
+ converted_grads_and_vars.Add((gR, v, p));
}
}
@@ -143,7 +144,8 @@ namespace Tensorflow
var scope_name = var.op.name;
with(ops.name_scope("update_" + scope_name), scope2 =>
{
- update_ops.Add(processor.update_op(this, grad));
+ var op = processor.update_op(this, grad);
+ update_ops.Add(op);
});
}
@@ -201,11 +203,69 @@ namespace Tensorflow
return gen_training_ops.apply_gradient_descent(var, alpha, grad, use_locking: _use_locking).op;
}
+ ///
+ /// Add ops to apply sparse gradients to `var`, with repeated sparse indices.
+ ///
+ ///
+ ///
+ ///
+ public virtual Operation _apply_sparse_duplicate_indices(IndexedSlices grad, RefVariable var)
+ {
+ var (summed_values, unique_indices) = _deduplicate_indexed_slices(values: grad.values, indices: grad.indices);
+ var gradient_no_duplicate_indices = new IndexedSlices(
+ indices: unique_indices,
+ values: summed_values,
+ dense_shape: grad.dense_shape);
+ return _apply_sparse(gradient_no_duplicate_indices, var);
+ }
+
+ public virtual Operation _apply_sparse(IndexedSlices grad, RefVariable var)
+ {
+ throw new NotImplementedException("_apply_sparse");
+ }
+
+ public virtual (Tensor, Tensor) _deduplicate_indexed_slices(Tensor values, Tensor indices)
+ {
+ var (unique_indices, new_index_positions) = array_ops.unique(indices);
+ var summed_values = math_ops.unsorted_segment_sum(
+ values, new_index_positions,
+ array_ops.shape(unique_indices)[0]);
+ return (summed_values, unique_indices);
+ }
+
public virtual void _prepare()
{
}
+ ///
+ /// Return a slot named `name` created for `var` by the Optimizer.
+ ///
+ ///
+ ///
+ ///
+ protected RefVariable get_slot(RefVariable var, string name)
+ {
+ var named_slots = _slots.ContainsKey(name) ? _slots[name] : null;
+ if (named_slots == null)
+ return null;
+
+ return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null;
+ }
+
+ private string _var_key(RefVariable var)
+ {
+ return $"{var.op.graph.graph_key}.{var.op.name}";
+ }
+
+ protected RefVariable _get_non_slot_variable(string name, Graph graph = null)
+ {
+ var key = $"{graph.graph_key}.{name}";
+ var non_slot = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null;
+
+ return non_slot;
+ }
+
private _OptimizableVariable _get_processor(RefVariable v)
{
if(v is RefVariable)
diff --git a/src/TensorFlowNET.Core/Train/optimizer.py.cs b/src/TensorFlowNET.Core/Train/optimizer.py.cs
index 3a376e97..fbf32876 100644
--- a/src/TensorFlowNET.Core/Train/optimizer.py.cs
+++ b/src/TensorFlowNET.Core/Train/optimizer.py.cs
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
+using Tensorflow.Framework;
namespace Tensorflow
{
@@ -32,5 +33,12 @@ namespace Tensorflow
return update_op;
}
+
+ public Operation update_op(Optimizer optimizer, IndexedSlices g)
+ {
+ var update_op = optimizer._apply_dense(g, _v);
+
+ return update_op;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs
index 4b4237a0..a5a4ab69 100644
--- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs
+++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs
@@ -97,6 +97,20 @@ namespace Tensorflow
var _op = _op_def_lib._apply_op_helper("AssignAdd", name: name, args: new { @ref, value, use_locking });
return _op.outputs[0];
}
-
+
+ ///
+ /// Adds sparse updates to a variable reference.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("ScatterAdd", name: name, args: new { @ref, indices, updates, use_locking });
+ return _op.outputs[0];
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs
index aaa27e85..4022e1dc 100644
--- a/src/TensorFlowNET.Core/Variables/state_ops.cs
+++ b/src/TensorFlowNET.Core/Variables/state_ops.cs
@@ -72,5 +72,13 @@ namespace Tensorflow
Tensor value,
bool use_locking = false,
string name = null) => gen_state_ops.assign_add(@ref, value, use_locking: use_locking, name: name);
+
+ public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null)
+ {
+ if (@ref.dtype.is_ref_dtype())
+ return gen_state_ops.scatter_add(@ref, indices, updates, use_locking: use_locking, name: name);
+
+ throw new NotImplementedException("scatter_add");
+ }
}
}
diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs
index 60b6d050..465b08b2 100644
--- a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs
+++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs
@@ -8,10 +8,7 @@ using System.Text;
using Newtonsoft.Json;
using NumSharp;
using Tensorflow;
-using Tensorflow.Keras.Engine;
using Tensorflow.Sessions;
-using TensorFlowNET.Examples.Text.cnn_models;
-using TensorFlowNET.Examples.TextClassification;
using TensorFlowNET.Examples.Utility;
using static Tensorflow.Python;
diff --git a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs
index 4b4623dc..744e52c3 100644
--- a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs
+++ b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs
@@ -28,7 +28,7 @@ namespace TensorFlowNET.UnitTest.nn_test
public void testZeroFraction()
{
var x_shape = new Shape(5, 17);
- var x_np = new NumPyRandom().randint(0, 2, x_shape);
+ var x_np = np.random.randint(0, 2, x_shape);
x_np.astype(np.float32);
var y_np = this._ZeroFraction(x_np);