@@ -3,7 +3,7 @@ using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
public class GRUCellArgs : AutoSerializeLayerArgs | public class GRUCellArgs : AutoSerializeLayerArgs | ||||
{ | { | ||||
@@ -1,4 +1,4 @@ | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
public class LSTMArgs : RNNArgs | public class LSTMArgs : RNNArgs | ||||
{ | { | ||||
@@ -1,7 +1,7 @@ | |||||
using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
// TODO: complete the implementation | // TODO: complete the implementation | ||||
public class LSTMCellArgs : AutoSerializeLayerArgs | public class LSTMCellArgs : AutoSerializeLayerArgs | ||||
@@ -1,8 +1,8 @@ | |||||
using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Keras.Layers.Rnn; | |||||
using Tensorflow.Keras.Layers; | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
// TODO(Rinne): add regularizers. | // TODO(Rinne): add regularizers. | ||||
public class RNNArgs : AutoSerializeLayerArgs | public class RNNArgs : AutoSerializeLayerArgs | ||||
@@ -23,16 +23,22 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
public int? InputDim { get; set; } | public int? InputDim { get; set; } | ||||
public int? InputLength { get; set; } | public int? InputLength { get; set; } | ||||
// TODO: Add `num_constants` and `zero_output_for_mask`. | // TODO: Add `num_constants` and `zero_output_for_mask`. | ||||
[JsonProperty("units")] | |||||
public int Units { get; set; } | public int Units { get; set; } | ||||
[JsonProperty("activation")] | |||||
public Activation Activation { get; set; } | public Activation Activation { get; set; } | ||||
[JsonProperty("recurrent_activation")] | |||||
public Activation RecurrentActivation { get; set; } | public Activation RecurrentActivation { get; set; } | ||||
[JsonProperty("use_bias")] | |||||
public bool UseBias { get; set; } = true; | public bool UseBias { get; set; } = true; | ||||
public IInitializer KernelInitializer { get; set; } | public IInitializer KernelInitializer { get; set; } | ||||
public IInitializer RecurrentInitializer { get; set; } | public IInitializer RecurrentInitializer { get; set; } | ||||
public IInitializer BiasInitializer { get; set; } | public IInitializer BiasInitializer { get; set; } | ||||
[JsonProperty("dropout")] | |||||
public float Dropout { get; set; } = .0f; | public float Dropout { get; set; } = .0f; | ||||
[JsonProperty("zero_output_for_mask")] | |||||
public bool ZeroOutputForMask { get; set; } = false; | public bool ZeroOutputForMask { get; set; } = false; | ||||
[JsonProperty("recurrent_dropout")] | |||||
public float RecurrentDropout { get; set; } = .0f; | public float RecurrentDropout { get; set; } = .0f; | ||||
} | } | ||||
} | } |
@@ -3,7 +3,7 @@ using System.Collections.Generic; | |||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
public class RnnOptionalArgs: IOptionalArgs | public class RnnOptionalArgs: IOptionalArgs | ||||
{ | { | ||||
@@ -1,4 +1,4 @@ | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
public class SimpleRNNArgs : RNNArgs | public class SimpleRNNArgs : RNNArgs | ||||
{ | { | ||||
@@ -1,6 +1,6 @@ | |||||
using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
public class SimpleRNNCellArgs: AutoSerializeLayerArgs | public class SimpleRNNCellArgs: AutoSerializeLayerArgs | ||||
{ | { | ||||
@@ -1,7 +1,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Keras.Layers.Rnn; | |||||
using Tensorflow.Keras.Layers; | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
public class StackedRNNCellsArgs : LayerArgs | public class StackedRNNCellsArgs : LayerArgs | ||||
{ | { | ||||
@@ -1,7 +1,7 @@ | |||||
using System; | using System; | ||||
using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Layers.Rnn; | |||||
using Tensorflow.Keras.Layers; | |||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | ||||
@@ -3,7 +3,7 @@ using System.Collections.Generic; | |||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
namespace Tensorflow.Keras.Layers.Rnn | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
public interface IRnnCell: ILayer | public interface IRnnCell: ILayer | ||||
{ | { | ||||
@@ -2,7 +2,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
namespace Tensorflow.Keras.Layers.Rnn | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
public interface IStackedRnnCells : IRnnCell | public interface IStackedRnnCells : IRnnCell | ||||
{ | { | ||||
@@ -19,9 +19,8 @@ using System.Collections.Generic; | |||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
using Tensorflow.Keras; | using Tensorflow.Keras; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Layers.Rnn; | |||||
using Tensorflow.Keras.Layers; | |||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
@@ -571,7 +571,9 @@ namespace Tensorflow | |||||
if (tf.Context.executing_eagerly()) | if (tf.Context.executing_eagerly()) | ||||
return true; | return true; | ||||
else | else | ||||
throw new NotImplementedException(""); | |||||
// TODO(Wanglongzhi2001), implement the false case | |||||
return true; | |||||
//throw new NotImplementedException(""); | |||||
} | } | ||||
public static bool inside_function() | public static bool inside_function() | ||||
@@ -2,9 +2,8 @@ | |||||
using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.ArgsDefinition.Core; | using Tensorflow.Keras.ArgsDefinition.Core; | ||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Layers.Rnn; | |||||
using Tensorflow.Keras.Layers; | |||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
@@ -6,7 +6,7 @@ using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
namespace Tensorflow.Keras.Layers.Rnn | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
public abstract class DropoutRNNCellMixin: Layer, IRnnCell | public abstract class DropoutRNNCellMixin: Layer, IRnnCell | ||||
{ | { | ||||
@@ -3,12 +3,11 @@ using System.Collections.Generic; | |||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.Common.Extensions; | using Tensorflow.Common.Extensions; | ||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
namespace Tensorflow.Keras.Layers.Rnn | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Cell class for the GRU layer. | /// Cell class for the GRU layer. | ||||
@@ -1,10 +1,10 @@ | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
using Tensorflow.Common.Extensions; | using Tensorflow.Common.Extensions; | ||||
namespace Tensorflow.Keras.Layers.Rnn | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Long Short-Term Memory layer - Hochreiter 1997. | /// Long Short-Term Memory layer - Hochreiter 1997. | ||||
@@ -3,12 +3,12 @@ using Serilog.Core; | |||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using Tensorflow.Common.Extensions; | using Tensorflow.Common.Extensions; | ||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
namespace Tensorflow.Keras.Layers.Rnn | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Cell class for the LSTM layer. | /// Cell class for the LSTM layer. | ||||
@@ -3,7 +3,6 @@ using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Reflection; | using System.Reflection; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
@@ -14,7 +13,7 @@ using Tensorflow.Common.Types; | |||||
using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
// from tensorflow.python.distribute import distribution_strategy_context as ds_context; | // from tensorflow.python.distribute import distribution_strategy_context as ds_context; | ||||
namespace Tensorflow.Keras.Layers.Rnn | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Base class for recurrent layers. | /// Base class for recurrent layers. | ||||
@@ -185,6 +184,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
public override void build(KerasShapesWrapper input_shape) | public override void build(KerasShapesWrapper input_shape) | ||||
{ | { | ||||
_buildInputShape = input_shape; | |||||
input_shape = new KerasShapesWrapper(input_shape.Shapes[0]); | input_shape = new KerasShapesWrapper(input_shape.Shapes[0]); | ||||
InputSpec get_input_spec(Shape shape) | InputSpec get_input_spec(Shape shape) | ||||
@@ -4,7 +4,7 @@ using System.Text; | |||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
namespace Tensorflow.Keras.Layers.Rnn | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
public abstract class RnnBase: Layer | public abstract class RnnBase: Layer | ||||
{ | { | ||||
@@ -1,11 +1,11 @@ | |||||
using System.Data; | using System.Data; | ||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
using Tensorflow.Operations.Activation; | using Tensorflow.Operations.Activation; | ||||
using static HDF.PInvoke.H5Z; | using static HDF.PInvoke.H5Z; | ||||
using static Tensorflow.ApiDef.Types; | using static Tensorflow.ApiDef.Types; | ||||
namespace Tensorflow.Keras.Layers.Rnn | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
public class SimpleRNN : RNN | public class SimpleRNN : RNN | ||||
{ | { | ||||
@@ -1,7 +1,7 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
@@ -9,7 +9,7 @@ using Tensorflow.Common.Extensions; | |||||
using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
namespace Tensorflow.Keras.Layers.Rnn | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Cell class for SimpleRNN. | /// Cell class for SimpleRNN. | ||||
@@ -3,12 +3,12 @@ using System.ComponentModel; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Common.Extensions; | using Tensorflow.Common.Extensions; | ||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
namespace Tensorflow.Keras.Layers.Rnn | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
public class StackedRNNCells : Layer, IRnnCell | public class StackedRNNCells : Layer, IRnnCell | ||||
{ | { | ||||
@@ -13,7 +13,6 @@ using Tensorflow.Framework.Models; | |||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Layers; | using Tensorflow.Keras.Layers; | ||||
using Tensorflow.Keras.Layers.Rnn; | |||||
using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
using Tensorflow.Keras.Metrics; | using Tensorflow.Keras.Metrics; | ||||
using Tensorflow.Keras.Saving.SavedModel; | using Tensorflow.Keras.Saving.SavedModel; | ||||
@@ -3,7 +3,7 @@ using System.Collections.Generic; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Layers.Rnn; | |||||
using Tensorflow.Keras.Layers; | |||||
using Tensorflow.Keras.Metrics; | using Tensorflow.Keras.Metrics; | ||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
@@ -3,7 +3,7 @@ using System.Collections.Generic; | |||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
using Tensorflow.Keras.Layers.Rnn; | |||||
using Tensorflow.Keras.Layers; | |||||
using Tensorflow.Common.Extensions; | using Tensorflow.Common.Extensions; | ||||
namespace Tensorflow.Keras.Utils | namespace Tensorflow.Keras.Utils | ||||
@@ -6,7 +6,7 @@ using System.Text; | |||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Layers.Rnn; | |||||
using Tensorflow.Keras.Layers; | |||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
@@ -1,5 +1,7 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using Microsoft.VisualStudio.TestPlatform.Utilities; | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
using Tensorflow.Keras.UnitTest.Helpers; | using Tensorflow.Keras.UnitTest.Helpers; | ||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
@@ -79,6 +81,31 @@ public class ModelLoadTest | |||||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | ||||
} | } | ||||
[TestMethod] | |||||
public void LSTMLoad() | |||||
{ | |||||
var inputs = np.random.randn(10, 5, 3); | |||||
var outputs = np.random.randn(10, 1); | |||||
var model = keras.Sequential(); | |||||
model.add(keras.Input(shape: (5, 3))); | |||||
var lstm = keras.layers.LSTM(32); | |||||
model.add(lstm); | |||||
model.add(keras.layers.Dense(1, keras.activations.Sigmoid)); | |||||
model.compile(optimizer: keras.optimizers.Adam(), | |||||
loss: keras.losses.MeanSquaredError(), | |||||
new[] { "accuracy" }); | |||||
var result = model.fit(inputs.numpy(), outputs.numpy(), batch_size: 10, epochs: 3, workers: 16, use_multiprocessing: true); | |||||
model.save("LSTM_Random"); | |||||
var model_loaded = keras.models.load_model("LSTM_Random"); | |||||
model_loaded.summary(); | |||||
} | |||||
[Ignore] | [Ignore] | ||||
[TestMethod] | [TestMethod] | ||||
public void VGG19() | public void VGG19() | ||||