diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 437b647a..86e06a7e 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -29,7 +29,7 @@ class Callback(object): @property def n_steps(self): """total number of steps for training""" - return self.n_steps + return self._trainer.n_steps @property def batch_size(self): @@ -124,6 +124,21 @@ class Callback(object): pass +def transfer(func): + """装饰器,将对CallbackManager的调用转发到各个Callback子类. + :param func: + :return: + """ + + def wrapper(manager, *arg): + returns = [] + for callback in manager.callbacks: + returns.append(getattr(callback, func.__name__)(*arg)) + return returns + + return wrapper + + class CallbackManager(Callback): """A manager for all callbacks passed into Trainer. It collects resources inside Trainer and raise callbacks. @@ -150,42 +165,59 @@ class CallbackManager(Callback): else: raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") + for env_name, env_val in env.items(): + for callback in self.callbacks: + setattr(callback, '_'+env_name, env_val) # Callback.trainer + + @transfer def on_train_begin(self): pass + @transfer def on_epoch_begin(self): pass + @transfer def on_batch_begin(self, batch_x, batch_y, indices): pass + @transfer def on_loss_begin(self, batch_y, predict_y): pass + @transfer def on_backward_begin(self, loss): pass + @transfer def on_backward_end(self): pass + @transfer def on_step_end(self): pass + @transfer def on_batch_end(self): pass + @transfer def on_valid_begin(self): pass + @transfer def on_valid_end(self, eval_result, metric_key): pass + @transfer def on_epoch_end(self): pass + @transfer def on_train_end(self): pass + @transfer def on_exception(self, exception): pass diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py index 7d66620c..3329e7a1 100644 --- a/test/core/test_callbacks.py +++ b/test/core/test_callbacks.py @@ -139,11 +139,14 @@ class TestCallback(unittest.TestCase): def test_readonly_property(self): from fastNLP.core.callback import Callback + passed_epochs = [] + total_epochs = 5 class MyCallback(Callback): def __init__(self): super(MyCallback, self).__init__() - def on_epoch_begin(self, cur_epoch, total_epoch): + def on_epoch_begin(self): + passed_epochs.append(self.epoch) print(self.n_epochs, self.n_steps, self.batch_size) print(self.model) print(self.optimizer) @@ -151,7 +154,7 @@ class TestCallback(unittest.TestCase): data_set, model = prepare_env() trainer = Trainer(data_set, model, loss=BCELoss(pred="predict", target="y"), - n_epochs=5, + n_epochs=total_epochs, batch_size=32, print_every=50, optimizer=SGD(lr=0.1), @@ -161,3 +164,4 @@ class TestCallback(unittest.TestCase): metrics=AccuracyMetric(pred="predict", target="y"), callbacks=[MyCallback()]) trainer.train() + assert passed_epochs == list(range(1, total_epochs+1)) diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index 607f9a13..356b157a 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -217,6 +217,7 @@ class TestDataSetMethods(unittest.TestCase): self.assertTrue(len(ds) > 0) def test_add_null(self): + # TODO test failed because 'fastNLP\core\fieldarray.py:143: RuntimeError' ds = DataSet() ds.add_field('test', []) ds.set_target('test') diff --git a/test/models/test_enas.py b/test/models/test_enas.py deleted file mode 100644 index 07a43205..00000000 --- a/test/models/test_enas.py +++ /dev/null @@ -1,112 +0,0 @@ -import unittest - -from fastNLP import DataSet -from fastNLP import Instance -from fastNLP import Vocabulary -from fastNLP.core.losses import CrossEntropyLoss -from fastNLP.core.metrics import AccuracyMetric - - -class TestENAS(unittest.TestCase): - def testENAS(self): - # 从csv读取数据到DataSet - sample_path = "tutorials/sample_data/tutorial_sample_dataset.csv" - dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'), - sep='\t') - print(len(dataset)) - print(dataset[0]) - print(dataset[-3]) - - dataset.append(Instance(raw_sentence='fake data', label='0')) - # 将所有数字转为小写 - dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence') - # label转int - dataset.apply(lambda x: int(x['label']), new_field_name='label') - - # 使用空格分割句子 - def split_sent(ins): - return ins['raw_sentence'].split() - - dataset.apply(split_sent, new_field_name='words') - - # 增加长度信息 - dataset.apply(lambda x: len(x['words']), new_field_name='seq_len') - print(len(dataset)) - print(dataset[0]) - - # DataSet.drop(func)筛除数据 - dataset.drop(lambda x: x['seq_len'] <= 3) - print(len(dataset)) - - # 设置DataSet中,哪些field要转为tensor - # set target,loss或evaluate中的golden,计算loss,模型评估时使用 - dataset.set_target("label") - # set input,模型forward时使用 - dataset.set_input("words", "seq_len") - - # 分出测试集、训练集 - test_data, train_data = dataset.split(0.5) - print(len(test_data)) - print(len(train_data)) - - # 构建词表, Vocabulary.add(word) - vocab = Vocabulary(min_freq=2) - train_data.apply(lambda x: [vocab.add(word) for word in x['words']]) - vocab.build_vocab() - - # index句子, Vocabulary.to_index(word) - train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') - test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') - print(test_data[0]) - - # 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具 - from fastNLP.core.batch import Batch - from fastNLP.core.sampler import RandomSampler - - batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler()) - for batch_x, batch_y in batch_iterator: - print("batch_x has: ", batch_x) - print("batch_y has: ", batch_y) - break - - from fastNLP.models.enas_model import ENASModel - from fastNLP.models.enas_controller import Controller - model = ENASModel(embed_num=len(vocab), num_classes=5) - controller = Controller() - - from fastNLP.models.enas_trainer import ENASTrainer - from copy import deepcopy - - # 更改DataSet中对应field的名称,要以模型的forward等参数名一致 - train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 - train_data.rename_field('label', 'label_seq') - test_data.rename_field('words', 'word_seq') - test_data.rename_field('label', 'label_seq') - - loss = CrossEntropyLoss(pred="output", target="label_seq") - metric = AccuracyMetric(pred="predict", target="label_seq") - - trainer = ENASTrainer(model=model, controller=controller, train_data=train_data, dev_data=test_data, - loss=CrossEntropyLoss(pred="output", target="label_seq"), - metrics=AccuracyMetric(pred="predict", target="label_seq"), - check_code_level=-1, - save_path=None, - batch_size=32, - print_every=1, - n_epochs=3, - final_epochs=1) - trainer.train() - print('Train finished!') - - # 调用Tester在test_data上评价效果 - from fastNLP import Tester - - tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"), - batch_size=4) - - acc = tester.test() - print(acc) - - -if __name__ == '__main__': - unittest.main() \ No newline at end of file