Browse Source

fix a serial bugs on importing

tags/v0.4.10
ChenXin 5 years ago
parent
commit
a8a21b169a
12 changed files with 30 additions and 20 deletions
  1. +5
    -1
      fastNLP/core/dist_trainer.py
  2. +4
    -1
      fastNLP/embeddings/contextual_embedding.py
  3. +3
    -1
      fastNLP/io/config_io.py
  4. +3
    -3
      fastNLP/io/loader/conll.py
  5. +2
    -2
      fastNLP/io/loader/cws.py
  6. +1
    -1
      fastNLP/io/loader/loader.py
  7. +3
    -3
      fastNLP/io/loader/matching.py
  8. +3
    -2
      fastNLP/io/pipe/classification.py
  9. +1
    -1
      fastNLP/io/pipe/conll.py
  10. +2
    -2
      fastNLP/io/pipe/matching.py
  11. +2
    -2
      fastNLP/io/pipe/utils.py
  12. +1
    -1
      reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py

+ 5
- 1
fastNLP/core/dist_trainer.py View File

@@ -1,3 +1,6 @@
"""
正在开发中的分布式训练代码
"""
import torch
import torch.cuda
import torch.optim
@@ -41,7 +44,8 @@ def get_local_rank():


class DistTrainer():
"""Distributed Trainer that support distributed and mixed precision training
"""
Distributed Trainer that support distributed and mixed precision training
"""
def __init__(self, train_data, model, optimizer=None, loss=None,
callbacks_all=None, callbacks_master=None,


+ 4
- 1
fastNLP/embeddings/contextual_embedding.py View File

@@ -1,4 +1,3 @@

from abc import abstractmethod
import torch

@@ -9,6 +8,10 @@ from ..core.sampler import SequentialSampler
from ..core.utils import _move_model_to_device, _get_model_device
from .embedding import TokenEmbedding

__all__ = [
"ContextualEmbedding"
]


class ContextualEmbedding(TokenEmbedding):
def __init__(self, vocab: Vocabulary, word_dropout:float=0.0, dropout:float=0.0):


+ 3
- 1
fastNLP/io/config_io.py View File

@@ -1,7 +1,9 @@
"""
用于读入和处理和保存 config 文件
.. todo::

.. todo::
这个模块中的类可能被抛弃?

"""
__all__ = [
"ConfigLoader",


+ 3
- 3
fastNLP/io/loader/conll.py View File

@@ -1,12 +1,12 @@
from typing import Dict, Union

from .loader import Loader
from ... import DataSet
from ...core.dataset import DataSet
from ..file_reader import _read_conll
from ... import Instance
from ...core.instance import Instance
from .. import DataBundle
from ..utils import check_loader_paths
from ... import Const
from ...core.const import Const


class ConllLoader(Loader):


+ 2
- 2
fastNLP/io/loader/cws.py View File

@@ -1,6 +1,6 @@

from .loader import Loader
from ...core import DataSet, Instance
from ...core.dataset import DataSet
from ...core.instance import Instance


class CWSLoader(Loader):


+ 1
- 1
fastNLP/io/loader/loader.py View File

@@ -1,4 +1,4 @@
from ... import DataSet
from ...core.dataset import DataSet
from .. import DataBundle
from ..utils import check_loader_paths
from typing import Union, Dict


+ 3
- 3
fastNLP/io/loader/matching.py View File

@@ -1,12 +1,12 @@
import warnings
from .loader import Loader
from .json import JsonLoader
from ...core import Const
from ...core.const import Const
from .. import DataBundle
import os
from typing import Union, Dict
from ...core import DataSet
from ...core import Instance
from ...core.dataset import DataSet
from ...core.instance import Instance


class MNLILoader(Loader):


+ 3
- 2
fastNLP/io/pipe/classification.py View File

@@ -4,13 +4,14 @@ from ..base_loader import DataBundle
from ...core.vocabulary import Vocabulary
from ...core.const import Const
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader
from ...core import DataSet, Instance
from ...core.dataset import DataSet
from ...core.instance import Instance

from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance
from .pipe import Pipe
import re
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+')
from ...core import cache_results
from ...core.utils import cache_results

class _CLSPipe(Pipe):
"""


+ 1
- 1
fastNLP/io/pipe/conll.py View File

@@ -1,7 +1,7 @@
from .pipe import Pipe
from .. import DataBundle
from .utils import iob2, iob2bioes
from ... import Const
from ...core.const import Const
from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader
from .utils import _indexize, _add_words_field



+ 2
- 2
fastNLP/io/pipe/matching.py View File

@@ -2,8 +2,8 @@ import math

from .pipe import Pipe
from .utils import get_tokenizer
from ...core import Const
from ...core import Vocabulary
from ...core.const import Const
from ...core.vocabulary import Vocabulary
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader




+ 2
- 2
fastNLP/io/pipe/utils.py View File

@@ -1,6 +1,6 @@
from typing import List
from ...core import Vocabulary
from ...core import Const
from ...core.vocabulary import Vocabulary
from ...core.const import Const

def iob2(tags:List[str])->List[str]:
"""


+ 1
- 1
reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py View File

@@ -51,7 +51,7 @@ class ChineseNERLoader(DataSetLoader):
:param paths:
:param bool, bigrams: 是否包含生成bigram feature, [a, b, c, d] -> [ab, bc, cd, d<eos>]
:param bool, trigrams: 是否包含trigram feature,[a, b, c, d] -> [abc, bcd, cd<eos>, d<eos><eos>]
:return: DataBundle
:return: ~fastNLP.io.DataBundle
包含以下的fields
raw_chars: List[str]
chars: List[int]


Loading…
Cancel
Save