diff --git a/fastNLP/core/collect_fn.py b/fastNLP/core/collect_fn.py index 14add06f..29f19e2c 100644 --- a/fastNLP/core/collect_fn.py +++ b/fastNLP/core/collect_fn.py @@ -118,6 +118,12 @@ class Collector: def outputs(self): return self.output2fn.keys() + def copy_from(self, col): + assert isinstance(col, Collector) + self.fns = col.fns.copy() + self.input2fn = col.input2fn.copy() + self.output2fn = col.output2fn.copy() + self._clear_fn2io() class CollectFn: def __init__(self): diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index b13eab76..c5210169 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -569,6 +569,7 @@ class DataSet(object): :param str field_name: 需要删除的field的名称. """ self.field_arrays.pop(field_name) + self.collector.drop_field(field_name) return self def copy_field(self, field_name, new_field_name): @@ -641,6 +642,7 @@ class DataSet(object): if field_name in self.field_arrays: self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) self.field_arrays[new_field_name].name = new_field_name + self.collector.rename_field(field_name, new_field_name) else: raise KeyError("DataSet has no field named {}.".format(field_name)) return self @@ -933,6 +935,8 @@ class DataSet(object): train_set.field_arrays[field_name].to(self.field_arrays[field_name]) dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) + train_set.collector.copy_from(self.collector) + dev_set.collector.copy_from(self.collector) return train_set, dev_set def save(self, path): diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py index 11cf6704..e756040c 100644 --- a/test/core/test_callbacks.py +++ b/test/core/test_callbacks.py @@ -226,36 +226,36 @@ class TestCallback(unittest.TestCase): callbacks=EarlyStopCallback(1), check_code_level=2) trainer.train() - +@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") def test_control_C(): # 用于测试 ControlC , 再两次训练时用 Control+C 进行退出,如果最后不显示 "Test failed!" 则通过测试 from fastNLP import ControlC, Callback import time - + line1 = "\n\n\n\n\n*************************" line2 = "*************************\n\n\n\n\n" - + class Wait(Callback): def on_epoch_end(self): time.sleep(5) - + data_set, model = prepare_env() - + print(line1 + "Test starts!" + line2) trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), batch_size=32, n_epochs=20, dev_data=data_set, metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, callbacks=[Wait(), ControlC(False)], check_code_level=2) trainer.train() - + print(line1 + "Program goes on ..." + line2) - + trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), batch_size=32, n_epochs=20, dev_data=data_set, metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, callbacks=[Wait(), ControlC(True)], check_code_level=2) trainer.train() - + print(line1 + "Test failed!" + line2)