Browse Source

增强field中的value_count支持对nested的field的支持

tags/v0.4.10
yh_cc 6 years ago
parent
commit
839d712467
4 changed files with 12 additions and 6 deletions
  1. +8
    -1
      fastNLP/core/field.py
  2. +0
    -1
      fastNLP/modules/encoder/embedding.py
  3. +1
    -1
      reproduction/utils.py
  4. +3
    -3
      test/test_tutorials.py

+ 8
- 1
fastNLP/core/field.py View File

@@ -350,8 +350,15 @@ class FieldArray:
:return: Counter, key是label,value是出现次数
"""
count = Counter()

def cum(cell):
if _is_iterable(cell) and not isinstance(cell, str):
for cell_ in cell:
cum(cell_)
else:
count[cell] += 1
for cell in self.content:
count[cell] += 1
cum(cell)
return count

def _after_process(self, new_contents, inplace):


+ 0
- 1
fastNLP/modules/encoder/embedding.py View File

@@ -326,7 +326,6 @@ class ElmoEmbedding(ContextualEmbedding):

# 根据model_dir_or_name检查是否存在并下载
PRETRAIN_URL = _get_base_url('elmo')
# TODO 把baidu云上的加上去
PRETRAINED_ELMO_MODEL_DIR = {'en': 'elmo_en-d39843fe.tar.gz',
'cn': 'elmo_cn-5e9b34e2.tar.gz'}



+ 1
- 1
reproduction/utils.py View File

@@ -24,7 +24,7 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
if not os.path.isfile(train_fp):
raise FileNotFoundError(f"train.txt is not found in folder {paths}.")
files = {'train': train_fp}
for filename in ['test.txt', 'dev.txt']:
for filename in ['dev.txt', 'test.txt']:
fp = os.path.join(paths, filename)
if os.path.isfile(fp):
files[filename.split('.')[0]] = fp


+ 3
- 3
test/test_tutorials.py View File

@@ -80,7 +80,7 @@ class TestTutorial(unittest.TestCase):
test_data.rename_field('label', 'label_seq')

loss = CrossEntropyLoss(pred="output", target="label_seq")
metric = AccuracyMetric(pred="predict", target="label_seq")
metric = AccuracyMetric(target="label_seq")

# 实例化Trainer,传入模型和数据,进行训练
# 先在test_data拟合(确保模型的实现是正确的)
@@ -96,7 +96,7 @@ class TestTutorial(unittest.TestCase):
# 用train_data训练,在test_data验证
trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,
loss=CrossEntropyLoss(pred="output", target="label_seq"),
metrics=AccuracyMetric(pred="predict", target="label_seq"),
metrics=AccuracyMetric(target="label_seq"),
save_path=None,
batch_size=32,
n_epochs=5)
@@ -106,7 +106,7 @@ class TestTutorial(unittest.TestCase):
# 调用Tester在test_data上评价效果
from fastNLP import Tester

tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"),
tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(target="label_seq"),
batch_size=4)
acc = tester.test()
print(acc)


Loading…
Cancel
Save