Browse Source

Add Tensors class to adapt Tensor and Tensor[].

tags/v0.30
Oceania2018 5 years ago
parent
commit
0abf166437
36 changed files with 282 additions and 176 deletions
  1. +2
    -0
      .gitignore
  2. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.reshape.cs
  3. +4
    -3
      src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs
  4. +2
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/Flatten.cs
  7. +29
    -0
      src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs
  8. +10
    -43
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  9. +5
    -1
      src/TensorFlowNET.Core/Keras/Engine/Model.cs
  10. +4
    -3
      src/TensorFlowNET.Core/Keras/Engine/Node.cs
  11. +2
    -2
      src/TensorFlowNET.Core/Keras/Engine/Sequential.cs
  12. +13
    -0
      src/TensorFlowNET.Core/Keras/KerasApi.cs
  13. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
  14. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Conv.cs
  15. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Dense.cs
  16. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Dropout.cs
  17. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Embedding.cs
  18. +2
    -2
      src/TensorFlowNET.Core/Keras/Layers/LSTM.cs
  19. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs
  20. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs
  21. +3
    -40
      src/TensorFlowNET.Core/Layers/Layer.cs
  22. +2
    -2
      src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
  23. +3
    -3
      src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
  24. +2
    -2
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
  25. +2
    -2
      src/TensorFlowNET.Core/Operations/array_ops.cs
  26. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  27. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  28. +2
    -0
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  29. +70
    -0
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  30. +2
    -0
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  31. +3
    -0
      src/TensorFlowNET.Core/Variables/variable_scope.py.cs
  32. +2
    -2
      src/TensorFlowNET.Core/ops.cs
  33. +37
    -0
      test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs
  34. +34
    -0
      test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs
  35. +35
    -0
      test/TensorFlowNET.UnitTest/NativeAPI/GraphBuildTest.cs
  36. +0
    -59
      test/TensorFlowNET.UnitTest/layers_test/flatten.cs

+ 2
- 0
.gitignore View File

@@ -337,3 +337,5 @@ test/TensorFlowNET.Examples/mnist
# training model resources
.resources
/redist
*.xml
*.xsd

+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.reshape.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow
{
public partial class tensorflow
{
public Tensor reshape<T>(T tensor,
public Tensor reshape(Tensor tensor,
TensorShape shape,
string name = null) => gen_array_ops.reshape(tensor, shape, name);



+ 4
- 3
src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs View File

@@ -1,4 +1,5 @@
using System;
using NumSharp;
using System;
using System.Linq;
using static Tensorflow.Binding;

@@ -42,8 +43,8 @@ namespace Tensorflow.Framework
var values_shape = values.TensorShape.with_rank(1);
var dense_shape_shape = dense_shape.TensorShape.with_rank(1);

indices_shape[0].merge_with(values_shape.dims[0]);
indices_shape[1].merge_with(dense_shape_shape.dims[0]);
indices_shape["0"].merge_with(values_shape[0]);
indices_shape["1"].merge_with(dense_shape_shape[0]);

_shape = new TensorShape(_dense_shape.Select(x => Convert.ToInt32(x)).ToArray());
}


+ 2
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs View File

@@ -6,5 +6,7 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public class ModelArgs : LayerArgs
{
public Tensor[] Inputs { get; set; }
public Tensor[] Outputs { get; set; }
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs View File

@@ -12,6 +12,6 @@ namespace Tensorflow.Keras.ArgsDefinition
public int[] NodeIndices { get; set; }
public int[] TensorIndices { get; set; }
public Tensor InputTensors { get; set; }
public Tensor Outputs { get; set; }
public Tensors Outputs { get; set; }
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/Flatten.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine
_channels_first = args.DataFormat == "channels_first";
}

protected override Tensor call(Tensor inputs, bool is_training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
if (_channels_first)
{


+ 29
- 0
src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs View File

@@ -0,0 +1,29 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Engine
{
/// <summary>
/// Tracks the Layer call that created a Tensor, for Keras Graph Networks.
/// </summary>
public class KerasHistory
{
Layer layer;
int node_index;
int tensor_index;

public KerasHistory(Layer layer, int node_index, int tensor_index)
{
this.layer = layer;
this.node_index = node_index;
this.tensor_index = tensor_index;
}

public static implicit operator Layer(KerasHistory history)
=> history.layer;

public static implicit operator (Layer, int, int)(KerasHistory history)
=> (history.layer, history.node_index, history.tensor_index);
}
}

+ 10
- 43
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -119,11 +119,12 @@ namespace Tensorflow.Keras.Engine
/// Wraps `call`, applying pre- and post-processing steps.
/// </summary>
/// <param name="input"></param>
/// <param name="state"></param>
/// <param name="is_training"></param>
/// <returns></returns>
public Tensor Apply(Tensor inputs, bool is_training = false)
public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false)
{
Tensor outputs = null;
Tensors outputs = null;

callContext = callContext ?? new ThreadLocal<CallContext>()
{
@@ -148,7 +149,7 @@ namespace Tensorflow.Keras.Engine
if (!built)
MaybeBuild(inputs);

outputs = call(inputs, is_training: is_training);
outputs = call(inputs, state: state, is_training: is_training);

outputs = _set_connectivity_metadata_(inputs, outputs);
_handle_activity_regularization(inputs, outputs);
@@ -161,36 +162,7 @@ namespace Tensorflow.Keras.Engine
return outputs;
}

public Tensor[] Apply(Tensor[] inputs, Tensor state, bool is_training = false)
{
Tensor[] outputs = null;

callContext = callContext ?? new ThreadLocal<CallContext>()
{
Value = new CallContext()
};

var eager = tf.executing_eagerly();
using var ctxManager = CallContext.enter();

string nameScope = "";
if (eager)
nameScope = name;
else
nameScope = _name_scope();

tf_with(ops.name_scope(nameScope), scope =>
{
if (!built)
MaybeBuild(inputs[0]);

outputs = call(inputs, is_training: is_training, state: state);
});

return outputs;
}

private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs)
private Tensors _set_connectivity_metadata_(Tensors inputs, Tensors outputs)
{
/*var returnOutputs = new List<Tensor>();
foreach(var x in outputs)
@@ -211,7 +183,7 @@ namespace Tensorflow.Keras.Engine
return outputs;
}

private void _handle_activity_regularization(Tensor inputs, Tensor outputs)
private void _handle_activity_regularization(Tensors inputs, Tensors outputs)
{
//if(_activity_regularizer != null)
{
@@ -219,7 +191,7 @@ namespace Tensorflow.Keras.Engine
}
}

private void _set_mask_metadata(Tensor inputs, Tensor outputs, Tensor previous_mask)
private void _set_mask_metadata(Tensors inputs, Tensors outputs, Tensors previous_mask)
{

}
@@ -229,12 +201,7 @@ namespace Tensorflow.Keras.Engine
return null;
}

protected virtual Tensor call(Tensor inputs, bool is_training = false)
{
throw new NotImplementedException("");
}

protected virtual Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false)
protected virtual Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
throw new NotImplementedException("");
}
@@ -244,7 +211,7 @@ namespace Tensorflow.Keras.Engine
return Name;
}

protected void MaybeBuild(Tensor inputs)
protected void MaybeBuild(Tensors inputs)
{
// Check input assumptions set before layer building, e.g. input rank.
if (built)
@@ -252,7 +219,7 @@ namespace Tensorflow.Keras.Engine
if (DType == TF_DataType.DtInvalid)
args.DType = inputs.dtype;

var input_shapes = inputs.TensorShape;
var input_shapes = inputs.shape;
build(input_shapes);
built = true;
}


+ 5
- 1
src/TensorFlowNET.Core/Keras/Engine/Model.cs View File

@@ -27,7 +27,11 @@ namespace Tensorflow.Keras.Engine
public Model(ModelArgs args)
: base(args)
{

// Build _output_layers
/*foreach(var x in args.Outputs)
{
var layer = x.KerasHistory;
}*/
}

public void compile(string optimizerName, string lossName)


+ 4
- 3
src/TensorFlowNET.Core/Keras/Engine/Node.cs View File

@@ -35,8 +35,8 @@ namespace Tensorflow.Keras.Engine

public int[] node_indices;
public int[] tensor_indices;
public Tensor input_tensors;
public Tensor Outputs => args.Outputs;
public Tensors input_tensors;
public Tensors Outputs => args.Outputs;
public TensorShape[] input_shapes;
public TensorShape[] output_shapes;
List<Layer> kerasInputs;
@@ -57,7 +57,8 @@ namespace Tensorflow.Keras.Engine

// Set metadata on outputs.
var node_index = layer.InboundNodes.Count - 1;
args.Outputs.KerasHistory.Add(layer);
foreach (var (i, tensor) in enumerate(Outputs))
tensor.KerasHistory = new KerasHistory(layer, node_index, i);
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Keras/Engine/Sequential.cs View File

@@ -60,7 +60,7 @@ namespace Tensorflow.Keras.Engine

public void add(Tensor tensor)
{
var layer = tensor.KerasHistory[0];
Layer layer = tensor.KerasHistory;
add(layer);
}

@@ -129,7 +129,7 @@ namespace Tensorflow.Keras.Engine

void _map_graph_network(Tensor inputs, Tensor outputs)
{
layers.add(outputs.KerasHistory[0]);
layers.add(outputs.KerasHistory);
}
}
}

+ 13
- 0
src/TensorFlowNET.Core/Keras/KerasApi.cs View File

@@ -30,6 +30,19 @@ namespace Tensorflow
Name = name
});

/// <summary>
/// `Model` groups layers into an object with training and inference features.
/// </summary>
/// <param name="input"></param>
/// <param name="output"></param>
/// <returns></returns>
public Model Model(Tensor input, Tensor output)
=> new Model(new ModelArgs
{
Inputs = new[] { input },
Outputs = new[] { output }
});

/// <summary>
/// Instantiate a Keras tensor.
/// </summary>


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs View File

@@ -143,7 +143,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensor call(Tensor inputs, bool is_training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
Tensor outputs = null;



+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Conv.cs View File

@@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensor call(Tensor inputs, bool training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false)
{
var outputs = _convolution_op.__call__(inputs, kernel);
if (use_bias)


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Dense.cs View File

@@ -65,7 +65,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensor call(Tensor inputs, bool training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false)
{
Tensor outputs = null;
var rank = inputs.rank;


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Dropout.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensor call(Tensor inputs, bool is_training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
var output = tf_utils.smart_cond(is_training,
() => tf.nn.dropout(inputs,


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Embedding.cs View File

@@ -62,7 +62,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensor call(Tensor inputs, bool is_training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
var dtype = inputs.dtype;
if (dtype != tf.int32 && dtype != tf.int64)


+ 2
- 2
src/TensorFlowNET.Core/Keras/Layers/LSTM.cs View File

@@ -29,9 +29,9 @@ namespace Tensorflow.Keras.Layers
.ToArray();
}

protected override Tensor call(Tensor inputs, bool is_training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
return base.call(inputs, is_training);
return base.call(inputs, state: state, is_training: is_training);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs View File

@@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers
input_spec = new InputSpec(ndim: 4);
}

protected override Tensor call(Tensor inputs, bool is_training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
int[] pool_shape;
int[] strides;


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs View File

@@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensor call(Tensor inputs, bool is_training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
scale = math_ops.cast(args.Scale, args.DType);
offset = math_ops.cast(args.Offset, args.DType);


+ 3
- 40
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -61,44 +61,7 @@ namespace Tensorflow.Layers
return (results[0], results[1]);
}

public Tensor __call__(Tensor inputs,
Tensor training = null,
VariableScope scope = null)
{
_set_scope(scope);
_graph = ops._get_graph_from_inputs(new Tensor[] { inputs }, graph: _graph);

variable_scope scope_context_manager = null;
if (built)
{
scope_context_manager = tf.variable_scope(_scope,
reuse: true,
auxiliary_name_scope: false);
}
else
{
scope_context_manager = tf.variable_scope(_scope,
reuse: _reuse,
auxiliary_name_scope: false);
}

Tensor outputs = null;
tf_with(scope_context_manager, scope2 =>
{
_current_scope = scope2;
// Actually call layer
outputs = base.Apply(inputs[0],
is_training: training == null ? false : false);
});


// Update global default collections.
_add_elements_to_collection(updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS });

return outputs;
}

public Tensor[] __call__(Tensor[] inputs,
public Tensors __call__(Tensors inputs,
Tensor state = null,
Tensor training = null,
VariableScope scope = null)
@@ -120,13 +83,13 @@ namespace Tensorflow.Layers
auxiliary_name_scope: false);
}

Tensor[] outputs = null;
Tensors outputs = null;
tf_with(scope_context_manager, scope2 =>
{
_current_scope = scope2;
// Actually call layer
outputs = base.Apply(inputs,
state,
state: state,
is_training: training == null ? false : false);
});



+ 2
- 2
src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs View File

@@ -74,7 +74,7 @@ namespace Tensorflow
/// <param name="training"></param>
/// <param name="state"></param>
/// <returns></returns>
protected override Tensor call(Tensor inputs, bool is_training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
var one = constant_op.constant(1, dtype: dtypes.int32);
// Parameters of gates are concatenated into one multiply for efficiency.
@@ -87,7 +87,7 @@ namespace Tensorflow
// array_ops.split(value: state, num_or_size_splits: 2, axis: one);
throw new NotImplementedException("BasicLstmCell call");
}
var gate_inputs = math_ops.matmul(array_ops.concat(new[] { inputs, h }, 1), _kernel.AsTensor());
var gate_inputs = math_ops.matmul(array_ops.concat(new[] { (Tensor)inputs, h }, 1), _kernel.AsTensor());
gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor());

// i = input_gate, j = new_input, f = forget_gate, o = output_gate


+ 3
- 3
src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs View File

@@ -67,14 +67,14 @@ namespace Tensorflow
built = true;
}

protected override Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false)
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
// Most basic RNN: output = new_state = act(W * input + U * state + B).
var concat = array_ops.concat(new[] { inputs[0], state }, 1);
var concat = array_ops.concat(new Tensor[] { inputs, state }, 1);
var gate_inputs = math_ops.matmul(concat, _kernel.AsTensor());
gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor());
var output = _activation(gate_inputs, null);
return new[] { output, output };
return new Tensors(output, output);
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -127,7 +127,7 @@ namespace Tensorflow.Operations
{
input_shape = flat_input.TensorShape.with_rank_at_least(2);
batch_size = tensor_shape.dimension_at_index(input_shape, 0);
var input_size = input_shape[1];
var input_size = input_shape[new Slice(1)];
fixed_batch_size.merge_with(batch_size);
foreach (var (i, size) in enumerate(input_size.dims))
{
@@ -364,7 +364,7 @@ namespace Tensorflow.Operations
if (sequence_length != null)
throw new NotImplementedException("sequence_length != null");
else
outputs = cell.__call__(new[] { input_t_t }, state: state1);
outputs = cell.__call__(input_t_t, state: state1);

var (output, new_state) = (outputs[0], outputs[1]);
// Keras cells always wrap state as list, even if it's a single tensor.


+ 2
- 2
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -157,7 +157,7 @@ namespace Tensorflow
leading_size,
shape(tensor_tensor)[$"{axis + ndims_mask}:"]
}, 0);
tensor_tensor = reshape(tensor, shape1);
tensor_tensor = reshape(tensor_tensor, shape1);
var first_dim = shape_tensor.dims.Skip(axis).Take(ndims_mask).First();
var s1 = tensor_shape.as_shape(shape_tensor.dims.Take(axis).ToArray());
var s2 = s1.concatenate(new[] { first_dim }).concatenate(shape_tensor.dims.Skip(axis + ndims_mask).ToArray());
@@ -353,7 +353,7 @@ namespace Tensorflow
public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
=> ones_like_impl(tensor, dtype, name, optimize);

public static Tensor reshape<T1, T2>(T1 tensor, T2 shape, string name = null)
public static Tensor reshape<T2>(Tensor tensor, T2 shape, string name = null)
=> gen_array_ops.reshape(tensor, shape, null);

private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true)


+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -292,7 +292,7 @@ namespace Tensorflow
return _op.output;
}

public static Tensor reshape<T1, T2>(T1 tensor, T2 shape, string name = null)
public static Tensor reshape<T>(Tensor tensor, T shape, string name = null)
{
if (tf.Context.executing_eagerly())
{


+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -144,7 +144,7 @@ namespace Tensorflow
/// <summary>
/// Keras History: (Layer, (node_index, tensor_index))
/// </summary>
public List<Layer> KerasHistory = new List<Layer>();
public KerasHistory KerasHistory { get; set; }

/// <summary>
/// Updates the shape of this tensor.


+ 2
- 0
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -132,6 +132,8 @@ namespace Tensorflow
}
}

public int this[int index] => dims[index];

/// <summary>
/// Returns True iff `self` is fully defined in every dimension.
/// </summary>


+ 70
- 0
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -0,0 +1,70 @@
using NumSharp;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Gradients;

namespace Tensorflow
{
/// <summary>
/// Tensors is used to represent a Tensor or a array of Tensor.
/// It will simplify the API interface, it converts Tensor
/// and Tensor[] to Tensors implicitily. And parse back to Tensor
/// and Tensor[] from Tensors implicitily.
/// It works for tuple and scalar as well.
/// </summary>
public class Tensors : IEnumerable<Tensor>
{
Tensor[] items;

public TF_DataType dtype => items.First().dtype;
public TensorShape shape => items.First().TensorShape;
public int rank => items.First().rank;
public bool IsEagerTensor => items.First().IsEagerTensor;

public Tensor this[int index] => items[index];

public Tensors(params Tensor[] tensors)
{
items = tensors;
}

public Tensors(NDArray nd)
{
items = new[] { ops.convert_to_tensor(nd) };
}

public IEnumerator<Tensor> GetEnumerator()
{
foreach (var tensor in items)
yield return tensor;
}

IEnumerator IEnumerable.GetEnumerator()
{
throw new NotImplementedException();
}

public static implicit operator Tensors(Tensor tensor)
=> new Tensors(tensor);

public static implicit operator Tensors(NDArray nd)
=> new Tensors(nd);

public static implicit operator Tensors(Tensor[] tensors)
=> new Tensors(tensors);

public static implicit operator Tensor(Tensors tensors)
=> tensors.FirstOrDefault();

public static implicit operator Tensor[](Tensors tensors)
=> tensors.items;

public override string ToString()
=> items.Length == 1
? items.First().ToString()
: items.Length + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name));
}
}

+ 2
- 0
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -155,6 +155,8 @@ namespace Tensorflow
return val;
case NDArray val:
return new EagerTensor(val, ctx.DeviceName);
//case TensorShape val:
//return new EagerTensor(val.dims, ctx.DeviceName);
case string val:
return new EagerTensor(val, ctx.DeviceName);
case string[] val:


+ 3
- 0
src/TensorFlowNET.Core/Variables/variable_scope.py.cs View File

@@ -16,6 +16,7 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;

namespace Tensorflow
@@ -280,6 +281,7 @@ namespace Tensorflow
return scope._scope;
}

[DebuggerHidden]
public void __exit__()
{
_cached_pure_variable_scope.__exit__();
@@ -287,6 +289,7 @@ namespace Tensorflow
_current_name_scope.__exit__();
}

[DebuggerHidden]
public void Dispose()
{
if (_current_name_scope != null)


+ 2
- 2
src/TensorFlowNET.Core/ops.cs View File

@@ -76,10 +76,10 @@ namespace Tensorflow
return get_default_graph().get_collection_ref<T>(key);
}

public static Graph _get_graph_from_inputs(params Tensor[] op_input_list)
public static Graph _get_graph_from_inputs(Tensors op_input_list)
=> _get_graph_from_inputs(op_input_list: op_input_list, graph: null);

public static Graph _get_graph_from_inputs(Tensor[] op_input_list, Graph graph = null)
public static Graph _get_graph_from_inputs(Tensors op_input_list, Graph graph = null)
{
foreach(var op_input in op_input_list)
{


+ 37
- 0
test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs View File

@@ -0,0 +1,37 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using NumSharp;
using Tensorflow.UnitTest;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.Keras
{
/// <summary>
/// https://www.tensorflow.org/guide/keras/save_and_serialize
/// </summary>
[TestClass]
public class ModelSaveTest : EagerModeTestBase
{
[TestMethod]
public void SaveAndLoadTest()
{
var model = GetModel();
}

Model GetModel()
{
var keras = tf.keras;

// Create a simple model.
var inputs = keras.Input(shape: 32);
var outputs = keras.layers.Dense(1).Apply(inputs);
var model = keras.Model(inputs, outputs);
model.compile("adam", "mean_squared_error");
return model;
}
}
}

+ 34
- 0
test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs View File

@@ -12,6 +12,40 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
[TestClass]
public class FunctionApiTest : TFNetApiTest
{
Tensor Min(Tensor a, Tensor b)
{
return tf.cond(a < b, () => a, () => b);
}

[TestMethod]
public void MulInAutoGraph()
{
var a = tf.constant(1);
var b = tf.constant(2);
// For first time running, tf.net will record the operations in graph mode.
// And register to tensorflow op library.
var output = Mul(a, b);
Assert.AreEqual(2, (int)output);

var c = tf.constant(3);
// for the following invoke, Mul will be intercepted and run it in eager mode.
output = Mul(b, c);
Assert.AreEqual(6, (int)output);
}

/// <summary>
/// Method with AutoGraph attribute will be converted to FuncGraph
/// when it's invoked for the first time.
/// </summary>
/// <param name="a"></param>
/// <param name="b"></param>
/// <returns></returns>
[AutoGraph]
Tensor Mul(Tensor a, Tensor b)
{
return a * b;
}

[TestMethod]
public void TwoInputs_OneOutput()
{


+ 35
- 0
test/TensorFlowNET.UnitTest/NativeAPI/GraphBuildTest.cs View File

@@ -0,0 +1,35 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Text;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.NativeAPI
{
[TestClass]
public class GraphBuildTest : CApiTest
{
[TestMethod, Ignore("Waiting to merge https://github.com/tensorflow/tensorflow/pull/43383")]
public void UpdateEdge()
{
using var graph = new Graph().as_default();

var one = tf.constant(1, name: "one");
var two = tf.constant(2, name: "two");
var add = tf.add(one, two, name: "add");
var neg = tf.negative(add, name: "neg");

Assert.AreEqual(1, one.consumers().Length);
Assert.AreEqual("add", neg.op.node_def.Input[0]);
// update edge
neg.op._update_input(0, one);
// c_api.TF_UpdateEdge(graph, new TF_Output(c1.op, 0), new TF_Input(neg.op, 0), tf.Status.Handle);

Assert.AreEqual(2, one.consumers().Length);
Assert.AreEqual("one:0", neg.op.node_def.Input[0]);
}
}
}

+ 0
- 59
test/TensorFlowNET.UnitTest/layers_test/flatten.cs View File

@@ -1,59 +0,0 @@
using System;
using FluentAssertions;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow;
using Tensorflow.UnitTest;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.layers_test
{
[TestClass]
public class flatten : GraphModeTestBase
{
[TestMethod]
public void Case1()
{
var sess = tf.Session().as_default();

var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, 3, 1, 2));
sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
}

[TestMethod]
public void Case2()
{
var sess = tf.Session().as_default();

var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6));
sess.run(tf.layers.flatten(input), (input, np.arange(6))).Should().BeShaped(6, 1);
}

[TestMethod]
public void Case3()
{
var sess = tf.Session().as_default();

var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape());
new Action(() => sess.run(tf.layers.flatten(input), (input, NDArray.Scalar(6)))).Should().Throw<ValueError>();
}

[TestMethod]
public void Case4()
{
var sess = tf.Session().as_default();

var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, Unknown, 1, 2));
sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
}

[TestMethod]
public void Case5()
{
var sess = tf.Session().as_default();

var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(Unknown, 4, 3, 1, 2));
sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
}
}
}

Loading…
Cancel
Save