@@ -1,3 +1,6 @@ | |||||
""" | |||||
正在开发中的分布式训练代码 | |||||
""" | |||||
import torch | import torch | ||||
import torch.cuda | import torch.cuda | ||||
import torch.optim | import torch.optim | ||||
@@ -41,7 +44,8 @@ def get_local_rank(): | |||||
class DistTrainer(): | class DistTrainer(): | ||||
"""Distributed Trainer that support distributed and mixed precision training | |||||
""" | |||||
Distributed Trainer that support distributed and mixed precision training | |||||
""" | """ | ||||
def __init__(self, train_data, model, optimizer=None, loss=None, | def __init__(self, train_data, model, optimizer=None, loss=None, | ||||
callbacks_all=None, callbacks_master=None, | callbacks_all=None, callbacks_master=None, | ||||
@@ -1,4 +1,3 @@ | |||||
from abc import abstractmethod | from abc import abstractmethod | ||||
import torch | import torch | ||||
@@ -9,6 +8,10 @@ from ..core.sampler import SequentialSampler | |||||
from ..core.utils import _move_model_to_device, _get_model_device | from ..core.utils import _move_model_to_device, _get_model_device | ||||
from .embedding import TokenEmbedding | from .embedding import TokenEmbedding | ||||
__all__ = [ | |||||
"ContextualEmbedding" | |||||
] | |||||
class ContextualEmbedding(TokenEmbedding): | class ContextualEmbedding(TokenEmbedding): | ||||
def __init__(self, vocab: Vocabulary, word_dropout:float=0.0, dropout:float=0.0): | def __init__(self, vocab: Vocabulary, word_dropout:float=0.0, dropout:float=0.0): | ||||
@@ -1,7 +1,9 @@ | |||||
""" | """ | ||||
用于读入和处理和保存 config 文件 | 用于读入和处理和保存 config 文件 | ||||
.. todo:: | |||||
.. todo:: | |||||
这个模块中的类可能被抛弃? | 这个模块中的类可能被抛弃? | ||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
"ConfigLoader", | "ConfigLoader", | ||||
@@ -1,12 +1,12 @@ | |||||
from typing import Dict, Union | from typing import Dict, Union | ||||
from .loader import Loader | from .loader import Loader | ||||
from ... import DataSet | |||||
from ...core.dataset import DataSet | |||||
from ..file_reader import _read_conll | from ..file_reader import _read_conll | ||||
from ... import Instance | |||||
from ...core.instance import Instance | |||||
from .. import DataBundle | from .. import DataBundle | ||||
from ..utils import check_loader_paths | from ..utils import check_loader_paths | ||||
from ... import Const | |||||
from ...core.const import Const | |||||
class ConllLoader(Loader): | class ConllLoader(Loader): | ||||
@@ -1,6 +1,6 @@ | |||||
from .loader import Loader | from .loader import Loader | ||||
from ...core import DataSet, Instance | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
class CWSLoader(Loader): | class CWSLoader(Loader): | ||||
@@ -1,4 +1,4 @@ | |||||
from ... import DataSet | |||||
from ...core.dataset import DataSet | |||||
from .. import DataBundle | from .. import DataBundle | ||||
from ..utils import check_loader_paths | from ..utils import check_loader_paths | ||||
from typing import Union, Dict | from typing import Union, Dict | ||||
@@ -1,12 +1,12 @@ | |||||
import warnings | import warnings | ||||
from .loader import Loader | from .loader import Loader | ||||
from .json import JsonLoader | from .json import JsonLoader | ||||
from ...core import Const | |||||
from ...core.const import Const | |||||
from .. import DataBundle | from .. import DataBundle | ||||
import os | import os | ||||
from typing import Union, Dict | from typing import Union, Dict | ||||
from ...core import DataSet | |||||
from ...core import Instance | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
class MNLILoader(Loader): | class MNLILoader(Loader): | ||||
@@ -4,13 +4,14 @@ from ..base_loader import DataBundle | |||||
from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
from ...core.const import Const | from ...core.const import Const | ||||
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | ||||
from ...core import DataSet, Instance | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance | from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance | ||||
from .pipe import Pipe | from .pipe import Pipe | ||||
import re | import re | ||||
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | ||||
from ...core import cache_results | |||||
from ...core.utils import cache_results | |||||
class _CLSPipe(Pipe): | class _CLSPipe(Pipe): | ||||
""" | """ | ||||
@@ -1,7 +1,7 @@ | |||||
from .pipe import Pipe | from .pipe import Pipe | ||||
from .. import DataBundle | from .. import DataBundle | ||||
from .utils import iob2, iob2bioes | from .utils import iob2, iob2bioes | ||||
from ... import Const | |||||
from ...core.const import Const | |||||
from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader | from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader | ||||
from .utils import _indexize, _add_words_field | from .utils import _indexize, _add_words_field | ||||
@@ -2,8 +2,8 @@ import math | |||||
from .pipe import Pipe | from .pipe import Pipe | ||||
from .utils import get_tokenizer | from .utils import get_tokenizer | ||||
from ...core import Const | |||||
from ...core import Vocabulary | |||||
from ...core.const import Const | |||||
from ...core.vocabulary import Vocabulary | |||||
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader | from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader | ||||
@@ -1,6 +1,6 @@ | |||||
from typing import List | from typing import List | ||||
from ...core import Vocabulary | |||||
from ...core import Const | |||||
from ...core.vocabulary import Vocabulary | |||||
from ...core.const import Const | |||||
def iob2(tags:List[str])->List[str]: | def iob2(tags:List[str])->List[str]: | ||||
""" | """ | ||||
@@ -51,7 +51,7 @@ class ChineseNERLoader(DataSetLoader): | |||||
:param paths: | :param paths: | ||||
:param bool, bigrams: 是否包含生成bigram feature, [a, b, c, d] -> [ab, bc, cd, d<eos>] | :param bool, bigrams: 是否包含生成bigram feature, [a, b, c, d] -> [ab, bc, cd, d<eos>] | ||||
:param bool, trigrams: 是否包含trigram feature,[a, b, c, d] -> [abc, bcd, cd<eos>, d<eos><eos>] | :param bool, trigrams: 是否包含trigram feature,[a, b, c, d] -> [abc, bcd, cd<eos>, d<eos><eos>] | ||||
:return: DataBundle | |||||
:return: ~fastNLP.io.DataBundle | |||||
包含以下的fields | 包含以下的fields | ||||
raw_chars: List[str] | raw_chars: List[str] | ||||
chars: List[int] | chars: List[int] | ||||