diff --git a/fastNLP/automl/__init__.py b/fastNLP/automl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastNLP/models/enas_controller.py b/fastNLP/automl/enas_controller.py similarity index 98% rename from fastNLP/models/enas_controller.py rename to fastNLP/automl/enas_controller.py index ae9bcfd2..6ddbb211 100644 --- a/fastNLP/models/enas_controller.py +++ b/fastNLP/automl/enas_controller.py @@ -5,9 +5,9 @@ import os import torch import torch.nn.functional as F -import fastNLP -import fastNLP.models.enas_utils as utils -from fastNLP.models.enas_utils import Node + +import fastNLP.automl.enas_utils as utils +from fastNLP.automl.enas_utils import Node def _construct_dags(prev_nodes, activations, func_names, num_blocks): diff --git a/fastNLP/models/enas_model.py b/fastNLP/automl/enas_model.py similarity index 99% rename from fastNLP/models/enas_model.py rename to fastNLP/automl/enas_model.py index cc91e675..4f9fb449 100644 --- a/fastNLP/models/enas_model.py +++ b/fastNLP/automl/enas_model.py @@ -1,17 +1,17 @@ # Code Modified from https://github.com/carpedm20/ENAS-pytorch """Module containing the shared RNN model.""" -import numpy as np import collections +import numpy as np import torch -from torch import nn import torch.nn.functional as F +from torch import nn from torch.autograd import Variable -import fastNLP.models.enas_utils as utils +import fastNLP.automl.enas_utils as utils from fastNLP.models.base_model import BaseModel -import fastNLP.modules.encoder as encoder + def _get_dropped_weights(w_raw, dropout_p, is_training): """Drops out weights to implement DropConnect. diff --git a/fastNLP/models/enas_trainer.py b/fastNLP/automl/enas_trainer.py similarity index 98% rename from fastNLP/models/enas_trainer.py rename to fastNLP/automl/enas_trainer.py index 22e323ce..7c0da752 100644 --- a/fastNLP/models/enas_trainer.py +++ b/fastNLP/automl/enas_trainer.py @@ -1,14 +1,12 @@ # Code Modified from https://github.com/carpedm20/ENAS-pytorch -import os +import math import time from datetime import datetime from datetime import timedelta import numpy as np import torch -import math -from torch import nn try: from tqdm.autonotebook import tqdm @@ -16,12 +14,11 @@ except: from fastNLP.core.utils import pseudo_tqdm as tqdm from fastNLP.core.batch import Batch -from fastNLP.core.callback import CallbackManager, CallbackException +from fastNLP.core.callback import CallbackException from fastNLP.core.dataset import DataSet -from fastNLP.core.utils import CheckError from fastNLP.core.utils import _move_dict_value_to_device import fastNLP -import fastNLP.models.enas_utils as utils +import fastNLP.automl.enas_utils as utils from fastNLP.core.utils import _build_args from torch.optim import Adam diff --git a/fastNLP/models/enas_utils.py b/fastNLP/automl/enas_utils.py similarity index 96% rename from fastNLP/models/enas_utils.py rename to fastNLP/automl/enas_utils.py index e5027d81..7a53dd12 100644 --- a/fastNLP/models/enas_utils.py +++ b/fastNLP/automl/enas_utils.py @@ -2,17 +2,14 @@ from __future__ import print_function -from collections import defaultdict import collections -from datetime import datetime -import os -import json +from collections import defaultdict import numpy as np - import torch from torch.autograd import Variable + def detach(h): if type(h) == Variable: return Variable(h.data) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index d941c235..e3b4f36e 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -17,6 +17,38 @@ class Callback(object): super(Callback, self).__init__() self.trainer = None # 在Trainer内部被重新赋值 + # callback只读属性 + self._n_epochs = None + self._n_steps = None + self._batch_size = None + self._model = None + self._pbar = None + self._optimizer = None + + @property + def n_epochs(self): + return self._n_epochs + + @property + def n_steps(self): + return self._n_steps + + @property + def batch_size(self): + return self._batch_size + + @property + def model(self): + return self._model + + @property + def pbar(self): + return self._pbar + + @property + def optimizer(self): + return self._optimizer + def on_train_begin(self): # before the main training loop pass @@ -101,8 +133,6 @@ def transfer(func): def wrapper(manager, *arg): returns = [] for callback in manager.callbacks: - for env_name, env_value in manager.env.items(): - setattr(callback, env_name, env_value) returns.append(getattr(callback, func.__name__)(*arg)) return returns @@ -115,15 +145,15 @@ class CallbackManager(Callback): """ - def __init__(self, env, callbacks=None): + def __init__(self, env, attr, callbacks=None): """ :param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. + :param dict attr: read-only attributes for all callbacks :param Callback callbacks: """ super(CallbackManager, self).__init__() # set attribute of trainer environment - self.env = env self.callbacks = [] if callbacks is not None: @@ -136,6 +166,23 @@ class CallbackManager(Callback): else: raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") + for env_name, env_val in env.items(): + for callback in self.callbacks: + setattr(callback, env_name, env_val) # Callback.trainer + + self.set_property(**attr) + + def set_property(self, **kwargs): + """设置所有callback的只读属性 + + :param kwargs: + :return: + """ + for callback in self.callbacks: + for k, v in kwargs.items(): + setattr(callback, "_" + k, v) + + @transfer def on_train_begin(self): pass diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 54fde815..9d581798 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -157,7 +157,7 @@ class MetricBase(object): fast_param = {} if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: fast_param['pred'] = list(pred_dict.values())[0] - fast_param['target'] = list(pred_dict.values())[0] + fast_param['target'] = list(target_dict.values())[0] return fast_param return fast_param diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 8880291d..ca2ff93b 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -121,7 +121,6 @@ class Trainer(object): self.best_dev_perf = None self.sampler = sampler if sampler is not None else RandomSampler() self.prefetch = prefetch - self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer @@ -144,6 +143,12 @@ class Trainer(object): self.step = 0 self.start_time = None # start timestamp + self.callback_manager = CallbackManager(env={"trainer": self}, + attr={"n_epochs": self.n_epochs, "n_steps": self.step, + "batch_size": self.batch_size, "model": self.model, + "optimizer": self.optimizer}, + callbacks=callbacks) + def train(self, load_best_model=True): """ @@ -236,6 +241,7 @@ class Trainer(object): avg_loss = 0 data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, prefetch=self.prefetch) + self.callback_manager.set_property(pbar=pbar) for epoch in range(1, self.n_epochs+1): pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) # early stopping @@ -361,6 +367,8 @@ class Trainer(object): """ if self.save_path is not None: model_path = os.path.join(self.save_path, model_name) + if not os.path.exists(self.save_path): + os.makedirs(self.save_path, exist_ok=True) if only_param: state_dict = model.state_dict() for key in state_dict: diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 987a3527..a1c8e678 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -196,3 +196,9 @@ class Vocabulary(object): """ self.__dict__.update(state) self.build_reverse_vocab() + + def __repr__(self): + return "Vocabulary({}...)".format(list(self.word_count.keys())[:5]) + + def __iter__(self): + return iter(list(self.word_count.keys())) diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index 5287bca4..4ae15b18 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -60,7 +60,8 @@ def initial_parameter(net, initial_method=None): init_method(w.data) # weight else: init.normal_(w.data) # bias - elif hasattr(m, 'weight') and m.weight.requires_grad: + elif m is not None and hasattr(m, 'weight') and \ + hasattr(m.weight, "requires_grad"): init_method(m.weight.data) else: for w in m.parameters(): diff --git a/test/models/test_enas.py b/test/automl/test_enas.py similarity index 94% rename from test/models/test_enas.py rename to test/automl/test_enas.py index 07a43205..d2d3af05 100644 --- a/test/models/test_enas.py +++ b/test/automl/test_enas.py @@ -69,13 +69,12 @@ class TestENAS(unittest.TestCase): print("batch_y has: ", batch_y) break - from fastNLP.models.enas_model import ENASModel - from fastNLP.models.enas_controller import Controller + from fastNLP.automl.enas_model import ENASModel + from fastNLP.automl.enas_controller import Controller model = ENASModel(embed_num=len(vocab), num_classes=5) controller = Controller() - from fastNLP.models.enas_trainer import ENASTrainer - from copy import deepcopy + from fastNLP.automl.enas_trainer import ENASTrainer # 更改DataSet中对应field的名称,要以模型的forward等参数名一致 train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py index 74ce4876..7d66620c 100644 --- a/test/core/test_callbacks.py +++ b/test/core/test_callbacks.py @@ -136,3 +136,28 @@ class TestCallback(unittest.TestCase): metrics=AccuracyMetric(pred="predict", target="y"), callbacks=[TensorboardCallback("loss", "metric")]) trainer.train() + + def test_readonly_property(self): + from fastNLP.core.callback import Callback + class MyCallback(Callback): + def __init__(self): + super(MyCallback, self).__init__() + + def on_epoch_begin(self, cur_epoch, total_epoch): + print(self.n_epochs, self.n_steps, self.batch_size) + print(self.model) + print(self.optimizer) + + data_set, model = prepare_env() + trainer = Trainer(data_set, model, + loss=BCELoss(pred="predict", target="y"), + n_epochs=5, + batch_size=32, + print_every=50, + optimizer=SGD(lr=0.1), + check_code_level=2, + use_tqdm=False, + dev_data=data_set, + metrics=AccuracyMetric(pred="predict", target="y"), + callbacks=[MyCallback()]) + trainer.train() diff --git a/test/core/test_vocabulary.py b/test/core/test_vocabulary.py index af2c493b..2f9cd3b1 100644 --- a/test/core/test_vocabulary.py +++ b/test/core/test_vocabulary.py @@ -60,6 +60,15 @@ class TestIndexing(unittest.TestCase): vocab.update(text) self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) + def test_iteration(self): + vocab = Vocabulary() + text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", + "works", "well", "in", "most", "cases", "scales", "well"] + vocab.update(text) + text = set(text) + for word in vocab: + self.assertTrue(word in text) + class TestOther(unittest.TestCase): def test_additional_update(self):