@@ -337,3 +337,5 @@ test/TensorFlowNET.Examples/mnist | |||
# training model resources | |||
.resources | |||
/redist | |||
*.xml | |||
*.xsd |
@@ -18,7 +18,7 @@ namespace Tensorflow | |||
{ | |||
public partial class tensorflow | |||
{ | |||
public Tensor reshape<T>(T tensor, | |||
public Tensor reshape(Tensor tensor, | |||
TensorShape shape, | |||
string name = null) => gen_array_ops.reshape(tensor, shape, name); | |||
@@ -1,4 +1,5 @@ | |||
using System; | |||
using NumSharp; | |||
using System; | |||
using System.Linq; | |||
using static Tensorflow.Binding; | |||
@@ -42,8 +43,8 @@ namespace Tensorflow.Framework | |||
var values_shape = values.TensorShape.with_rank(1); | |||
var dense_shape_shape = dense_shape.TensorShape.with_rank(1); | |||
indices_shape[0].merge_with(values_shape.dims[0]); | |||
indices_shape[1].merge_with(dense_shape_shape.dims[0]); | |||
indices_shape["0"].merge_with(values_shape[0]); | |||
indices_shape["1"].merge_with(dense_shape_shape[0]); | |||
_shape = new TensorShape(_dense_shape.Select(x => Convert.ToInt32(x)).ToArray()); | |||
} | |||
@@ -6,5 +6,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class ModelArgs : LayerArgs | |||
{ | |||
public Tensor[] Inputs { get; set; } | |||
public Tensor[] Outputs { get; set; } | |||
} | |||
} |
@@ -12,6 +12,6 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
public int[] NodeIndices { get; set; } | |||
public int[] TensorIndices { get; set; } | |||
public Tensor InputTensors { get; set; } | |||
public Tensor Outputs { get; set; } | |||
public Tensors Outputs { get; set; } | |||
} | |||
} |
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine | |||
_channels_first = args.DataFormat == "channels_first"; | |||
} | |||
protected override Tensor call(Tensor inputs, bool is_training = false) | |||
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
if (_channels_first) | |||
{ | |||
@@ -0,0 +1,29 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
/// <summary> | |||
/// Tracks the Layer call that created a Tensor, for Keras Graph Networks. | |||
/// </summary> | |||
public class KerasHistory | |||
{ | |||
Layer layer; | |||
int node_index; | |||
int tensor_index; | |||
public KerasHistory(Layer layer, int node_index, int tensor_index) | |||
{ | |||
this.layer = layer; | |||
this.node_index = node_index; | |||
this.tensor_index = tensor_index; | |||
} | |||
public static implicit operator Layer(KerasHistory history) | |||
=> history.layer; | |||
public static implicit operator (Layer, int, int)(KerasHistory history) | |||
=> (history.layer, history.node_index, history.tensor_index); | |||
} | |||
} |
@@ -119,11 +119,12 @@ namespace Tensorflow.Keras.Engine | |||
/// Wraps `call`, applying pre- and post-processing steps. | |||
/// </summary> | |||
/// <param name="input"></param> | |||
/// <param name="state"></param> | |||
/// <param name="is_training"></param> | |||
/// <returns></returns> | |||
public Tensor Apply(Tensor inputs, bool is_training = false) | |||
public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
Tensor outputs = null; | |||
Tensors outputs = null; | |||
callContext = callContext ?? new ThreadLocal<CallContext>() | |||
{ | |||
@@ -148,7 +149,7 @@ namespace Tensorflow.Keras.Engine | |||
if (!built) | |||
MaybeBuild(inputs); | |||
outputs = call(inputs, is_training: is_training); | |||
outputs = call(inputs, state: state, is_training: is_training); | |||
outputs = _set_connectivity_metadata_(inputs, outputs); | |||
_handle_activity_regularization(inputs, outputs); | |||
@@ -161,36 +162,7 @@ namespace Tensorflow.Keras.Engine | |||
return outputs; | |||
} | |||
public Tensor[] Apply(Tensor[] inputs, Tensor state, bool is_training = false) | |||
{ | |||
Tensor[] outputs = null; | |||
callContext = callContext ?? new ThreadLocal<CallContext>() | |||
{ | |||
Value = new CallContext() | |||
}; | |||
var eager = tf.executing_eagerly(); | |||
using var ctxManager = CallContext.enter(); | |||
string nameScope = ""; | |||
if (eager) | |||
nameScope = name; | |||
else | |||
nameScope = _name_scope(); | |||
tf_with(ops.name_scope(nameScope), scope => | |||
{ | |||
if (!built) | |||
MaybeBuild(inputs[0]); | |||
outputs = call(inputs, is_training: is_training, state: state); | |||
}); | |||
return outputs; | |||
} | |||
private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs) | |||
private Tensors _set_connectivity_metadata_(Tensors inputs, Tensors outputs) | |||
{ | |||
/*var returnOutputs = new List<Tensor>(); | |||
foreach(var x in outputs) | |||
@@ -211,7 +183,7 @@ namespace Tensorflow.Keras.Engine | |||
return outputs; | |||
} | |||
private void _handle_activity_regularization(Tensor inputs, Tensor outputs) | |||
private void _handle_activity_regularization(Tensors inputs, Tensors outputs) | |||
{ | |||
//if(_activity_regularizer != null) | |||
{ | |||
@@ -219,7 +191,7 @@ namespace Tensorflow.Keras.Engine | |||
} | |||
} | |||
private void _set_mask_metadata(Tensor inputs, Tensor outputs, Tensor previous_mask) | |||
private void _set_mask_metadata(Tensors inputs, Tensors outputs, Tensors previous_mask) | |||
{ | |||
} | |||
@@ -229,12 +201,7 @@ namespace Tensorflow.Keras.Engine | |||
return null; | |||
} | |||
protected virtual Tensor call(Tensor inputs, bool is_training = false) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
protected virtual Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false) | |||
protected virtual Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
@@ -244,7 +211,7 @@ namespace Tensorflow.Keras.Engine | |||
return Name; | |||
} | |||
protected void MaybeBuild(Tensor inputs) | |||
protected void MaybeBuild(Tensors inputs) | |||
{ | |||
// Check input assumptions set before layer building, e.g. input rank. | |||
if (built) | |||
@@ -252,7 +219,7 @@ namespace Tensorflow.Keras.Engine | |||
if (DType == TF_DataType.DtInvalid) | |||
args.DType = inputs.dtype; | |||
var input_shapes = inputs.TensorShape; | |||
var input_shapes = inputs.shape; | |||
build(input_shapes); | |||
built = true; | |||
} | |||
@@ -27,7 +27,11 @@ namespace Tensorflow.Keras.Engine | |||
public Model(ModelArgs args) | |||
: base(args) | |||
{ | |||
// Build _output_layers | |||
/*foreach(var x in args.Outputs) | |||
{ | |||
var layer = x.KerasHistory; | |||
}*/ | |||
} | |||
public void compile(string optimizerName, string lossName) | |||
@@ -35,8 +35,8 @@ namespace Tensorflow.Keras.Engine | |||
public int[] node_indices; | |||
public int[] tensor_indices; | |||
public Tensor input_tensors; | |||
public Tensor Outputs => args.Outputs; | |||
public Tensors input_tensors; | |||
public Tensors Outputs => args.Outputs; | |||
public TensorShape[] input_shapes; | |||
public TensorShape[] output_shapes; | |||
List<Layer> kerasInputs; | |||
@@ -57,7 +57,8 @@ namespace Tensorflow.Keras.Engine | |||
// Set metadata on outputs. | |||
var node_index = layer.InboundNodes.Count - 1; | |||
args.Outputs.KerasHistory.Add(layer); | |||
foreach (var (i, tensor) in enumerate(Outputs)) | |||
tensor.KerasHistory = new KerasHistory(layer, node_index, i); | |||
} | |||
} | |||
} |
@@ -60,7 +60,7 @@ namespace Tensorflow.Keras.Engine | |||
public void add(Tensor tensor) | |||
{ | |||
var layer = tensor.KerasHistory[0]; | |||
Layer layer = tensor.KerasHistory; | |||
add(layer); | |||
} | |||
@@ -129,7 +129,7 @@ namespace Tensorflow.Keras.Engine | |||
void _map_graph_network(Tensor inputs, Tensor outputs) | |||
{ | |||
layers.add(outputs.KerasHistory[0]); | |||
layers.add(outputs.KerasHistory); | |||
} | |||
} | |||
} |
@@ -30,6 +30,19 @@ namespace Tensorflow | |||
Name = name | |||
}); | |||
/// <summary> | |||
/// `Model` groups layers into an object with training and inference features. | |||
/// </summary> | |||
/// <param name="input"></param> | |||
/// <param name="output"></param> | |||
/// <returns></returns> | |||
public Model Model(Tensor input, Tensor output) | |||
=> new Model(new ModelArgs | |||
{ | |||
Inputs = new[] { input }, | |||
Outputs = new[] { output } | |||
}); | |||
/// <summary> | |||
/// Instantiate a Keras tensor. | |||
/// </summary> | |||
@@ -143,7 +143,7 @@ namespace Tensorflow.Keras.Layers | |||
built = true; | |||
} | |||
protected override Tensor call(Tensor inputs, bool is_training = false) | |||
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
Tensor outputs = null; | |||
@@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Layers | |||
built = true; | |||
} | |||
protected override Tensor call(Tensor inputs, bool training = false) | |||
protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false) | |||
{ | |||
var outputs = _convolution_op.__call__(inputs, kernel); | |||
if (use_bias) | |||
@@ -65,7 +65,7 @@ namespace Tensorflow.Keras.Layers | |||
built = true; | |||
} | |||
protected override Tensor call(Tensor inputs, bool training = false) | |||
protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false) | |||
{ | |||
Tensor outputs = null; | |||
var rank = inputs.rank; | |||
@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers | |||
this.args = args; | |||
} | |||
protected override Tensor call(Tensor inputs, bool is_training = false) | |||
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
var output = tf_utils.smart_cond(is_training, | |||
() => tf.nn.dropout(inputs, | |||
@@ -62,7 +62,7 @@ namespace Tensorflow.Keras.Layers | |||
built = true; | |||
} | |||
protected override Tensor call(Tensor inputs, bool is_training = false) | |||
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
var dtype = inputs.dtype; | |||
if (dtype != tf.int32 && dtype != tf.int64) | |||
@@ -29,9 +29,9 @@ namespace Tensorflow.Keras.Layers | |||
.ToArray(); | |||
} | |||
protected override Tensor call(Tensor inputs, bool is_training = false) | |||
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
return base.call(inputs, is_training); | |||
return base.call(inputs, state: state, is_training: is_training); | |||
} | |||
} | |||
} |
@@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers | |||
input_spec = new InputSpec(ndim: 4); | |||
} | |||
protected override Tensor call(Tensor inputs, bool is_training = false) | |||
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
int[] pool_shape; | |||
int[] strides; | |||
@@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Layers | |||
this.args = args; | |||
} | |||
protected override Tensor call(Tensor inputs, bool is_training = false) | |||
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
scale = math_ops.cast(args.Scale, args.DType); | |||
offset = math_ops.cast(args.Offset, args.DType); | |||
@@ -61,44 +61,7 @@ namespace Tensorflow.Layers | |||
return (results[0], results[1]); | |||
} | |||
public Tensor __call__(Tensor inputs, | |||
Tensor training = null, | |||
VariableScope scope = null) | |||
{ | |||
_set_scope(scope); | |||
_graph = ops._get_graph_from_inputs(new Tensor[] { inputs }, graph: _graph); | |||
variable_scope scope_context_manager = null; | |||
if (built) | |||
{ | |||
scope_context_manager = tf.variable_scope(_scope, | |||
reuse: true, | |||
auxiliary_name_scope: false); | |||
} | |||
else | |||
{ | |||
scope_context_manager = tf.variable_scope(_scope, | |||
reuse: _reuse, | |||
auxiliary_name_scope: false); | |||
} | |||
Tensor outputs = null; | |||
tf_with(scope_context_manager, scope2 => | |||
{ | |||
_current_scope = scope2; | |||
// Actually call layer | |||
outputs = base.Apply(inputs[0], | |||
is_training: training == null ? false : false); | |||
}); | |||
// Update global default collections. | |||
_add_elements_to_collection(updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS }); | |||
return outputs; | |||
} | |||
public Tensor[] __call__(Tensor[] inputs, | |||
public Tensors __call__(Tensors inputs, | |||
Tensor state = null, | |||
Tensor training = null, | |||
VariableScope scope = null) | |||
@@ -120,13 +83,13 @@ namespace Tensorflow.Layers | |||
auxiliary_name_scope: false); | |||
} | |||
Tensor[] outputs = null; | |||
Tensors outputs = null; | |||
tf_with(scope_context_manager, scope2 => | |||
{ | |||
_current_scope = scope2; | |||
// Actually call layer | |||
outputs = base.Apply(inputs, | |||
state, | |||
state: state, | |||
is_training: training == null ? false : false); | |||
}); | |||
@@ -74,7 +74,7 @@ namespace Tensorflow | |||
/// <param name="training"></param> | |||
/// <param name="state"></param> | |||
/// <returns></returns> | |||
protected override Tensor call(Tensor inputs, bool is_training = false) | |||
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
var one = constant_op.constant(1, dtype: dtypes.int32); | |||
// Parameters of gates are concatenated into one multiply for efficiency. | |||
@@ -87,7 +87,7 @@ namespace Tensorflow | |||
// array_ops.split(value: state, num_or_size_splits: 2, axis: one); | |||
throw new NotImplementedException("BasicLstmCell call"); | |||
} | |||
var gate_inputs = math_ops.matmul(array_ops.concat(new[] { inputs, h }, 1), _kernel.AsTensor()); | |||
var gate_inputs = math_ops.matmul(array_ops.concat(new[] { (Tensor)inputs, h }, 1), _kernel.AsTensor()); | |||
gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor()); | |||
// i = input_gate, j = new_input, f = forget_gate, o = output_gate | |||
@@ -67,14 +67,14 @@ namespace Tensorflow | |||
built = true; | |||
} | |||
protected override Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false) | |||
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
// Most basic RNN: output = new_state = act(W * input + U * state + B). | |||
var concat = array_ops.concat(new[] { inputs[0], state }, 1); | |||
var concat = array_ops.concat(new Tensor[] { inputs, state }, 1); | |||
var gate_inputs = math_ops.matmul(concat, _kernel.AsTensor()); | |||
gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor()); | |||
var output = _activation(gate_inputs, null); | |||
return new[] { output, output }; | |||
return new Tensors(output, output); | |||
} | |||
} | |||
} |
@@ -127,7 +127,7 @@ namespace Tensorflow.Operations | |||
{ | |||
input_shape = flat_input.TensorShape.with_rank_at_least(2); | |||
batch_size = tensor_shape.dimension_at_index(input_shape, 0); | |||
var input_size = input_shape[1]; | |||
var input_size = input_shape[new Slice(1)]; | |||
fixed_batch_size.merge_with(batch_size); | |||
foreach (var (i, size) in enumerate(input_size.dims)) | |||
{ | |||
@@ -364,7 +364,7 @@ namespace Tensorflow.Operations | |||
if (sequence_length != null) | |||
throw new NotImplementedException("sequence_length != null"); | |||
else | |||
outputs = cell.__call__(new[] { input_t_t }, state: state1); | |||
outputs = cell.__call__(input_t_t, state: state1); | |||
var (output, new_state) = (outputs[0], outputs[1]); | |||
// Keras cells always wrap state as list, even if it's a single tensor. | |||
@@ -157,7 +157,7 @@ namespace Tensorflow | |||
leading_size, | |||
shape(tensor_tensor)[$"{axis + ndims_mask}:"] | |||
}, 0); | |||
tensor_tensor = reshape(tensor, shape1); | |||
tensor_tensor = reshape(tensor_tensor, shape1); | |||
var first_dim = shape_tensor.dims.Skip(axis).Take(ndims_mask).First(); | |||
var s1 = tensor_shape.as_shape(shape_tensor.dims.Take(axis).ToArray()); | |||
var s2 = s1.concatenate(new[] { first_dim }).concatenate(shape_tensor.dims.Skip(axis + ndims_mask).ToArray()); | |||
@@ -353,7 +353,7 @@ namespace Tensorflow | |||
public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | |||
=> ones_like_impl(tensor, dtype, name, optimize); | |||
public static Tensor reshape<T1, T2>(T1 tensor, T2 shape, string name = null) | |||
public static Tensor reshape<T2>(Tensor tensor, T2 shape, string name = null) | |||
=> gen_array_ops.reshape(tensor, shape, null); | |||
private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true) | |||
@@ -292,7 +292,7 @@ namespace Tensorflow | |||
return _op.output; | |||
} | |||
public static Tensor reshape<T1, T2>(T1 tensor, T2 shape, string name = null) | |||
public static Tensor reshape<T>(Tensor tensor, T shape, string name = null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
@@ -144,7 +144,7 @@ namespace Tensorflow | |||
/// <summary> | |||
/// Keras History: (Layer, (node_index, tensor_index)) | |||
/// </summary> | |||
public List<Layer> KerasHistory = new List<Layer>(); | |||
public KerasHistory KerasHistory { get; set; } | |||
/// <summary> | |||
/// Updates the shape of this tensor. | |||
@@ -132,6 +132,8 @@ namespace Tensorflow | |||
} | |||
} | |||
public int this[int index] => dims[index]; | |||
/// <summary> | |||
/// Returns True iff `self` is fully defined in every dimension. | |||
/// </summary> | |||
@@ -0,0 +1,70 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Collections; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Gradients; | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// Tensors is used to represent a Tensor or a array of Tensor. | |||
/// It will simplify the API interface, it converts Tensor | |||
/// and Tensor[] to Tensors implicitily. And parse back to Tensor | |||
/// and Tensor[] from Tensors implicitily. | |||
/// It works for tuple and scalar as well. | |||
/// </summary> | |||
public class Tensors : IEnumerable<Tensor> | |||
{ | |||
Tensor[] items; | |||
public TF_DataType dtype => items.First().dtype; | |||
public TensorShape shape => items.First().TensorShape; | |||
public int rank => items.First().rank; | |||
public bool IsEagerTensor => items.First().IsEagerTensor; | |||
public Tensor this[int index] => items[index]; | |||
public Tensors(params Tensor[] tensors) | |||
{ | |||
items = tensors; | |||
} | |||
public Tensors(NDArray nd) | |||
{ | |||
items = new[] { ops.convert_to_tensor(nd) }; | |||
} | |||
public IEnumerator<Tensor> GetEnumerator() | |||
{ | |||
foreach (var tensor in items) | |||
yield return tensor; | |||
} | |||
IEnumerator IEnumerable.GetEnumerator() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public static implicit operator Tensors(Tensor tensor) | |||
=> new Tensors(tensor); | |||
public static implicit operator Tensors(NDArray nd) | |||
=> new Tensors(nd); | |||
public static implicit operator Tensors(Tensor[] tensors) | |||
=> new Tensors(tensors); | |||
public static implicit operator Tensor(Tensors tensors) | |||
=> tensors.FirstOrDefault(); | |||
public static implicit operator Tensor[](Tensors tensors) | |||
=> tensors.items; | |||
public override string ToString() | |||
=> items.Length == 1 | |||
? items.First().ToString() | |||
: items.Length + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name)); | |||
} | |||
} |
@@ -155,6 +155,8 @@ namespace Tensorflow | |||
return val; | |||
case NDArray val: | |||
return new EagerTensor(val, ctx.DeviceName); | |||
//case TensorShape val: | |||
//return new EagerTensor(val.dims, ctx.DeviceName); | |||
case string val: | |||
return new EagerTensor(val, ctx.DeviceName); | |||
case string[] val: | |||
@@ -16,6 +16,7 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
namespace Tensorflow | |||
@@ -280,6 +281,7 @@ namespace Tensorflow | |||
return scope._scope; | |||
} | |||
[DebuggerHidden] | |||
public void __exit__() | |||
{ | |||
_cached_pure_variable_scope.__exit__(); | |||
@@ -287,6 +289,7 @@ namespace Tensorflow | |||
_current_name_scope.__exit__(); | |||
} | |||
[DebuggerHidden] | |||
public void Dispose() | |||
{ | |||
if (_current_name_scope != null) | |||
@@ -76,10 +76,10 @@ namespace Tensorflow | |||
return get_default_graph().get_collection_ref<T>(key); | |||
} | |||
public static Graph _get_graph_from_inputs(params Tensor[] op_input_list) | |||
public static Graph _get_graph_from_inputs(Tensors op_input_list) | |||
=> _get_graph_from_inputs(op_input_list: op_input_list, graph: null); | |||
public static Graph _get_graph_from_inputs(Tensor[] op_input_list, Graph graph = null) | |||
public static Graph _get_graph_from_inputs(Tensors op_input_list, Graph graph = null) | |||
{ | |||
foreach(var op_input in op_input_list) | |||
{ | |||
@@ -0,0 +1,37 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Layers; | |||
using NumSharp; | |||
using Tensorflow.UnitTest; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.Keras | |||
{ | |||
/// <summary> | |||
/// https://www.tensorflow.org/guide/keras/save_and_serialize | |||
/// </summary> | |||
[TestClass] | |||
public class ModelSaveTest : EagerModeTestBase | |||
{ | |||
[TestMethod] | |||
public void SaveAndLoadTest() | |||
{ | |||
var model = GetModel(); | |||
} | |||
Model GetModel() | |||
{ | |||
var keras = tf.keras; | |||
// Create a simple model. | |||
var inputs = keras.Input(shape: 32); | |||
var outputs = keras.layers.Dense(1).Apply(inputs); | |||
var model = keras.Model(inputs, outputs); | |||
model.compile("adam", "mean_squared_error"); | |||
return model; | |||
} | |||
} | |||
} |
@@ -12,6 +12,40 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
[TestClass] | |||
public class FunctionApiTest : TFNetApiTest | |||
{ | |||
Tensor Min(Tensor a, Tensor b) | |||
{ | |||
return tf.cond(a < b, () => a, () => b); | |||
} | |||
[TestMethod] | |||
public void MulInAutoGraph() | |||
{ | |||
var a = tf.constant(1); | |||
var b = tf.constant(2); | |||
// For first time running, tf.net will record the operations in graph mode. | |||
// And register to tensorflow op library. | |||
var output = Mul(a, b); | |||
Assert.AreEqual(2, (int)output); | |||
var c = tf.constant(3); | |||
// for the following invoke, Mul will be intercepted and run it in eager mode. | |||
output = Mul(b, c); | |||
Assert.AreEqual(6, (int)output); | |||
} | |||
/// <summary> | |||
/// Method with AutoGraph attribute will be converted to FuncGraph | |||
/// when it's invoked for the first time. | |||
/// </summary> | |||
/// <param name="a"></param> | |||
/// <param name="b"></param> | |||
/// <returns></returns> | |||
[AutoGraph] | |||
Tensor Mul(Tensor a, Tensor b) | |||
{ | |||
return a * b; | |||
} | |||
[TestMethod] | |||
public void TwoInputs_OneOutput() | |||
{ | |||
@@ -0,0 +1,35 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using System.Linq; | |||
using Tensorflow; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.NativeAPI | |||
{ | |||
[TestClass] | |||
public class GraphBuildTest : CApiTest | |||
{ | |||
[TestMethod, Ignore("Waiting to merge https://github.com/tensorflow/tensorflow/pull/43383")] | |||
public void UpdateEdge() | |||
{ | |||
using var graph = new Graph().as_default(); | |||
var one = tf.constant(1, name: "one"); | |||
var two = tf.constant(2, name: "two"); | |||
var add = tf.add(one, two, name: "add"); | |||
var neg = tf.negative(add, name: "neg"); | |||
Assert.AreEqual(1, one.consumers().Length); | |||
Assert.AreEqual("add", neg.op.node_def.Input[0]); | |||
// update edge | |||
neg.op._update_input(0, one); | |||
// c_api.TF_UpdateEdge(graph, new TF_Output(c1.op, 0), new TF_Input(neg.op, 0), tf.Status.Handle); | |||
Assert.AreEqual(2, one.consumers().Length); | |||
Assert.AreEqual("one:0", neg.op.node_def.Input[0]); | |||
} | |||
} | |||
} |
@@ -1,59 +0,0 @@ | |||
using System; | |||
using FluentAssertions; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using NumSharp; | |||
using Tensorflow; | |||
using Tensorflow.UnitTest; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.layers_test | |||
{ | |||
[TestClass] | |||
public class flatten : GraphModeTestBase | |||
{ | |||
[TestMethod] | |||
public void Case1() | |||
{ | |||
var sess = tf.Session().as_default(); | |||
var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, 3, 1, 2)); | |||
sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); | |||
} | |||
[TestMethod] | |||
public void Case2() | |||
{ | |||
var sess = tf.Session().as_default(); | |||
var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6)); | |||
sess.run(tf.layers.flatten(input), (input, np.arange(6))).Should().BeShaped(6, 1); | |||
} | |||
[TestMethod] | |||
public void Case3() | |||
{ | |||
var sess = tf.Session().as_default(); | |||
var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape()); | |||
new Action(() => sess.run(tf.layers.flatten(input), (input, NDArray.Scalar(6)))).Should().Throw<ValueError>(); | |||
} | |||
[TestMethod] | |||
public void Case4() | |||
{ | |||
var sess = tf.Session().as_default(); | |||
var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, Unknown, 1, 2)); | |||
sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); | |||
} | |||
[TestMethod] | |||
public void Case5() | |||
{ | |||
var sess = tf.Session().as_default(); | |||
var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(Unknown, 4, 3, 1, 2)); | |||
sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); | |||
} | |||
} | |||
} |