Browse Source

[update] collect_fn keep tracks for field

tags/v0.5.5
yunfan 5 years ago
parent
commit
54dcc23ad7
3 changed files with 18 additions and 8 deletions
  1. +6
    -0
      fastNLP/core/collect_fn.py
  2. +4
    -0
      fastNLP/core/dataset.py
  3. +8
    -8
      test/core/test_callbacks.py

+ 6
- 0
fastNLP/core/collect_fn.py View File

@@ -118,6 +118,12 @@ class Collector:
def outputs(self): def outputs(self):
return self.output2fn.keys() return self.output2fn.keys()


def copy_from(self, col):
assert isinstance(col, Collector)
self.fns = col.fns.copy()
self.input2fn = col.input2fn.copy()
self.output2fn = col.output2fn.copy()
self._clear_fn2io()


class CollectFn: class CollectFn:
def __init__(self): def __init__(self):


+ 4
- 0
fastNLP/core/dataset.py View File

@@ -569,6 +569,7 @@ class DataSet(object):
:param str field_name: 需要删除的field的名称. :param str field_name: 需要删除的field的名称.
""" """
self.field_arrays.pop(field_name) self.field_arrays.pop(field_name)
self.collector.drop_field(field_name)
return self return self


def copy_field(self, field_name, new_field_name): def copy_field(self, field_name, new_field_name):
@@ -641,6 +642,7 @@ class DataSet(object):
if field_name in self.field_arrays: if field_name in self.field_arrays:
self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) self.field_arrays[new_field_name] = self.field_arrays.pop(field_name)
self.field_arrays[new_field_name].name = new_field_name self.field_arrays[new_field_name].name = new_field_name
self.collector.rename_field(field_name, new_field_name)
else: else:
raise KeyError("DataSet has no field named {}.".format(field_name)) raise KeyError("DataSet has no field named {}.".format(field_name))
return self return self
@@ -933,6 +935,8 @@ class DataSet(object):
train_set.field_arrays[field_name].to(self.field_arrays[field_name]) train_set.field_arrays[field_name].to(self.field_arrays[field_name])
dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) dev_set.field_arrays[field_name].to(self.field_arrays[field_name])


train_set.collector.copy_from(self.collector)
dev_set.collector.copy_from(self.collector)
return train_set, dev_set return train_set, dev_set


def save(self, path): def save(self, path):


+ 8
- 8
test/core/test_callbacks.py View File

@@ -226,36 +226,36 @@ class TestCallback(unittest.TestCase):
callbacks=EarlyStopCallback(1), check_code_level=2) callbacks=EarlyStopCallback(1), check_code_level=2)
trainer.train() trainer.train()


@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_control_C(): def test_control_C():
# 用于测试 ControlC , 再两次训练时用 Control+C 进行退出,如果最后不显示 "Test failed!" 则通过测试 # 用于测试 ControlC , 再两次训练时用 Control+C 进行退出,如果最后不显示 "Test failed!" 则通过测试
from fastNLP import ControlC, Callback from fastNLP import ControlC, Callback
import time import time
line1 = "\n\n\n\n\n*************************" line1 = "\n\n\n\n\n*************************"
line2 = "*************************\n\n\n\n\n" line2 = "*************************\n\n\n\n\n"
class Wait(Callback): class Wait(Callback):
def on_epoch_end(self): def on_epoch_end(self):
time.sleep(5) time.sleep(5)
data_set, model = prepare_env() data_set, model = prepare_env()
print(line1 + "Test starts!" + line2) print(line1 + "Test starts!" + line2)
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=20, dev_data=data_set, batch_size=32, n_epochs=20, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=[Wait(), ControlC(False)], check_code_level=2) callbacks=[Wait(), ControlC(False)], check_code_level=2)
trainer.train() trainer.train()
print(line1 + "Program goes on ..." + line2) print(line1 + "Program goes on ..." + line2)
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=20, dev_data=data_set, batch_size=32, n_epochs=20, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=[Wait(), ControlC(True)], check_code_level=2) callbacks=[Wait(), ControlC(True)], check_code_level=2)
trainer.train() trainer.train()
print(line1 + "Test failed!" + line2) print(line1 + "Test failed!" + line2)






Loading…
Cancel
Save