Browse Source

add __all__ and __doc__ for all files in module 'core', using 'undocumented' tags

tags/v0.4.10
ChenXin 5 years ago
parent
commit
efe88263bb
9 changed files with 178 additions and 81 deletions
  1. +65
    -2
      fastNLP/core/__init__.py
  2. +20
    -18
      fastNLP/core/_logger.py
  3. +13
    -8
      fastNLP/core/_parallel_utils.py
  4. +18
    -8
      fastNLP/core/const.py
  5. +11
    -11
      fastNLP/core/dist_trainer.py
  6. +13
    -6
      fastNLP/core/field.py
  7. +15
    -13
      fastNLP/core/predictor.py
  8. +17
    -11
      fastNLP/core/vocabulary.py
  9. +6
    -4
      fastNLP/embeddings/contextual_embedding.py

+ 65
- 2
fastNLP/core/__init__.py View File

@@ -10,8 +10,72 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa


对于常用的功能,你只需要在 :doc:`fastNLP` 中查看即可。如果想了解各个子模块的具体作用,您可以在下面找到每个子模块的具体文档。 对于常用的功能,你只需要在 :doc:`fastNLP` 中查看即可。如果想了解各个子模块的具体作用,您可以在下面找到每个子模块的具体文档。


""" """
__all__ = [
"DataSet",
"Instance",
"FieldArray",
"Padder",
"AutoPadder",
"EngChar2DPadder",
"Vocabulary",
"DataSetIter",
"BatchIter",
"TorchLoaderIter",
"Const",
"Tester",
"Trainer",
"cache_results",
"seq_len_to_mask",
"get_seq_len",
"logger",
"Callback",
"GradientClipCallback",
"EarlyStopCallback",
"FitlogCallback",
"EvaluateCallback",
"LRScheduler",
"ControlC",
"LRFinder",
"TensorboardCallback",
"WarmupCallback",
'SaveModelCallback',
"EchoCallback",
"TesterCallback",
"CallbackException",
"EarlyStopError",
"LossFunc",
"CrossEntropyLoss",
"L1Loss",
"BCELoss",
"NLLLoss",
"LossInForward",
"AccuracyMetric",
"SpanFPreRecMetric",
"ExtractiveQAMetric",
"Optimizer",
"SGD",
"Adam",
"AdamW",
"SequentialSampler",
"BucketSampler",
"RandomSampler",
"Sampler",
]

from ._logger import logger
from .batch import DataSetIter, BatchIter, TorchLoaderIter from .batch import DataSetIter, BatchIter, TorchLoaderIter
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \ LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \
@@ -28,4 +92,3 @@ from .tester import Tester
from .trainer import Trainer from .trainer import Trainer
from .utils import cache_results, seq_len_to_mask, get_seq_len from .utils import cache_results, seq_len_to_mask, get_seq_len
from .vocabulary import Vocabulary from .vocabulary import Vocabulary
from ._logger import logger

+ 20
- 18
fastNLP/core/_logger.py View File

@@ -1,15 +1,15 @@
"""undocumented"""

__all__ = [
'logger',
]

import logging import logging
import logging.config import logging.config
import torch
import _pickle as pickle
import os import os
import sys import sys
import warnings import warnings


__all__ = [
'logger',
]

ROOT_NAME = 'fastNLP' ROOT_NAME = 'fastNLP'


try: try:
@@ -25,7 +25,7 @@ if tqdm is not None:
class TqdmLoggingHandler(logging.Handler): class TqdmLoggingHandler(logging.Handler):
def __init__(self, level=logging.INFO): def __init__(self, level=logging.INFO):
super().__init__(level) super().__init__(level)
def emit(self, record): def emit(self, record):
try: try:
msg = self.format(record) msg = self.format(record)
@@ -59,14 +59,14 @@ def _add_file_handler(logger, path, level='INFO'):
if os.path.abspath(path) == h.baseFilename: if os.path.abspath(path) == h.baseFilename:
# file path already added # file path already added
return return
# File Handler # File Handler
if os.path.exists(path): if os.path.exists(path):
assert os.path.isfile(path) assert os.path.isfile(path)
warnings.warn('log already exists in {}'.format(path)) warnings.warn('log already exists in {}'.format(path))
dirname = os.path.abspath(os.path.dirname(path)) dirname = os.path.abspath(os.path.dirname(path))
os.makedirs(dirname, exist_ok=True) os.makedirs(dirname, exist_ok=True)
file_handler = logging.FileHandler(path, mode='a') file_handler = logging.FileHandler(path, mode='a')
file_handler.setLevel(_get_level(level)) file_handler.setLevel(_get_level(level))
file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s', file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s',
@@ -87,7 +87,7 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
break break
if stream_handler is not None: if stream_handler is not None:
logger.removeHandler(stream_handler) logger.removeHandler(stream_handler)
# Stream Handler # Stream Handler
if stdout == 'plain': if stdout == 'plain':
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
@@ -95,7 +95,7 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
stream_handler = TqdmLoggingHandler(level) stream_handler = TqdmLoggingHandler(level)
else: else:
stream_handler = None stream_handler = None
if stream_handler is not None: if stream_handler is not None:
stream_formatter = logging.Formatter('%(message)s') stream_formatter = logging.Formatter('%(message)s')
stream_handler.setLevel(level) stream_handler.setLevel(level)
@@ -103,38 +103,40 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
logger.addHandler(stream_handler) logger.addHandler(stream_handler)





class FastNLPLogger(logging.getLoggerClass()): class FastNLPLogger(logging.getLoggerClass()):
def __init__(self, name): def __init__(self, name):
super().__init__(name) super().__init__(name)
def add_file(self, path='./log.txt', level='INFO'): def add_file(self, path='./log.txt', level='INFO'):
"""add log output file and level""" """add log output file and level"""
_add_file_handler(self, path, level) _add_file_handler(self, path, level)
def set_stdout(self, stdout='tqdm', level='INFO'): def set_stdout(self, stdout='tqdm', level='INFO'):
"""set stdout format and level""" """set stdout format and level"""
_set_stdout_handler(self, stdout, level) _set_stdout_handler(self, stdout, level)



logging.setLoggerClass(FastNLPLogger) logging.setLoggerClass(FastNLPLogger)


# print(logging.getLoggerClass()) # print(logging.getLoggerClass())
# print(logging.getLogger()) # print(logging.getLogger())


def _init_logger(path=None, stdout='tqdm', level='INFO'): def _init_logger(path=None, stdout='tqdm', level='INFO'):
"""initialize logger""" """initialize logger"""
level = _get_level(level) level = _get_level(level)
# logger = logging.getLogger() # logger = logging.getLogger()
logger = logging.getLogger(ROOT_NAME) logger = logging.getLogger(ROOT_NAME)
logger.propagate = False logger.propagate = False
logger.setLevel(level) logger.setLevel(level)
_set_stdout_handler(logger, stdout, level) _set_stdout_handler(logger, stdout, level)
# File Handler # File Handler
if path is not None: if path is not None:
_add_file_handler(logger, path, level) _add_file_handler(logger, path, level)
return logger return logger






+ 13
- 8
fastNLP/core/_parallel_utils.py View File

@@ -1,11 +1,14 @@
"""undocumented"""

__all__ = []


import threading import threading

import torch import torch
from torch import nn from torch import nn
from torch.nn.parallel.parallel_apply import get_a_var from torch.nn.parallel.parallel_apply import get_a_var

from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
from torch.nn.parallel.replicate import replicate from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather




def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
@@ -27,11 +30,11 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
assert len(modules) == len(devices) assert len(modules) == len(devices)
else: else:
devices = [None] * len(modules) devices = [None] * len(modules)
lock = threading.Lock() lock = threading.Lock()
results = {} results = {}
grad_enabled = torch.is_grad_enabled() grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input, kwargs, device=None): def _worker(i, module, input, kwargs, device=None):
torch.set_grad_enabled(grad_enabled) torch.set_grad_enabled(grad_enabled)
if device is None: if device is None:
@@ -47,20 +50,20 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
except Exception as e: except Exception as e:
with lock: with lock:
results[i] = e results[i] = e
if len(modules) > 1: if len(modules) > 1:
threads = [threading.Thread(target=_worker, threads = [threading.Thread(target=_worker,
args=(i, module, input, kwargs, device)) args=(i, module, input, kwargs, device))
for i, (module, input, kwargs, device) in for i, (module, input, kwargs, device) in
enumerate(zip(modules, inputs, kwargs_tup, devices))] enumerate(zip(modules, inputs, kwargs_tup, devices))]
for thread in threads: for thread in threads:
thread.start() thread.start()
for thread in threads: for thread in threads:
thread.join() thread.join()
else: else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
outputs = [] outputs = []
for i in range(len(inputs)): for i in range(len(inputs)):
output = results[i] output = results[i]
@@ -79,6 +82,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device):
:param output_device: nn.DataParallel中的output_device :param output_device: nn.DataParallel中的output_device
:return: :return:
""" """
def wrapper(network, *inputs, **kwargs): def wrapper(network, *inputs, **kwargs):
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0)
if len(device_ids) == 1: if len(device_ids) == 1:
@@ -86,6 +90,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device):
replicas = replicate(network, device_ids[:len(inputs)]) replicas = replicate(network, device_ids[:len(inputs)])
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)])
return gather(outputs, output_device) return gather(outputs, output_device)
return wrapper return wrapper




@@ -99,4 +104,4 @@ def _model_contains_inner_module(model):
if isinstance(model, nn.Module): if isinstance(model, nn.Module):
if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
return True return True
return False
return False

+ 18
- 8
fastNLP/core/const.py View File

@@ -1,3 +1,13 @@
"""
.. todo::
doc
"""

__all__ = [
"Const"
]


class Const: class Const:
""" """
fastNLP中field命名常量。 fastNLP中field命名常量。
@@ -25,47 +35,47 @@ class Const:
LOSS = 'loss' LOSS = 'loss'
RAW_WORD = 'raw_words' RAW_WORD = 'raw_words'
RAW_CHAR = 'raw_chars' RAW_CHAR = 'raw_chars'
@staticmethod @staticmethod
def INPUTS(i): def INPUTS(i):
"""得到第 i 个 ``INPUT`` 的命名""" """得到第 i 个 ``INPUT`` 的命名"""
i = int(i) + 1 i = int(i) + 1
return Const.INPUT + str(i) return Const.INPUT + str(i)
@staticmethod @staticmethod
def CHAR_INPUTS(i): def CHAR_INPUTS(i):
"""得到第 i 个 ``CHAR_INPUT`` 的命名""" """得到第 i 个 ``CHAR_INPUT`` 的命名"""
i = int(i) + 1 i = int(i) + 1
return Const.CHAR_INPUT + str(i) return Const.CHAR_INPUT + str(i)
@staticmethod @staticmethod
def RAW_WORDS(i): def RAW_WORDS(i):
i = int(i) + 1 i = int(i) + 1
return Const.RAW_WORD + str(i) return Const.RAW_WORD + str(i)
@staticmethod @staticmethod
def RAW_CHARS(i): def RAW_CHARS(i):
i = int(i) + 1 i = int(i) + 1
return Const.RAW_CHAR + str(i) return Const.RAW_CHAR + str(i)
@staticmethod @staticmethod
def INPUT_LENS(i): def INPUT_LENS(i):
"""得到第 i 个 ``INPUT_LEN`` 的命名""" """得到第 i 个 ``INPUT_LEN`` 的命名"""
i = int(i) + 1 i = int(i) + 1
return Const.INPUT_LEN + str(i) return Const.INPUT_LEN + str(i)
@staticmethod @staticmethod
def OUTPUTS(i): def OUTPUTS(i):
"""得到第 i 个 ``OUTPUT`` 的命名""" """得到第 i 个 ``OUTPUT`` 的命名"""
i = int(i) + 1 i = int(i) + 1
return Const.OUTPUT + str(i) return Const.OUTPUT + str(i)
@staticmethod @staticmethod
def TARGETS(i): def TARGETS(i):
"""得到第 i 个 ``TARGET`` 的命名""" """得到第 i 个 ``TARGET`` 的命名"""
i = int(i) + 1 i = int(i) + 1
return Const.TARGET + str(i) return Const.TARGET + str(i)
@staticmethod @staticmethod
def LOSSES(i): def LOSSES(i):
"""得到第 i 个 ``LOSS`` 的命名""" """得到第 i 个 ``LOSS`` 的命名"""


+ 11
- 11
fastNLP/core/dist_trainer.py View File

@@ -1,29 +1,29 @@
"""
"""undocumented
正在开发中的分布式训练代码 正在开发中的分布式训练代码
""" """
import logging
import os
import time
from datetime import datetime

import torch import torch
import torch.cuda import torch.cuda
import torch.optim
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
import torch.optim
from pkg_resources import parse_version
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
import os
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm from tqdm import tqdm
import time
from datetime import datetime, timedelta
from functools import partial


from ._logger import logger
from .batch import DataSetIter, BatchIter from .batch import DataSetIter, BatchIter
from .callback import DistCallbackManager, CallbackException, TesterCallback from .callback import DistCallbackManager, CallbackException, TesterCallback
from .dataset import DataSet from .dataset import DataSet
from .losses import _prepare_losser from .losses import _prepare_losser
from .optimizer import Optimizer from .optimizer import Optimizer
from .utils import _build_args from .utils import _build_args
from .utils import _move_dict_value_to_device
from .utils import _get_func_signature from .utils import _get_func_signature
from ._logger import logger
import logging
from pkg_resources import parse_version
from .utils import _move_dict_value_to_device


__all__ = [ __all__ = [
'get_local_rank', 'get_local_rank',


+ 13
- 6
fastNLP/core/field.py View File

@@ -1,18 +1,25 @@
"""
.. todo::
doc
"""

__all__ = [ __all__ = [
"Padder", "Padder",
"AutoPadder", "AutoPadder",
"EngChar2DPadder", "EngChar2DPadder",
] ]


from numbers import Number
import torch
import numpy as np
from typing import Any
from abc import abstractmethod from abc import abstractmethod
from copy import deepcopy
from collections import Counter from collections import Counter
from .utils import _is_iterable
from copy import deepcopy
from numbers import Number
from typing import Any

import numpy as np
import torch

from ._logger import logger from ._logger import logger
from .utils import _is_iterable




class SetInputOrTargetException(Exception): class SetInputOrTargetException(Exception):


+ 15
- 13
fastNLP/core/predictor.py View File

@@ -1,13 +1,15 @@
"""
..todo::
检查这个类是否需要
"""
"""undocumented"""

__all__ = [
"Predictor"
]

from collections import defaultdict from collections import defaultdict


import torch import torch


from . import DataSetIter
from . import DataSet from . import DataSet
from . import DataSetIter
from . import SequentialSampler from . import SequentialSampler
from .utils import _build_args, _move_dict_value_to_device, _get_model_device from .utils import _build_args, _move_dict_value_to_device, _get_model_device


@@ -21,7 +23,7 @@ class Predictor(object):


:param torch.nn.Module network: 用来完成预测任务的模型 :param torch.nn.Module network: 用来完成预测任务的模型
""" """
def __init__(self, network): def __init__(self, network):
if not isinstance(network, torch.nn.Module): if not isinstance(network, torch.nn.Module):
raise ValueError( raise ValueError(
@@ -29,7 +31,7 @@ class Predictor(object):
self.network = network self.network = network
self.batch_size = 1 self.batch_size = 1
self.batch_output = [] self.batch_output = []
def predict(self, data: DataSet, seq_len_field_name=None): def predict(self, data: DataSet, seq_len_field_name=None):
"""用已经训练好的模型进行inference. """用已经训练好的模型进行inference.


@@ -41,27 +43,27 @@ class Predictor(object):
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) raise ValueError("Only Dataset class is allowed, not {}.".format(type(data)))
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays:
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data))
prev_training = self.network.training prev_training = self.network.training
self.network.eval() self.network.eval()
network_device = _get_model_device(self.network) network_device = _get_model_device(self.network)
batch_output = defaultdict(list) batch_output = defaultdict(list)
data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False)
if hasattr(self.network, "predict"): if hasattr(self.network, "predict"):
predict_func = self.network.predict predict_func = self.network.predict
else: else:
predict_func = self.network.forward predict_func = self.network.forward
with torch.no_grad(): with torch.no_grad():
for batch_x, _ in data_iterator: for batch_x, _ in data_iterator:
_move_dict_value_to_device(batch_x, _, device=network_device) _move_dict_value_to_device(batch_x, _, device=network_device)
refined_batch_x = _build_args(predict_func, **batch_x) refined_batch_x = _build_args(predict_func, **batch_x)
prediction = predict_func(**refined_batch_x) prediction = predict_func(**refined_batch_x)
if seq_len_field_name is not None: if seq_len_field_name is not None:
seq_lens = batch_x[seq_len_field_name].tolist() seq_lens = batch_x[seq_len_field_name].tolist()
for key, value in prediction.items(): for key, value in prediction.items():
value = value.cpu().numpy() value = value.cpu().numpy()
if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1): if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1):
@@ -74,6 +76,6 @@ class Predictor(object):
batch_output[key].extend(tmp_batch) batch_output[key].extend(tmp_batch)
else: else:
batch_output[key].append(value) batch_output[key].append(value)
self.network.train(prev_training) self.network.train(prev_training)
return batch_output return batch_output

+ 17
- 11
fastNLP/core/vocabulary.py View File

@@ -1,16 +1,22 @@
"""
.. todo::
doc
"""

__all__ = [ __all__ = [
"Vocabulary", "Vocabulary",
"VocabularyOption", "VocabularyOption",
] ]


from functools import wraps
from collections import Counter from collections import Counter
from functools import partial
from functools import wraps

from ._logger import logger
from .dataset import DataSet from .dataset import DataSet
from .utils import Option from .utils import Option
from functools import partial
import numpy as np
from .utils import _is_iterable from .utils import _is_iterable
from ._logger import logger



class VocabularyOption(Option): class VocabularyOption(Option):
def __init__(self, def __init__(self,
@@ -51,7 +57,7 @@ def _check_build_status(func):
self.rebuild = True self.rebuild = True
if self.max_size is not None and len(self.word_count) >= self.max_size: if self.max_size is not None and len(self.word_count) >= self.max_size:
logger.info("[Warning] Vocabulary has reached the max size {} when calling {} method. " logger.info("[Warning] Vocabulary has reached the max size {} when calling {} method. "
"Adding more words may cause unexpected behaviour of Vocabulary. ".format(
"Adding more words may cause unexpected behaviour of Vocabulary. ".format(
self.max_size, func.__name__)) self.max_size, func.__name__))
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
@@ -199,7 +205,7 @@ class Vocabulary(object):
self.build_reverse_vocab() self.build_reverse_vocab()
self.rebuild = False self.rebuild = False
return self return self
def build_reverse_vocab(self): def build_reverse_vocab(self):
""" """
基于 `word to index` dict, 构建 `index to word` dict. 基于 `word to index` dict, 构建 `index to word` dict.
@@ -279,19 +285,19 @@ class Vocabulary(object):
if not isinstance(field[0][0], str) and _is_iterable(field[0][0]): if not isinstance(field[0][0], str) and _is_iterable(field[0][0]):
raise RuntimeError("Only support field with 2 dimensions.") raise RuntimeError("Only support field with 2 dimensions.")
return [[self.to_index(c) for c in w] for w in field] return [[self.to_index(c) for c in w] for w in field]
new_field_name = new_field_name or field_name new_field_name = new_field_name or field_name
if type(new_field_name) == type(field_name): if type(new_field_name) == type(field_name):
if isinstance(new_field_name, list): if isinstance(new_field_name, list):
assert len(new_field_name) == len(field_name), "new_field_name should have same number elements with " \ assert len(new_field_name) == len(field_name), "new_field_name should have same number elements with " \
"field_name."
"field_name."
elif isinstance(new_field_name, str): elif isinstance(new_field_name, str):
field_name = [field_name] field_name = [field_name]
new_field_name = [new_field_name] new_field_name = [new_field_name]
else: else:
raise TypeError("field_name and new_field_name can only be str or List[str].") raise TypeError("field_name and new_field_name can only be str or List[str].")
for idx, dataset in enumerate(datasets): for idx, dataset in enumerate(datasets):
if isinstance(dataset, DataSet): if isinstance(dataset, DataSet):
try: try:
@@ -377,7 +383,7 @@ class Vocabulary(object):
:return: bool :return: bool
""" """
return word in self._no_create_word return word in self._no_create_word
def to_index(self, w): def to_index(self, w):
""" """
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``:: 将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``::


+ 6
- 4
fastNLP/embeddings/contextual_embedding.py View File

@@ -8,15 +8,17 @@ __all__ = [
] ]


from abc import abstractmethod from abc import abstractmethod

import torch import torch


from ..core.vocabulary import Vocabulary
from ..core.dataset import DataSet
from .embedding import TokenEmbedding
from ..core import logger
from ..core.batch import DataSetIter from ..core.batch import DataSetIter
from ..core.dataset import DataSet
from ..core.sampler import SequentialSampler from ..core.sampler import SequentialSampler
from ..core.utils import _move_model_to_device, _get_model_device from ..core.utils import _move_model_to_device, _get_model_device
from .embedding import TokenEmbedding
from ..core import logger
from ..core.vocabulary import Vocabulary


class ContextualEmbedding(TokenEmbedding): class ContextualEmbedding(TokenEmbedding):
def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0): def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0):


Loading…
Cancel
Save