Browse Source

修复Embedding注释错误与core.utils.py代码错误

tags/v0.5.5
yh_cc 4 years ago
parent
commit
ff8b9a37a7
4 changed files with 35 additions and 6 deletions
  1. +1
    -1
      fastNLP/core/utils.py
  2. +2
    -2
      fastNLP/embeddings/gpt2_embedding.py
  3. +3
    -3
      fastNLP/embeddings/roberta_embedding.py
  4. +29
    -0
      test/core/test_trainer.py

+ 1
- 1
fastNLP/core/utils.py View File

@@ -779,7 +779,7 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
suggestions.append(f"You might need to set `{_miss_in_dataset}` as input. ")
if _miss_out_dataset:
_tmp = f"You need to provide `{_miss_out_dataset}` in DataSet and set it as input. "
if not dataset.collator.is_empty():
if not dataset.collater.is_empty():
_tmp += f'Or you need to add `{_miss_out_dataset}` in the output of your collate_fn. '
suggestions.append(_tmp)



+ 2
- 2
fastNLP/embeddings/gpt2_embedding.py View File

@@ -47,7 +47,7 @@ class GPT2Embedding(ContextualEmbedding):
>>> # torch.Size([1, 5, 3096])
"""

def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-small', layers: str = '-1',
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', layers: str = '-1',
pool_method: str = 'first', dropout=0, requires_grad: bool = True,
auto_truncate: bool = False, language_model: bool = False, **kwargs):
"""
@@ -152,7 +152,7 @@ class GPT2WordPieceEncoder(nn.Module):

"""

def __init__(self, model_dir_or_name: str = 'en-small', layers: str = '-1',
def __init__(self, model_dir_or_name: str = 'en', layers: str = '-1',
word_dropout=0, dropout=0, requires_grad: bool = True, language_model:bool=False):
"""



+ 3
- 3
fastNLP/embeddings/roberta_embedding.py View File

@@ -47,7 +47,7 @@ class RobertaEmbedding(ContextualEmbedding):
>>> # torch.Size([1, 5, 2304])
"""

def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1',
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', layers: str = '-1',
pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False,
pooled_cls=True, requires_grad: bool = True, auto_truncate: bool = False, **kwargs):
r"""
@@ -373,13 +373,13 @@ class RobertaWordPieceEncoder(nn.Module):
r"""
读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。

BertWordPieceEncoder可以支持自动下载权重,当前支持的模型:
RobertaWordPieceEncoder可以支持自动下载权重,当前支持的模型:
en: roberta-base
en-large: roberta-large

"""

def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False,
def __init__(self, model_dir_or_name: str = 'en', layers: str = '-1', pooled_cls: bool = False,
word_dropout=0, dropout=0, requires_grad: bool = True):
r"""



+ 29
- 0
test/core/test_trainer.py View File

@@ -1,9 +1,11 @@
import time
import unittest
import os

import numpy as np
import torch.nn.functional as F
from torch import nn
import torch

from fastNLP import DataSet
from fastNLP import Instance
@@ -228,6 +230,33 @@ class TrainerTestGround(unittest.TestCase):
trainer = Trainer(train_data=dataset, model=model, loss=CrossEntropyLoss(), print_every=2, dev_data=dataset,
metrics=AccuracyMetric(), use_tqdm=False)

@unittest.skipIf('TRAVIS' in os.environ, "Need to be tested in hosts with more than 1 gpus")
def test_trainer_data_parallel(self):
if torch.cuda.device_count()>1:
from fastNLP import AccuracyMetric
dataset = prepare_fake_dataset2('x1', 'x2')
dataset.set_input('x1', 'x2', 'y', flag=True)

class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)

def forward(self, x1, x2, y=None):
x1 = self.fc(x1)
x2 = self.fc(x2)
x = x1 + x2
if self.training:
loss = F.cross_entropy(x, y)
return {'loss': loss}
else:
return {'pred':x, 'target':y}

model = Model()
trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False,
dev_data=dataset, metrics=AccuracyMetric(), device=[0, 1])
trainer.train(load_best_model=False)

def test_udf_dataiter(self):
import random
import torch


Loading…
Cancel
Save