@@ -10,8 +10,72 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa | |||||
对于常用的功能,你只需要在 :doc:`fastNLP` 中查看即可。如果想了解各个子模块的具体作用,您可以在下面找到每个子模块的具体文档。 | 对于常用的功能,你只需要在 :doc:`fastNLP` 中查看即可。如果想了解各个子模块的具体作用,您可以在下面找到每个子模块的具体文档。 | ||||
""" | """ | ||||
__all__ = [ | |||||
"DataSet", | |||||
"Instance", | |||||
"FieldArray", | |||||
"Padder", | |||||
"AutoPadder", | |||||
"EngChar2DPadder", | |||||
"Vocabulary", | |||||
"DataSetIter", | |||||
"BatchIter", | |||||
"TorchLoaderIter", | |||||
"Const", | |||||
"Tester", | |||||
"Trainer", | |||||
"cache_results", | |||||
"seq_len_to_mask", | |||||
"get_seq_len", | |||||
"logger", | |||||
"Callback", | |||||
"GradientClipCallback", | |||||
"EarlyStopCallback", | |||||
"FitlogCallback", | |||||
"EvaluateCallback", | |||||
"LRScheduler", | |||||
"ControlC", | |||||
"LRFinder", | |||||
"TensorboardCallback", | |||||
"WarmupCallback", | |||||
'SaveModelCallback', | |||||
"EchoCallback", | |||||
"TesterCallback", | |||||
"CallbackException", | |||||
"EarlyStopError", | |||||
"LossFunc", | |||||
"CrossEntropyLoss", | |||||
"L1Loss", | |||||
"BCELoss", | |||||
"NLLLoss", | |||||
"LossInForward", | |||||
"AccuracyMetric", | |||||
"SpanFPreRecMetric", | |||||
"ExtractiveQAMetric", | |||||
"Optimizer", | |||||
"SGD", | |||||
"Adam", | |||||
"AdamW", | |||||
"SequentialSampler", | |||||
"BucketSampler", | |||||
"RandomSampler", | |||||
"Sampler", | |||||
] | |||||
from ._logger import logger | |||||
from .batch import DataSetIter, BatchIter, TorchLoaderIter | from .batch import DataSetIter, BatchIter, TorchLoaderIter | ||||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ | from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ | ||||
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \ | LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \ | ||||
@@ -28,4 +92,3 @@ from .tester import Tester | |||||
from .trainer import Trainer | from .trainer import Trainer | ||||
from .utils import cache_results, seq_len_to_mask, get_seq_len | from .utils import cache_results, seq_len_to_mask, get_seq_len | ||||
from .vocabulary import Vocabulary | from .vocabulary import Vocabulary | ||||
from ._logger import logger |
@@ -1,15 +1,15 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | |||||
'logger', | |||||
] | |||||
import logging | import logging | ||||
import logging.config | import logging.config | ||||
import torch | |||||
import _pickle as pickle | |||||
import os | import os | ||||
import sys | import sys | ||||
import warnings | import warnings | ||||
__all__ = [ | |||||
'logger', | |||||
] | |||||
ROOT_NAME = 'fastNLP' | ROOT_NAME = 'fastNLP' | ||||
try: | try: | ||||
@@ -25,7 +25,7 @@ if tqdm is not None: | |||||
class TqdmLoggingHandler(logging.Handler): | class TqdmLoggingHandler(logging.Handler): | ||||
def __init__(self, level=logging.INFO): | def __init__(self, level=logging.INFO): | ||||
super().__init__(level) | super().__init__(level) | ||||
def emit(self, record): | def emit(self, record): | ||||
try: | try: | ||||
msg = self.format(record) | msg = self.format(record) | ||||
@@ -59,14 +59,14 @@ def _add_file_handler(logger, path, level='INFO'): | |||||
if os.path.abspath(path) == h.baseFilename: | if os.path.abspath(path) == h.baseFilename: | ||||
# file path already added | # file path already added | ||||
return | return | ||||
# File Handler | # File Handler | ||||
if os.path.exists(path): | if os.path.exists(path): | ||||
assert os.path.isfile(path) | assert os.path.isfile(path) | ||||
warnings.warn('log already exists in {}'.format(path)) | warnings.warn('log already exists in {}'.format(path)) | ||||
dirname = os.path.abspath(os.path.dirname(path)) | dirname = os.path.abspath(os.path.dirname(path)) | ||||
os.makedirs(dirname, exist_ok=True) | os.makedirs(dirname, exist_ok=True) | ||||
file_handler = logging.FileHandler(path, mode='a') | file_handler = logging.FileHandler(path, mode='a') | ||||
file_handler.setLevel(_get_level(level)) | file_handler.setLevel(_get_level(level)) | ||||
file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s', | file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s', | ||||
@@ -87,7 +87,7 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'): | |||||
break | break | ||||
if stream_handler is not None: | if stream_handler is not None: | ||||
logger.removeHandler(stream_handler) | logger.removeHandler(stream_handler) | ||||
# Stream Handler | # Stream Handler | ||||
if stdout == 'plain': | if stdout == 'plain': | ||||
stream_handler = logging.StreamHandler(sys.stdout) | stream_handler = logging.StreamHandler(sys.stdout) | ||||
@@ -95,7 +95,7 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'): | |||||
stream_handler = TqdmLoggingHandler(level) | stream_handler = TqdmLoggingHandler(level) | ||||
else: | else: | ||||
stream_handler = None | stream_handler = None | ||||
if stream_handler is not None: | if stream_handler is not None: | ||||
stream_formatter = logging.Formatter('%(message)s') | stream_formatter = logging.Formatter('%(message)s') | ||||
stream_handler.setLevel(level) | stream_handler.setLevel(level) | ||||
@@ -103,38 +103,40 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'): | |||||
logger.addHandler(stream_handler) | logger.addHandler(stream_handler) | ||||
class FastNLPLogger(logging.getLoggerClass()): | class FastNLPLogger(logging.getLoggerClass()): | ||||
def __init__(self, name): | def __init__(self, name): | ||||
super().__init__(name) | super().__init__(name) | ||||
def add_file(self, path='./log.txt', level='INFO'): | def add_file(self, path='./log.txt', level='INFO'): | ||||
"""add log output file and level""" | """add log output file and level""" | ||||
_add_file_handler(self, path, level) | _add_file_handler(self, path, level) | ||||
def set_stdout(self, stdout='tqdm', level='INFO'): | def set_stdout(self, stdout='tqdm', level='INFO'): | ||||
"""set stdout format and level""" | """set stdout format and level""" | ||||
_set_stdout_handler(self, stdout, level) | _set_stdout_handler(self, stdout, level) | ||||
logging.setLoggerClass(FastNLPLogger) | logging.setLoggerClass(FastNLPLogger) | ||||
# print(logging.getLoggerClass()) | # print(logging.getLoggerClass()) | ||||
# print(logging.getLogger()) | # print(logging.getLogger()) | ||||
def _init_logger(path=None, stdout='tqdm', level='INFO'): | def _init_logger(path=None, stdout='tqdm', level='INFO'): | ||||
"""initialize logger""" | """initialize logger""" | ||||
level = _get_level(level) | level = _get_level(level) | ||||
# logger = logging.getLogger() | # logger = logging.getLogger() | ||||
logger = logging.getLogger(ROOT_NAME) | logger = logging.getLogger(ROOT_NAME) | ||||
logger.propagate = False | logger.propagate = False | ||||
logger.setLevel(level) | logger.setLevel(level) | ||||
_set_stdout_handler(logger, stdout, level) | _set_stdout_handler(logger, stdout, level) | ||||
# File Handler | # File Handler | ||||
if path is not None: | if path is not None: | ||||
_add_file_handler(logger, path, level) | _add_file_handler(logger, path, level) | ||||
return logger | return logger | ||||
@@ -1,11 +1,14 @@ | |||||
"""undocumented""" | |||||
__all__ = [] | |||||
import threading | import threading | ||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
from torch.nn.parallel.parallel_apply import get_a_var | from torch.nn.parallel.parallel_apply import get_a_var | ||||
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather | |||||
from torch.nn.parallel.replicate import replicate | from torch.nn.parallel.replicate import replicate | ||||
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather | |||||
def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): | def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): | ||||
@@ -27,11 +30,11 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): | |||||
assert len(modules) == len(devices) | assert len(modules) == len(devices) | ||||
else: | else: | ||||
devices = [None] * len(modules) | devices = [None] * len(modules) | ||||
lock = threading.Lock() | lock = threading.Lock() | ||||
results = {} | results = {} | ||||
grad_enabled = torch.is_grad_enabled() | grad_enabled = torch.is_grad_enabled() | ||||
def _worker(i, module, input, kwargs, device=None): | def _worker(i, module, input, kwargs, device=None): | ||||
torch.set_grad_enabled(grad_enabled) | torch.set_grad_enabled(grad_enabled) | ||||
if device is None: | if device is None: | ||||
@@ -47,20 +50,20 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): | |||||
except Exception as e: | except Exception as e: | ||||
with lock: | with lock: | ||||
results[i] = e | results[i] = e | ||||
if len(modules) > 1: | if len(modules) > 1: | ||||
threads = [threading.Thread(target=_worker, | threads = [threading.Thread(target=_worker, | ||||
args=(i, module, input, kwargs, device)) | args=(i, module, input, kwargs, device)) | ||||
for i, (module, input, kwargs, device) in | for i, (module, input, kwargs, device) in | ||||
enumerate(zip(modules, inputs, kwargs_tup, devices))] | enumerate(zip(modules, inputs, kwargs_tup, devices))] | ||||
for thread in threads: | for thread in threads: | ||||
thread.start() | thread.start() | ||||
for thread in threads: | for thread in threads: | ||||
thread.join() | thread.join() | ||||
else: | else: | ||||
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) | ||||
outputs = [] | outputs = [] | ||||
for i in range(len(inputs)): | for i in range(len(inputs)): | ||||
output = results[i] | output = results[i] | ||||
@@ -79,6 +82,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device): | |||||
:param output_device: nn.DataParallel中的output_device | :param output_device: nn.DataParallel中的output_device | ||||
:return: | :return: | ||||
""" | """ | ||||
def wrapper(network, *inputs, **kwargs): | def wrapper(network, *inputs, **kwargs): | ||||
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) | inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) | ||||
if len(device_ids) == 1: | if len(device_ids) == 1: | ||||
@@ -86,6 +90,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device): | |||||
replicas = replicate(network, device_ids[:len(inputs)]) | replicas = replicate(network, device_ids[:len(inputs)]) | ||||
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) | outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) | ||||
return gather(outputs, output_device) | return gather(outputs, output_device) | ||||
return wrapper | return wrapper | ||||
@@ -99,4 +104,4 @@ def _model_contains_inner_module(model): | |||||
if isinstance(model, nn.Module): | if isinstance(model, nn.Module): | ||||
if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): | if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): | ||||
return True | return True | ||||
return False | |||||
return False |
@@ -1,3 +1,13 @@ | |||||
""" | |||||
.. todo:: | |||||
doc | |||||
""" | |||||
__all__ = [ | |||||
"Const" | |||||
] | |||||
class Const: | class Const: | ||||
""" | """ | ||||
fastNLP中field命名常量。 | fastNLP中field命名常量。 | ||||
@@ -25,47 +35,47 @@ class Const: | |||||
LOSS = 'loss' | LOSS = 'loss' | ||||
RAW_WORD = 'raw_words' | RAW_WORD = 'raw_words' | ||||
RAW_CHAR = 'raw_chars' | RAW_CHAR = 'raw_chars' | ||||
@staticmethod | @staticmethod | ||||
def INPUTS(i): | def INPUTS(i): | ||||
"""得到第 i 个 ``INPUT`` 的命名""" | """得到第 i 个 ``INPUT`` 的命名""" | ||||
i = int(i) + 1 | i = int(i) + 1 | ||||
return Const.INPUT + str(i) | return Const.INPUT + str(i) | ||||
@staticmethod | @staticmethod | ||||
def CHAR_INPUTS(i): | def CHAR_INPUTS(i): | ||||
"""得到第 i 个 ``CHAR_INPUT`` 的命名""" | """得到第 i 个 ``CHAR_INPUT`` 的命名""" | ||||
i = int(i) + 1 | i = int(i) + 1 | ||||
return Const.CHAR_INPUT + str(i) | return Const.CHAR_INPUT + str(i) | ||||
@staticmethod | @staticmethod | ||||
def RAW_WORDS(i): | def RAW_WORDS(i): | ||||
i = int(i) + 1 | i = int(i) + 1 | ||||
return Const.RAW_WORD + str(i) | return Const.RAW_WORD + str(i) | ||||
@staticmethod | @staticmethod | ||||
def RAW_CHARS(i): | def RAW_CHARS(i): | ||||
i = int(i) + 1 | i = int(i) + 1 | ||||
return Const.RAW_CHAR + str(i) | return Const.RAW_CHAR + str(i) | ||||
@staticmethod | @staticmethod | ||||
def INPUT_LENS(i): | def INPUT_LENS(i): | ||||
"""得到第 i 个 ``INPUT_LEN`` 的命名""" | """得到第 i 个 ``INPUT_LEN`` 的命名""" | ||||
i = int(i) + 1 | i = int(i) + 1 | ||||
return Const.INPUT_LEN + str(i) | return Const.INPUT_LEN + str(i) | ||||
@staticmethod | @staticmethod | ||||
def OUTPUTS(i): | def OUTPUTS(i): | ||||
"""得到第 i 个 ``OUTPUT`` 的命名""" | """得到第 i 个 ``OUTPUT`` 的命名""" | ||||
i = int(i) + 1 | i = int(i) + 1 | ||||
return Const.OUTPUT + str(i) | return Const.OUTPUT + str(i) | ||||
@staticmethod | @staticmethod | ||||
def TARGETS(i): | def TARGETS(i): | ||||
"""得到第 i 个 ``TARGET`` 的命名""" | """得到第 i 个 ``TARGET`` 的命名""" | ||||
i = int(i) + 1 | i = int(i) + 1 | ||||
return Const.TARGET + str(i) | return Const.TARGET + str(i) | ||||
@staticmethod | @staticmethod | ||||
def LOSSES(i): | def LOSSES(i): | ||||
"""得到第 i 个 ``LOSS`` 的命名""" | """得到第 i 个 ``LOSS`` 的命名""" | ||||
@@ -1,29 +1,29 @@ | |||||
""" | |||||
"""undocumented | |||||
正在开发中的分布式训练代码 | 正在开发中的分布式训练代码 | ||||
""" | """ | ||||
import logging | |||||
import os | |||||
import time | |||||
from datetime import datetime | |||||
import torch | import torch | ||||
import torch.cuda | import torch.cuda | ||||
import torch.optim | |||||
import torch.distributed as dist | import torch.distributed as dist | ||||
from torch.utils.data.distributed import DistributedSampler | |||||
import torch.optim | |||||
from pkg_resources import parse_version | |||||
from torch.nn.parallel import DistributedDataParallel as DDP | from torch.nn.parallel import DistributedDataParallel as DDP | ||||
import os | |||||
from torch.utils.data.distributed import DistributedSampler | |||||
from tqdm import tqdm | from tqdm import tqdm | ||||
import time | |||||
from datetime import datetime, timedelta | |||||
from functools import partial | |||||
from ._logger import logger | |||||
from .batch import DataSetIter, BatchIter | from .batch import DataSetIter, BatchIter | ||||
from .callback import DistCallbackManager, CallbackException, TesterCallback | from .callback import DistCallbackManager, CallbackException, TesterCallback | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .losses import _prepare_losser | from .losses import _prepare_losser | ||||
from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
from .utils import _build_args | from .utils import _build_args | ||||
from .utils import _move_dict_value_to_device | |||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from ._logger import logger | |||||
import logging | |||||
from pkg_resources import parse_version | |||||
from .utils import _move_dict_value_to_device | |||||
__all__ = [ | __all__ = [ | ||||
'get_local_rank', | 'get_local_rank', | ||||
@@ -1,18 +1,25 @@ | |||||
""" | |||||
.. todo:: | |||||
doc | |||||
""" | |||||
__all__ = [ | __all__ = [ | ||||
"Padder", | "Padder", | ||||
"AutoPadder", | "AutoPadder", | ||||
"EngChar2DPadder", | "EngChar2DPadder", | ||||
] | ] | ||||
from numbers import Number | |||||
import torch | |||||
import numpy as np | |||||
from typing import Any | |||||
from abc import abstractmethod | from abc import abstractmethod | ||||
from copy import deepcopy | |||||
from collections import Counter | from collections import Counter | ||||
from .utils import _is_iterable | |||||
from copy import deepcopy | |||||
from numbers import Number | |||||
from typing import Any | |||||
import numpy as np | |||||
import torch | |||||
from ._logger import logger | from ._logger import logger | ||||
from .utils import _is_iterable | |||||
class SetInputOrTargetException(Exception): | class SetInputOrTargetException(Exception): | ||||
@@ -1,13 +1,15 @@ | |||||
""" | |||||
..todo:: | |||||
检查这个类是否需要 | |||||
""" | |||||
"""undocumented""" | |||||
__all__ = [ | |||||
"Predictor" | |||||
] | |||||
from collections import defaultdict | from collections import defaultdict | ||||
import torch | import torch | ||||
from . import DataSetIter | |||||
from . import DataSet | from . import DataSet | ||||
from . import DataSetIter | |||||
from . import SequentialSampler | from . import SequentialSampler | ||||
from .utils import _build_args, _move_dict_value_to_device, _get_model_device | from .utils import _build_args, _move_dict_value_to_device, _get_model_device | ||||
@@ -21,7 +23,7 @@ class Predictor(object): | |||||
:param torch.nn.Module network: 用来完成预测任务的模型 | :param torch.nn.Module network: 用来完成预测任务的模型 | ||||
""" | """ | ||||
def __init__(self, network): | def __init__(self, network): | ||||
if not isinstance(network, torch.nn.Module): | if not isinstance(network, torch.nn.Module): | ||||
raise ValueError( | raise ValueError( | ||||
@@ -29,7 +31,7 @@ class Predictor(object): | |||||
self.network = network | self.network = network | ||||
self.batch_size = 1 | self.batch_size = 1 | ||||
self.batch_output = [] | self.batch_output = [] | ||||
def predict(self, data: DataSet, seq_len_field_name=None): | def predict(self, data: DataSet, seq_len_field_name=None): | ||||
"""用已经训练好的模型进行inference. | """用已经训练好的模型进行inference. | ||||
@@ -41,27 +43,27 @@ class Predictor(object): | |||||
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) | raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) | ||||
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: | if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: | ||||
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) | raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) | ||||
prev_training = self.network.training | prev_training = self.network.training | ||||
self.network.eval() | self.network.eval() | ||||
network_device = _get_model_device(self.network) | network_device = _get_model_device(self.network) | ||||
batch_output = defaultdict(list) | batch_output = defaultdict(list) | ||||
data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) | data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) | ||||
if hasattr(self.network, "predict"): | if hasattr(self.network, "predict"): | ||||
predict_func = self.network.predict | predict_func = self.network.predict | ||||
else: | else: | ||||
predict_func = self.network.forward | predict_func = self.network.forward | ||||
with torch.no_grad(): | with torch.no_grad(): | ||||
for batch_x, _ in data_iterator: | for batch_x, _ in data_iterator: | ||||
_move_dict_value_to_device(batch_x, _, device=network_device) | _move_dict_value_to_device(batch_x, _, device=network_device) | ||||
refined_batch_x = _build_args(predict_func, **batch_x) | refined_batch_x = _build_args(predict_func, **batch_x) | ||||
prediction = predict_func(**refined_batch_x) | prediction = predict_func(**refined_batch_x) | ||||
if seq_len_field_name is not None: | if seq_len_field_name is not None: | ||||
seq_lens = batch_x[seq_len_field_name].tolist() | seq_lens = batch_x[seq_len_field_name].tolist() | ||||
for key, value in prediction.items(): | for key, value in prediction.items(): | ||||
value = value.cpu().numpy() | value = value.cpu().numpy() | ||||
if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1): | if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1): | ||||
@@ -74,6 +76,6 @@ class Predictor(object): | |||||
batch_output[key].extend(tmp_batch) | batch_output[key].extend(tmp_batch) | ||||
else: | else: | ||||
batch_output[key].append(value) | batch_output[key].append(value) | ||||
self.network.train(prev_training) | self.network.train(prev_training) | ||||
return batch_output | return batch_output |
@@ -1,16 +1,22 @@ | |||||
""" | |||||
.. todo:: | |||||
doc | |||||
""" | |||||
__all__ = [ | __all__ = [ | ||||
"Vocabulary", | "Vocabulary", | ||||
"VocabularyOption", | "VocabularyOption", | ||||
] | ] | ||||
from functools import wraps | |||||
from collections import Counter | from collections import Counter | ||||
from functools import partial | |||||
from functools import wraps | |||||
from ._logger import logger | |||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .utils import Option | from .utils import Option | ||||
from functools import partial | |||||
import numpy as np | |||||
from .utils import _is_iterable | from .utils import _is_iterable | ||||
from ._logger import logger | |||||
class VocabularyOption(Option): | class VocabularyOption(Option): | ||||
def __init__(self, | def __init__(self, | ||||
@@ -51,7 +57,7 @@ def _check_build_status(func): | |||||
self.rebuild = True | self.rebuild = True | ||||
if self.max_size is not None and len(self.word_count) >= self.max_size: | if self.max_size is not None and len(self.word_count) >= self.max_size: | ||||
logger.info("[Warning] Vocabulary has reached the max size {} when calling {} method. " | logger.info("[Warning] Vocabulary has reached the max size {} when calling {} method. " | ||||
"Adding more words may cause unexpected behaviour of Vocabulary. ".format( | |||||
"Adding more words may cause unexpected behaviour of Vocabulary. ".format( | |||||
self.max_size, func.__name__)) | self.max_size, func.__name__)) | ||||
return func(self, *args, **kwargs) | return func(self, *args, **kwargs) | ||||
@@ -199,7 +205,7 @@ class Vocabulary(object): | |||||
self.build_reverse_vocab() | self.build_reverse_vocab() | ||||
self.rebuild = False | self.rebuild = False | ||||
return self | return self | ||||
def build_reverse_vocab(self): | def build_reverse_vocab(self): | ||||
""" | """ | ||||
基于 `word to index` dict, 构建 `index to word` dict. | 基于 `word to index` dict, 构建 `index to word` dict. | ||||
@@ -279,19 +285,19 @@ class Vocabulary(object): | |||||
if not isinstance(field[0][0], str) and _is_iterable(field[0][0]): | if not isinstance(field[0][0], str) and _is_iterable(field[0][0]): | ||||
raise RuntimeError("Only support field with 2 dimensions.") | raise RuntimeError("Only support field with 2 dimensions.") | ||||
return [[self.to_index(c) for c in w] for w in field] | return [[self.to_index(c) for c in w] for w in field] | ||||
new_field_name = new_field_name or field_name | new_field_name = new_field_name or field_name | ||||
if type(new_field_name) == type(field_name): | if type(new_field_name) == type(field_name): | ||||
if isinstance(new_field_name, list): | if isinstance(new_field_name, list): | ||||
assert len(new_field_name) == len(field_name), "new_field_name should have same number elements with " \ | assert len(new_field_name) == len(field_name), "new_field_name should have same number elements with " \ | ||||
"field_name." | |||||
"field_name." | |||||
elif isinstance(new_field_name, str): | elif isinstance(new_field_name, str): | ||||
field_name = [field_name] | field_name = [field_name] | ||||
new_field_name = [new_field_name] | new_field_name = [new_field_name] | ||||
else: | else: | ||||
raise TypeError("field_name and new_field_name can only be str or List[str].") | raise TypeError("field_name and new_field_name can only be str or List[str].") | ||||
for idx, dataset in enumerate(datasets): | for idx, dataset in enumerate(datasets): | ||||
if isinstance(dataset, DataSet): | if isinstance(dataset, DataSet): | ||||
try: | try: | ||||
@@ -377,7 +383,7 @@ class Vocabulary(object): | |||||
:return: bool | :return: bool | ||||
""" | """ | ||||
return word in self._no_create_word | return word in self._no_create_word | ||||
def to_index(self, w): | def to_index(self, w): | ||||
""" | """ | ||||
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``:: | 将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``:: | ||||
@@ -8,15 +8,17 @@ __all__ = [ | |||||
] | ] | ||||
from abc import abstractmethod | from abc import abstractmethod | ||||
import torch | import torch | ||||
from ..core.vocabulary import Vocabulary | |||||
from ..core.dataset import DataSet | |||||
from .embedding import TokenEmbedding | |||||
from ..core import logger | |||||
from ..core.batch import DataSetIter | from ..core.batch import DataSetIter | ||||
from ..core.dataset import DataSet | |||||
from ..core.sampler import SequentialSampler | from ..core.sampler import SequentialSampler | ||||
from ..core.utils import _move_model_to_device, _get_model_device | from ..core.utils import _move_model_to_device, _get_model_device | ||||
from .embedding import TokenEmbedding | |||||
from ..core import logger | |||||
from ..core.vocabulary import Vocabulary | |||||
class ContextualEmbedding(TokenEmbedding): | class ContextualEmbedding(TokenEmbedding): | ||||
def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0): | def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0): | ||||