@@ -350,8 +350,15 @@ class FieldArray: | |||||
:return: Counter, key是label,value是出现次数 | :return: Counter, key是label,value是出现次数 | ||||
""" | """ | ||||
count = Counter() | count = Counter() | ||||
def cum(cell): | |||||
if _is_iterable(cell) and not isinstance(cell, str): | |||||
for cell_ in cell: | |||||
cum(cell_) | |||||
else: | |||||
count[cell] += 1 | |||||
for cell in self.content: | for cell in self.content: | ||||
count[cell] += 1 | |||||
cum(cell) | |||||
return count | return count | ||||
def _after_process(self, new_contents, inplace): | def _after_process(self, new_contents, inplace): | ||||
@@ -326,7 +326,6 @@ class ElmoEmbedding(ContextualEmbedding): | |||||
# 根据model_dir_or_name检查是否存在并下载 | # 根据model_dir_or_name检查是否存在并下载 | ||||
PRETRAIN_URL = _get_base_url('elmo') | PRETRAIN_URL = _get_base_url('elmo') | ||||
# TODO 把baidu云上的加上去 | |||||
PRETRAINED_ELMO_MODEL_DIR = {'en': 'elmo_en-d39843fe.tar.gz', | PRETRAINED_ELMO_MODEL_DIR = {'en': 'elmo_en-d39843fe.tar.gz', | ||||
'cn': 'elmo_cn-5e9b34e2.tar.gz'} | 'cn': 'elmo_cn-5e9b34e2.tar.gz'} | ||||
@@ -24,7 +24,7 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
if not os.path.isfile(train_fp): | if not os.path.isfile(train_fp): | ||||
raise FileNotFoundError(f"train.txt is not found in folder {paths}.") | raise FileNotFoundError(f"train.txt is not found in folder {paths}.") | ||||
files = {'train': train_fp} | files = {'train': train_fp} | ||||
for filename in ['test.txt', 'dev.txt']: | |||||
for filename in ['dev.txt', 'test.txt']: | |||||
fp = os.path.join(paths, filename) | fp = os.path.join(paths, filename) | ||||
if os.path.isfile(fp): | if os.path.isfile(fp): | ||||
files[filename.split('.')[0]] = fp | files[filename.split('.')[0]] = fp | ||||
@@ -80,7 +80,7 @@ class TestTutorial(unittest.TestCase): | |||||
test_data.rename_field('label', 'label_seq') | test_data.rename_field('label', 'label_seq') | ||||
loss = CrossEntropyLoss(pred="output", target="label_seq") | loss = CrossEntropyLoss(pred="output", target="label_seq") | ||||
metric = AccuracyMetric(pred="predict", target="label_seq") | |||||
metric = AccuracyMetric(target="label_seq") | |||||
# 实例化Trainer,传入模型和数据,进行训练 | # 实例化Trainer,传入模型和数据,进行训练 | ||||
# 先在test_data拟合(确保模型的实现是正确的) | # 先在test_data拟合(确保模型的实现是正确的) | ||||
@@ -96,7 +96,7 @@ class TestTutorial(unittest.TestCase): | |||||
# 用train_data训练,在test_data验证 | # 用train_data训练,在test_data验证 | ||||
trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, | trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, | ||||
loss=CrossEntropyLoss(pred="output", target="label_seq"), | loss=CrossEntropyLoss(pred="output", target="label_seq"), | ||||
metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||||
metrics=AccuracyMetric(target="label_seq"), | |||||
save_path=None, | save_path=None, | ||||
batch_size=32, | batch_size=32, | ||||
n_epochs=5) | n_epochs=5) | ||||
@@ -106,7 +106,7 @@ class TestTutorial(unittest.TestCase): | |||||
# 调用Tester在test_data上评价效果 | # 调用Tester在test_data上评价效果 | ||||
from fastNLP import Tester | from fastNLP import Tester | ||||
tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||||
tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(target="label_seq"), | |||||
batch_size=4) | batch_size=4) | ||||
acc = tester.test() | acc = tester.test() | ||||
print(acc) | print(acc) | ||||