From 839d712467b83a6bea1aab0b90e95f1432fc3ba6 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Mon, 17 Jun 2019 16:46:39 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=BC=BAfield=E4=B8=AD=E7=9A=84value?= =?UTF-8?q?=5Fcount=E6=94=AF=E6=8C=81=E5=AF=B9nested=E7=9A=84field?= =?UTF-8?q?=E7=9A=84=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/field.py | 9 ++++++++- fastNLP/modules/encoder/embedding.py | 1 - reproduction/utils.py | 2 +- test/test_tutorials.py | 6 +++--- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index faa306f3..b0a36765 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -350,8 +350,15 @@ class FieldArray: :return: Counter, key是label,value是出现次数 """ count = Counter() + + def cum(cell): + if _is_iterable(cell) and not isinstance(cell, str): + for cell_ in cell: + cum(cell_) + else: + count[cell] += 1 for cell in self.content: - count[cell] += 1 + cum(cell) return count def _after_process(self, new_contents, inplace): diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index 631f57e9..5f0b6c3b 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -326,7 +326,6 @@ class ElmoEmbedding(ContextualEmbedding): # 根据model_dir_or_name检查是否存在并下载 PRETRAIN_URL = _get_base_url('elmo') - # TODO 把baidu云上的加上去 PRETRAINED_ELMO_MODEL_DIR = {'en': 'elmo_en-d39843fe.tar.gz', 'cn': 'elmo_cn-5e9b34e2.tar.gz'} diff --git a/reproduction/utils.py b/reproduction/utils.py index 58883b43..0d06c99c 100644 --- a/reproduction/utils.py +++ b/reproduction/utils.py @@ -24,7 +24,7 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: if not os.path.isfile(train_fp): raise FileNotFoundError(f"train.txt is not found in folder {paths}.") files = {'train': train_fp} - for filename in ['test.txt', 'dev.txt']: + for filename in ['dev.txt', 'test.txt']: fp = os.path.join(paths, filename) if os.path.isfile(fp): files[filename.split('.')[0]] = fp diff --git a/test/test_tutorials.py b/test/test_tutorials.py index a38d5ae1..2e971a4f 100644 --- a/test/test_tutorials.py +++ b/test/test_tutorials.py @@ -80,7 +80,7 @@ class TestTutorial(unittest.TestCase): test_data.rename_field('label', 'label_seq') loss = CrossEntropyLoss(pred="output", target="label_seq") - metric = AccuracyMetric(pred="predict", target="label_seq") + metric = AccuracyMetric(target="label_seq") # 实例化Trainer,传入模型和数据,进行训练 # 先在test_data拟合(确保模型的实现是正确的) @@ -96,7 +96,7 @@ class TestTutorial(unittest.TestCase): # 用train_data训练,在test_data验证 trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, loss=CrossEntropyLoss(pred="output", target="label_seq"), - metrics=AccuracyMetric(pred="predict", target="label_seq"), + metrics=AccuracyMetric(target="label_seq"), save_path=None, batch_size=32, n_epochs=5) @@ -106,7 +106,7 @@ class TestTutorial(unittest.TestCase): # 调用Tester在test_data上评价效果 from fastNLP import Tester - tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"), + tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(target="label_seq"), batch_size=4) acc = tester.test() print(acc)