@@ -1,6 +1,9 @@ | |||
language: python | |||
python: | |||
- "3.6" | |||
env: | |||
- TRAVIS=1 | |||
# command to install dependencies | |||
install: | |||
- pip install --quiet -r requirements.txt | |||
@@ -14,7 +14,7 @@ help: | |||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) | |||
apidoc: | |||
$(SPHINXAPIDOC) -efM -o source ../$(SPHINXPROJ) | |||
$(SPHINXAPIDOC) -efM -o source ../$(SPHINXPROJ) && python3 format.py | |||
server: | |||
cd build/html && python -m http.server | |||
@@ -0,0 +1,65 @@ | |||
import os | |||
def shorten(file, to_delete, cut=False): | |||
if file.endswith("index.rst") or file.endswith("conf.py"): | |||
return | |||
res = [] | |||
with open(file, "r") as fin: | |||
lines = fin.readlines() | |||
for line in lines: | |||
if cut and line.rstrip() == "Submodules": | |||
break | |||
else: | |||
res.append(line.rstrip()) | |||
for i, line in enumerate(res): | |||
if line.endswith(" package"): | |||
res[i] = res[i][:-len(" package")] | |||
res[i + 1] = res[i + 1][:-len(" package")] | |||
elif line.endswith(" module"): | |||
res[i] = res[i][:-len(" module")] | |||
res[i + 1] = res[i + 1][:-len(" module")] | |||
else: | |||
for name in to_delete: | |||
if line.endswith(name): | |||
res[i] = "del" | |||
with open(file, "w") as fout: | |||
for line in res: | |||
if line != "del": | |||
print(line, file=fout) | |||
def clear(path='./source/'): | |||
files = os.listdir(path) | |||
to_delete = [ | |||
"fastNLP.core.dist_trainer", | |||
"fastNLP.core.predictor", | |||
"fastNLP.io.file_reader", | |||
"fastNLP.io.config_io", | |||
"fastNLP.embeddings.contextual_embedding", | |||
"fastNLP.modules.dropout", | |||
"fastNLP.models.base_model", | |||
"fastNLP.models.bert", | |||
"fastNLP.models.enas_utils", | |||
"fastNLP.models.enas_controller", | |||
"fastNLP.models.enas_model", | |||
"fastNLP.models.enas_trainer", | |||
] | |||
for file in files: | |||
if not os.path.isdir(path + file): | |||
res = file.split('.') | |||
if len(res) > 4: | |||
to_delete.append(file[:-4]) | |||
elif len(res) == 4: | |||
shorten(path + file, to_delete, True) | |||
else: | |||
shorten(path + file, to_delete) | |||
for file in to_delete: | |||
os.remove(path + file + ".rst") | |||
clear() |
@@ -6,11 +6,10 @@ fastNLP.core | |||
:undoc-members: | |||
:show-inheritance: | |||
子模块 | |||
Submodules | |||
---------- | |||
.. toctree:: | |||
:maxdepth: 1 | |||
fastNLP.core.batch | |||
fastNLP.core.callback | |||
@@ -6,11 +6,10 @@ fastNLP.embeddings | |||
:undoc-members: | |||
:show-inheritance: | |||
子模块 | |||
Submodules | |||
---------- | |||
.. toctree:: | |||
:maxdepth: 1 | |||
fastNLP.embeddings.bert_embedding | |||
fastNLP.embeddings.char_embedding | |||
@@ -1,7 +1,8 @@ | |||
fastNLP.io.data\_loader | |||
========================== | |||
======================= | |||
.. automodule:: fastNLP.io.data_loader | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:show-inheritance: | |||
@@ -0,0 +1,7 @@ | |||
fastNLP.io.file\_utils | |||
====================== | |||
.. automodule:: fastNLP.io.file_utils | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -0,0 +1,8 @@ | |||
fastNLP.io.loader | |||
================= | |||
.. automodule:: fastNLP.io.loader | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
@@ -0,0 +1,8 @@ | |||
fastNLP.io.pipe | |||
=============== | |||
.. automodule:: fastNLP.io.pipe | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
@@ -6,14 +6,23 @@ fastNLP.io | |||
:undoc-members: | |||
:show-inheritance: | |||
子模块 | |||
Subpackages | |||
----------- | |||
.. toctree:: | |||
fastNLP.io.data_loader | |||
fastNLP.io.loader | |||
fastNLP.io.pipe | |||
Submodules | |||
---------- | |||
.. toctree:: | |||
:maxdepth: 1 | |||
fastNLP.io.base_loader | |||
fastNLP.io.embed_loader | |||
fastNLP.io.dataset_loader | |||
fastNLP.io.data_loader | |||
fastNLP.io.embed_loader | |||
fastNLP.io.file_utils | |||
fastNLP.io.model_io | |||
fastNLP.io.utils |
@@ -0,0 +1,7 @@ | |||
fastNLP.io.utils | |||
================ | |||
.. automodule:: fastNLP.io.utils | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -6,11 +6,10 @@ fastNLP.models | |||
:undoc-members: | |||
:show-inheritance: | |||
子模块 | |||
Submodules | |||
---------- | |||
.. toctree:: | |||
:maxdepth: 1 | |||
fastNLP.models.biaffine_parser | |||
fastNLP.models.cnn_text_classification | |||
@@ -5,3 +5,4 @@ fastNLP.modules.encoder | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
@@ -6,12 +6,17 @@ fastNLP.modules | |||
:undoc-members: | |||
:show-inheritance: | |||
子模块 | |||
Subpackages | |||
----------- | |||
.. toctree:: | |||
:titlesonly: | |||
:maxdepth: 1 | |||
fastNLP.modules.decoder | |||
fastNLP.modules.encoder | |||
fastNLP.modules.encoder | |||
Submodules | |||
---------- | |||
.. toctree:: | |||
fastNLP.modules.utils |
@@ -0,0 +1,7 @@ | |||
fastNLP.modules.utils | |||
===================== | |||
.. automodule:: fastNLP.modules.utils | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,16 +1,15 @@ | |||
API 文档 | |||
=============== | |||
fastNLP | |||
======= | |||
.. automodule:: fastNLP | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
内部模块 | |||
Subpackages | |||
----------- | |||
.. toctree:: | |||
:maxdepth: 1 | |||
fastNLP.core | |||
fastNLP.embeddings | |||
@@ -2,7 +2,6 @@ fastNLP | |||
======= | |||
.. toctree:: | |||
:titlesonly: | |||
:maxdepth: 4 | |||
fastNLP |
@@ -14,6 +14,7 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa | |||
""" | |||
from .batch import DataSetIter, BatchIter, TorchLoaderIter | |||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC | |||
from .callback import EvaluateCallback, FitlogCallback, SaveModelCallback | |||
from .const import Const | |||
from .dataset import DataSet | |||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | |||
@@ -24,5 +25,5 @@ from .optimizer import Optimizer, SGD, Adam | |||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | |||
from .tester import Tester | |||
from .trainer import Trainer | |||
from .utils import cache_results, seq_len_to_mask | |||
from .utils import cache_results, seq_len_to_mask, get_seq_len | |||
from .vocabulary import Vocabulary |
@@ -1,6 +1,7 @@ | |||
import threading | |||
import torch | |||
from torch import nn | |||
from torch.nn.parallel.parallel_apply import get_a_var | |||
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather | |||
@@ -86,3 +87,16 @@ def _data_parallel_wrapper(func_name, device_ids, output_device): | |||
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) | |||
return gather(outputs, output_device) | |||
return wrapper | |||
def _model_contains_inner_module(model): | |||
""" | |||
:param nn.Module model: 模型文件,判断是否内部包含model.module, 多用于check模型是否是nn.DataParallel, | |||
nn.parallel.DistributedDataParallel。主要是在做形参匹配的时候需要使用最内部的model的function。 | |||
:return: bool | |||
""" | |||
if isinstance(model, nn.Module): | |||
if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): | |||
return True | |||
return False |
@@ -48,6 +48,11 @@ class DataSetGetter: | |||
return len(self.dataset) | |||
def collate_fn(self, batch: list): | |||
""" | |||
:param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]] | |||
:return: | |||
""" | |||
# TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景 | |||
batch_x = {n:[] for n in self.inputs.keys()} | |||
batch_y = {n:[] for n in self.targets.keys()} | |||
@@ -93,9 +98,13 @@ class DataSetGetter: | |||
class SamplerAdapter(torch.utils.data.Sampler): | |||
def __init__(self, sampler, dataset): | |||
super().__init__(dataset) | |||
self.sampler = sampler | |||
self.dataset = dataset | |||
def __len__(self): | |||
return len(self.dataset) | |||
def __iter__(self): | |||
return iter(self.sampler(self.dataset)) | |||
@@ -165,15 +174,19 @@ class DataSetIter(BatchIter): | |||
timeout=0, worker_init_fn=None): | |||
super().__init__() | |||
assert isinstance(dataset, DataSet) | |||
sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) | |||
if not isinstance(sampler, torch.utils.data.Sampler): | |||
self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) | |||
else: | |||
self.sampler = sampler | |||
dataset = DataSetGetter(dataset, as_numpy) | |||
collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None | |||
self.dataiter = torch.utils.data.DataLoader( | |||
dataset=dataset, batch_size=batch_size, sampler=sampler, | |||
dataset=dataset, batch_size=batch_size, sampler=self.sampler, | |||
collate_fn=collate_fn, num_workers=num_workers, | |||
pin_memory=pin_memory, drop_last=drop_last, | |||
timeout=timeout, worker_init_fn=worker_init_fn) | |||
self.num_batches = self.get_num_batches(len(dataset), batch_size, drop_last) | |||
# 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了 | |||
self.num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last) | |||
self.batch_size = batch_size | |||
@@ -182,7 +195,7 @@ class TorchLoaderIter(BatchIter): | |||
super().__init__() | |||
assert isinstance(dataset, torch.utils.data.DataLoader) | |||
self.dataiter = dataset | |||
self.num_batches = self.get_num_batches(len(dataset), dataset.batch_size, dataset.drop_last) | |||
self.num_batches = self.get_num_batches(len(dataset.sampler), dataset.batch_size, dataset.drop_last) | |||
self.batch_size = dataset.batch_size | |||
@@ -200,6 +213,13 @@ class OnlineDataIter(BatchIter): | |||
def _to_tensor(batch, field_dtype): | |||
""" | |||
:param batch: np.array() | |||
:param field_dtype: 数据类型 | |||
:return: batch, flag. 如果传入的数据支持转为tensor,返回的batch就是tensor,且flag为True;如果传入的数据不支持转为tensor, | |||
返回的batch就是原来的数据,且flag为False | |||
""" | |||
try: | |||
if field_dtype is not None and isinstance(field_dtype, type)\ | |||
and issubclass(field_dtype, Number) \ | |||
@@ -57,6 +57,7 @@ __all__ = [ | |||
"FitlogCallback", | |||
"LRScheduler", | |||
"ControlC", | |||
"EvaluateCallback", | |||
"CallbackException", | |||
"EarlyStopError" | |||
@@ -79,6 +80,7 @@ except: | |||
from ..io.model_io import ModelSaver, ModelLoader | |||
from .dataset import DataSet | |||
from .tester import Tester | |||
import logging | |||
try: | |||
import fitlog | |||
@@ -100,7 +102,8 @@ class Callback(object): | |||
def __init__(self): | |||
super(Callback, self).__init__() | |||
self._trainer = None # 在Trainer内部被重新赋值 | |||
self._disabled = False | |||
@property | |||
def trainer(self): | |||
""" | |||
@@ -158,7 +161,19 @@ class Callback(object): | |||
def batch_per_epoch(self): | |||
"""每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。""" | |||
return self._trainer.batch_per_epoch | |||
@property | |||
def is_master(self): | |||
return self._trainer.is_master() | |||
@property | |||
def disabled(self): | |||
return self._disabled | |||
@property | |||
def logger(self): | |||
return getattr(self._trainer, 'logger', logging) | |||
def on_train_begin(self): | |||
""" | |||
在Train过程开始之前调用。 | |||
@@ -250,6 +265,14 @@ class Callback(object): | |||
:return: | |||
""" | |||
pass | |||
def on_validation(self): | |||
""" | |||
如果Trainer中设置了验证,则会在每次需要验证时调用该函数 | |||
:return: | |||
""" | |||
pass | |||
def on_epoch_end(self): | |||
""" | |||
@@ -281,6 +304,8 @@ def _transfer(func): | |||
def wrapper(manager, *arg): | |||
returns = [] | |||
for callback in manager.callbacks: | |||
if callback.disabled: | |||
continue | |||
returns.append(getattr(callback, func.__name__)(*arg)) | |||
return returns | |||
@@ -297,22 +322,28 @@ class CallbackManager(Callback): | |||
""" | |||
super(CallbackManager, self).__init__() | |||
# set attribute of trainer environment | |||
self._env = env | |||
self.callbacks = [] | |||
if callbacks is not None: | |||
if isinstance(callbacks, list): | |||
if all([isinstance(cb, Callback) for cb in callbacks]) is True: | |||
self.callbacks.extend(callbacks) | |||
else: | |||
obj = [not isinstance(cb, Callback) for cb in callbacks][0] | |||
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") | |||
if callbacks: | |||
self.callbacks = self.prepare_callbacks(callbacks) | |||
def prepare_callbacks(self, callbacks): | |||
if not callbacks: | |||
return [] | |||
if isinstance(callbacks, list): | |||
if all([isinstance(cb, Callback) for cb in callbacks]) is True: | |||
pass | |||
else: | |||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | |||
for env_name, env_val in env.items(): | |||
for callback in self.callbacks: | |||
obj = [not isinstance(cb, Callback) for cb in callbacks][0] | |||
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") | |||
else: | |||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | |||
for env_name, env_val in self._env.items(): | |||
for callback in callbacks: | |||
setattr(callback, '_' + env_name, env_val) # Callback.trainer | |||
return callbacks | |||
@_transfer | |||
def on_train_begin(self): | |||
pass | |||
@@ -352,6 +383,10 @@ class CallbackManager(Callback): | |||
@_transfer | |||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||
pass | |||
@_transfer | |||
def on_validation(self): | |||
pass | |||
@_transfer | |||
def on_epoch_end(self): | |||
@@ -366,6 +401,25 @@ class CallbackManager(Callback): | |||
pass | |||
class DistCallbackManager(CallbackManager): | |||
def __init__(self, env, callbacks_all=None, callbacks_master=None): | |||
super(DistCallbackManager, self).__init__(env) | |||
assert 'trainer' in env | |||
is_master = env['trainer'].is_master | |||
self.patch_callback(callbacks_master, disabled=not is_master) | |||
self.callbacks_all = self.prepare_callbacks(callbacks_all) | |||
self.callbacks_master = self.prepare_callbacks(callbacks_master) | |||
self.callbacks = self.callbacks_all + self.callbacks_master | |||
def patch_callback(self, callbacks, disabled): | |||
if not callbacks: | |||
return | |||
if not isinstance(callbacks, (list, tuple)): | |||
callbacks = [callbacks] | |||
for cb in callbacks: | |||
cb._disabled = disabled | |||
class GradientClipCallback(Callback): | |||
""" | |||
别名::class:`fastNLP.GradientClipCallback` :class:`fastNLP.core.callback.GradientClipCallback` | |||
@@ -403,6 +457,9 @@ class GradientClipCallback(Callback): | |||
def on_backward_end(self): | |||
if self.step%self.update_every==0: | |||
if self.parameters is None: | |||
if getattr(self.trainer, 'fp16', ''): | |||
from apex import amp | |||
self.clip_fun(amp.master_params(self.optimizer), self.clip_value) | |||
self.clip_fun(self.model.parameters(), self.clip_value) | |||
else: | |||
self.clip_fun(self.parameters, self.clip_value) | |||
@@ -448,10 +505,9 @@ class FitlogCallback(Callback): | |||
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 | |||
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 | |||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 | |||
DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过 | |||
dict的方式传入。如果仅传入DataSet, 则被命名为test | |||
:param ~fastNLP.Tester tester: Tester对象,将在on_valid_end时调用。tester中的DataSet会被称为为`test` | |||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要 | |||
传入多个DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。data的结果的名称以'data'开头。 | |||
:param ~fastNLP.Tester,Dict[~fastNLP.Tester] tester: Tester对象,将在on_valid_end时调用。tester的结果的名称以'tester'开头 | |||
:param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值),如果数据集较大建议将该值设置得 | |||
大一些,不然会导致log文件巨大。默认为0, 即不要记录loss。 | |||
:param int verbose: 是否在终端打印evaluation的结果,0不打印。 | |||
@@ -465,21 +521,24 @@ class FitlogCallback(Callback): | |||
self._log_exception = log_exception | |||
assert isinstance(log_loss_every, int) and log_loss_every>=0 | |||
if tester is not None: | |||
assert isinstance(tester, Tester), "Only fastNLP.Tester allowed." | |||
assert isinstance(data, dict) or data is None, "If tester is not None, only dict[DataSet] allowed for data." | |||
if data is not None: | |||
assert 'test' not in data, "Cannot use `test` as DataSet key, when tester is passed." | |||
setattr(tester, 'verbose', 0) | |||
self.testers['test'] = tester | |||
if isinstance(tester, dict): | |||
for name, test in tester.items(): | |||
if not isinstance(test, Tester): | |||
raise TypeError(f"{name} in tester is not a valid fastNLP.Tester.") | |||
self.testers['tester-' + name] = test | |||
if isinstance(tester, Tester): | |||
self.testers['tester-test'] = tester | |||
for tester in self.testers.values(): | |||
setattr(tester, 'verbose', 0) | |||
if isinstance(data, dict): | |||
for key, value in data.items(): | |||
assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}." | |||
for key, value in data.items(): | |||
self.datasets[key] = value | |||
self.datasets['data-' + key] = value | |||
elif isinstance(data, DataSet): | |||
self.datasets['test'] = data | |||
else: | |||
self.datasets['data-test'] = data | |||
elif data is not None: | |||
raise TypeError("data receives dict[DataSet] or DataSet object.") | |||
self.verbose = verbose | |||
@@ -492,8 +551,11 @@ class FitlogCallback(Callback): | |||
if len(self.datasets) > 0: | |||
for key, data in self.datasets.items(): | |||
tester = Tester(data=data, model=self.model, batch_size=self.batch_size, metrics=self.trainer.metrics, | |||
verbose=0) | |||
tester = Tester(data=data, model=self.model, | |||
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), | |||
metrics=self.trainer.metrics, | |||
verbose=0, | |||
use_tqdm=self.trainer.use_tqdm) | |||
self.testers[key] = tester | |||
fitlog.add_progress(total_steps=self.n_steps) | |||
@@ -533,6 +595,65 @@ class FitlogCallback(Callback): | |||
fitlog.add_other(repr(exception), name='except_info') | |||
class EvaluateCallback(Callback): | |||
""" | |||
别名: :class:`fastNLP.EvaluateCallback` :class:`fastNLP.core.callback.EvaluateCallback` | |||
该callback用于扩展Trainer训练过程中只能对dev数据进行验证的问题。 | |||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 | |||
DataSet请通过dict的方式传入。 | |||
:param ~fastNLP.Tester,Dict[~fastNLP.DataSet] tester: Tester对象,将在on_valid_end时调用。 | |||
""" | |||
def __init__(self, data=None, tester=None): | |||
super().__init__() | |||
self.datasets = {} | |||
self.testers = {} | |||
if tester is not None: | |||
if isinstance(tester, dict): | |||
for name, test in tester.items(): | |||
if not isinstance(test, Tester): | |||
raise TypeError(f"{name} in tester is not a valid fastNLP.Tester.") | |||
self.testers['tester-' + name] = test | |||
if isinstance(tester, Tester): | |||
self.testers['tester-test'] = tester | |||
for tester in self.testers.values(): | |||
setattr(tester, 'verbose', 0) | |||
if isinstance(data, dict): | |||
for key, value in data.items(): | |||
assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}." | |||
for key, value in data.items(): | |||
self.datasets['data-' + key] = value | |||
elif isinstance(data, DataSet): | |||
self.datasets['data-test'] = data | |||
elif data is not None: | |||
raise TypeError("data receives dict[DataSet] or DataSet object.") | |||
def on_train_begin(self): | |||
if len(self.datasets) > 0and self.trainer.dev_data is None: | |||
raise RuntimeError("Trainer has no dev data, you cannot pass extra DataSet to do evaluation.") | |||
if len(self.datasets) > 0: | |||
for key, data in self.datasets.items(): | |||
tester = Tester(data=data, model=self.model, | |||
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), | |||
metrics=self.trainer.metrics, verbose=0, | |||
use_tqdm=self.trainer.use_tqdm) | |||
self.testers[key] = tester | |||
def on_valid_end(self, eval_result, metric_key, optimizer, better_result): | |||
if len(self.testers) > 0: | |||
for key, tester in self.testers.items(): | |||
try: | |||
eval_result = tester.test() | |||
self.pbar.write("Evaluation on {}:".format(key)) | |||
self.pbar.write(tester._format_eval_results(eval_result)) | |||
except Exception: | |||
self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) | |||
class LRScheduler(Callback): | |||
""" | |||
别名::class:`fastNLP.LRScheduler` :class:`fastNLP.core.callback.LRScheduler` | |||
@@ -884,3 +1005,59 @@ class EarlyStopError(CallbackException): | |||
def __init__(self, msg): | |||
super(EarlyStopError, self).__init__(msg) | |||
class EchoCallback(Callback): | |||
def __init__(self, name, out=sys.stdout): | |||
super(EchoCallback, self).__init__() | |||
self.name = name | |||
self.out = out | |||
def __getattribute__(self, item): | |||
if item.startswith('on_'): | |||
print('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid()), | |||
file=self.out) | |||
return super(EchoCallback, self).__getattribute__(item) | |||
class TesterCallback(Callback): | |||
def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None): | |||
super(TesterCallback, self).__init__() | |||
self.tester = Tester(data, model, | |||
metrics=metrics, batch_size=batch_size, | |||
num_workers=num_workers, verbose=0) | |||
# parse metric_key | |||
# increase_better is True. It means the exp result gets better if the indicator increases. | |||
# It is true by default. | |||
self.increase_better = True | |||
if metric_key is not None: | |||
self.increase_better = False if metric_key[0] == "-" else True | |||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||
else: | |||
self.metric_key = None | |||
self.score = None | |||
def on_validation(self): | |||
cur_score = self.tester.test() | |||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format( | |||
self.epoch, self.n_epochs, self.step, self.n_steps, | |||
self.tester._format_eval_results(cur_score)) | |||
self.logger.info(eval_str) | |||
is_better = self.compare_better(cur_score) | |||
if is_better: | |||
self.score = cur_score | |||
return cur_score, is_better | |||
def compare_better(self, a): | |||
if self.score is None: | |||
return True | |||
k = self.metric_key | |||
is_increase = self.score[k] <= a[k] # if equal, prefer more recent results | |||
if self.increase_better: | |||
return is_increase | |||
else: | |||
return not is_increase | |||
def on_train_end(self): | |||
self.logger.info('Evaluate on training ends.') | |||
self.on_validation() |
@@ -7,12 +7,14 @@ class Const: | |||
具体列表:: | |||
INPUT 模型的序列输入 words(复数words1, words2) | |||
CHAR_INPUT 模型character输入 chars(复数chars1, chars2) | |||
INPUT_LEN 序列长度 seq_len(复数seq_len1,seq_len2) | |||
OUTPUT 模型输出 pred(复数pred1, pred2) | |||
TARGET 真实目标 target(复数target1,target2) | |||
LOSS 损失函数 loss (复数loss1,loss2) | |||
INPUT 模型的序列输入 words(具有多列words时,依次使用words1, words2, ) | |||
CHAR_INPUT 模型character输入 chars(具有多列chars时,依次使用chars1, chars2) | |||
INPUT_LEN 序列长度 seq_len(具有多列seq_len时,依次使用seq_len1,seq_len2) | |||
OUTPUT 模型输出 pred(具有多列pred时,依次使用pred1, pred2) | |||
TARGET 真实目标 target(具有多列target时,依次使用target1,target2) | |||
LOSS 损失函数 loss (具有多列loss时,依次使用loss1,loss2) | |||
RAW_WORD 原文的词 raw_words (具有多列raw_words时,依次使用raw_words1, raw_words2) | |||
RAW_CHAR 原文的字 raw_chars (具有多列raw_chars时,依次使用raw_chars1, raw_chars2) | |||
""" | |||
INPUT = 'words' | |||
@@ -21,6 +23,8 @@ class Const: | |||
OUTPUT = 'pred' | |||
TARGET = 'target' | |||
LOSS = 'loss' | |||
RAW_WORD = 'raw_words' | |||
RAW_CHAR = 'raw_chars' | |||
@staticmethod | |||
def INPUTS(i): | |||
@@ -34,6 +38,16 @@ class Const: | |||
i = int(i) + 1 | |||
return Const.CHAR_INPUT + str(i) | |||
@staticmethod | |||
def RAW_WORDS(i): | |||
i = int(i) + 1 | |||
return Const.RAW_WORD + str(i) | |||
@staticmethod | |||
def RAW_CHARS(i): | |||
i = int(i) + 1 | |||
return Const.RAW_CHAR + str(i) | |||
@staticmethod | |||
def INPUT_LENS(i): | |||
"""得到第 i 个 ``INPUT_LEN`` 的命名""" | |||
@@ -291,6 +291,7 @@ import _pickle as pickle | |||
import warnings | |||
import numpy as np | |||
from copy import deepcopy | |||
from .field import AutoPadder | |||
from .field import FieldArray | |||
@@ -298,6 +299,7 @@ from .instance import Instance | |||
from .utils import _get_func_signature | |||
from .field import AppendToTargetOrInputException | |||
from .field import SetInputOrTargetException | |||
from .const import Const | |||
class DataSet(object): | |||
""" | |||
@@ -349,7 +351,11 @@ class DataSet(object): | |||
self.idx]) | |||
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) | |||
return self.dataset.field_arrays[item][self.idx] | |||
def items(self): | |||
ins = self.dataset[self.idx] | |||
return ins.items() | |||
def __repr__(self): | |||
return self.dataset[self.idx].__repr__() | |||
@@ -487,7 +493,7 @@ class DataSet(object): | |||
""" | |||
删除第index个instance | |||
:param int index: 需要删除的instance的index,从0开始 | |||
:param int index: 需要删除的instance的index,序号从0开始。 | |||
""" | |||
assert isinstance(index, int), "Only integer supported." | |||
if len(self) <= index: | |||
@@ -497,6 +503,7 @@ class DataSet(object): | |||
else: | |||
for field in self.field_arrays.values(): | |||
field.pop(index) | |||
return self | |||
def delete_field(self, field_name): | |||
""" | |||
@@ -505,7 +512,22 @@ class DataSet(object): | |||
:param str field_name: 需要删除的field的名称. | |||
""" | |||
self.field_arrays.pop(field_name) | |||
return self | |||
def copy_field(self, field_name, new_field_name): | |||
""" | |||
深度copy名为field_name的field到new_field_name | |||
:param str field_name: 需要copy的field。 | |||
:param str new_field_name: copy生成的field名称 | |||
:return: self | |||
""" | |||
if not self.has_field(field_name): | |||
raise KeyError(f"Field:{field_name} not found in DataSet.") | |||
fieldarray = deepcopy(self.get_field(field_name)) | |||
self.add_fieldarray(field_name=new_field_name, fieldarray=fieldarray) | |||
return self | |||
def has_field(self, field_name): | |||
""" | |||
判断DataSet中是否有名为field_name这个field | |||
@@ -566,7 +588,7 @@ class DataSet(object): | |||
raise KeyError("DataSet has no field named {}.".format(old_name)) | |||
return self | |||
def set_target(self, *field_names, flag=True): | |||
def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): | |||
""" | |||
将field_names的field设置为target | |||
@@ -577,11 +599,14 @@ class DataSet(object): | |||
:param str field_names: field的名称 | |||
:param bool flag: 将field_name的target状态设置为flag | |||
:param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一 | |||
行的数据进行类型和维度推断本列的数据的类型和维度。 | |||
""" | |||
assert isinstance(flag, bool), "Only bool type supported." | |||
for name in field_names: | |||
if name in self.field_arrays: | |||
try: | |||
self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) | |||
self.field_arrays[name].is_target = flag | |||
except SetInputOrTargetException as e: | |||
print(f"Cannot set field:{name} as target.") | |||
@@ -589,7 +614,7 @@ class DataSet(object): | |||
else: | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
def set_input(self, *field_names, flag=True): | |||
def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): | |||
""" | |||
将field_names的field设置为input:: | |||
@@ -598,10 +623,13 @@ class DataSet(object): | |||
:param str field_names: field的名称 | |||
:param bool flag: 将field_name的input状态设置为flag | |||
:param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一 | |||
行的数据进行类型和维度推断本列的数据的类型和维度。 | |||
""" | |||
for name in field_names: | |||
if name in self.field_arrays: | |||
try: | |||
self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) | |||
self.field_arrays[name].is_input = flag | |||
except SetInputOrTargetException as e: | |||
print(f"Cannot set field:{name} as input, exception happens at the {e.index} value.") | |||
@@ -695,7 +723,7 @@ class DataSet(object): | |||
results.append(func(ins[field_name])) | |||
except Exception as e: | |||
if idx != -1: | |||
print("Exception happens at the `{}`th instance.".format(idx)) | |||
print("Exception happens at the `{}`th(from 1) instance.".format(idx+1)) | |||
raise e | |||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | |||
@@ -760,10 +788,11 @@ class DataSet(object): | |||
results = [] | |||
for idx, ins in enumerate(self._inner_iter()): | |||
results.append(func(ins)) | |||
except Exception as e: | |||
except BaseException as e: | |||
if idx != -1: | |||
print("Exception happens at the `{}`th instance.".format(idx)) | |||
raise e | |||
# results = [func(ins) for ins in self._inner_iter()] | |||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | |||
@@ -773,7 +802,7 @@ class DataSet(object): | |||
return results | |||
def add_seq_len(self, field_name:str, new_field_name='seq_len'): | |||
def add_seq_len(self, field_name:str, new_field_name=Const.INPUT_LEN): | |||
""" | |||
将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。 | |||
@@ -0,0 +1,355 @@ | |||
""" | |||
正在开发中的分布式训练代码 | |||
""" | |||
import torch | |||
import torch.cuda | |||
import torch.optim | |||
import torch.distributed as dist | |||
from torch.utils.data.distributed import DistributedSampler | |||
from torch.nn.parallel import DistributedDataParallel as DDP | |||
import os | |||
from tqdm import tqdm | |||
import logging | |||
import time | |||
from datetime import datetime, timedelta | |||
from functools import partial | |||
from .batch import DataSetIter, BatchIter | |||
from .callback import DistCallbackManager, CallbackException, TesterCallback | |||
from .dataset import DataSet | |||
from .losses import _prepare_losser | |||
from .optimizer import Optimizer | |||
from .utils import _build_args | |||
from .utils import _move_dict_value_to_device | |||
from .utils import _get_func_signature | |||
from pkg_resources import parse_version | |||
__all__ = [ | |||
'get_local_rank', | |||
'DistTrainer', | |||
] | |||
def get_local_rank(): | |||
if 'LOCAL_RANK' in os.environ: | |||
return int(os.environ['LOCAL_RANK']) | |||
from argparse import ArgumentParser | |||
parser = ArgumentParser() | |||
parser.add_argument('--local_rank', type=int) | |||
args, _ = parser.parse_known_args() | |||
if 'local_rank' in args and args.local_rank: | |||
os.environ['LOCAL_RANK'] = str(args.local_rank) # for multiple calls for this function | |||
return args.local_rank | |||
raise RuntimeError('Please use "python -m torch.distributed.launch train_script.py') | |||
class DistTrainer(): | |||
""" | |||
Distributed Trainer that support distributed and mixed precision training | |||
""" | |||
def __init__(self, train_data, model, optimizer=None, loss=None, | |||
callbacks_all=None, callbacks_master=None, | |||
batch_size_per_gpu=8, n_epochs=1, | |||
num_data_workers=1, drop_last=False, | |||
dev_data=None, metrics=None, metric_key=None, | |||
update_every=1, print_every=10, validate_every=-1, | |||
log_path=None, | |||
save_every=-1, save_path=None, device='auto', | |||
fp16='', backend=None, init_method=None): | |||
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" | |||
if device == 'auto': | |||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||
if backend is None: | |||
backend = 'nccl' if device == 'cuda' else 'gloo' | |||
# init distributed | |||
if device == 'cuda': | |||
torch.cuda.set_device(get_local_rank()) | |||
self.device = torch.device("cuda", get_local_rank()) | |||
else: | |||
self.device = torch.device(device) | |||
dist.init_process_group(backend=backend, init_method=init_method) | |||
self.world_size = dist.get_world_size() | |||
self.rank = dist.get_rank() # unique id for each process | |||
self.model = model | |||
self.train_data = train_data | |||
self.batch_size_per_gpu = int(batch_size_per_gpu) | |||
self.n_epochs = int(n_epochs) | |||
self.num_data_workers = int(num_data_workers) | |||
self.drop_last = drop_last | |||
self.update_every = int(update_every) | |||
self.print_every = int(print_every) | |||
self.validate_every = int(validate_every) | |||
self.save_every = int(save_every) | |||
self.save_path = save_path | |||
self.losser = _prepare_losser(loss) | |||
self.fp16 = fp16 | |||
self.init_method = init_method | |||
self.backend = backend | |||
self.local_rank = get_local_rank() | |||
self._forward_func = model.forward | |||
self.callback_manager = DistCallbackManager( | |||
env={"trainer": self}, callbacks_all=callbacks_all, | |||
callbacks_master=callbacks_master) | |||
self.metric_key = metric_key | |||
model.to(self.device) | |||
optimizer = self._get_optimizer(optimizer) | |||
# init fp16, must before DataParallel init | |||
if len(self.fp16): | |||
assert isinstance(self.fp16, str), "Please set Apex AMP optimization level selected in ['O0', 'O1', 'O2', 'O3']" | |||
try: | |||
from apex import amp | |||
except ImportError: | |||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") | |||
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." | |||
assert device == 'cuda', "Amp requires cuda device" | |||
model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) | |||
# init DataParallel | |||
if parse_version(torch.__version__)>=parse_version('1.1'): | |||
self.model = DDP(model, device_ids=[self.local_rank], | |||
output_device=self.local_rank, find_unused_parameters=True) | |||
else: | |||
self.model = DDP(model, device_ids=[self.local_rank], | |||
output_device=self.local_rank) | |||
self.optimizer = optimizer | |||
self.sampler = DistributedSampler(self.train_data) | |||
self.data_iterator = self._get_data_iter(self.train_data) | |||
self.n_steps = self._get_n_steps() | |||
# for evaluation, only run eval on master proc | |||
if dev_data and metrics: | |||
cb = TesterCallback( | |||
dev_data, model, metrics, | |||
batch_size=batch_size_per_gpu, num_workers=num_data_workers) | |||
self.callback_manager.callbacks_master += \ | |||
self.callback_manager.prepare_callbacks([cb]) | |||
# Setup logging | |||
dist.barrier() | |||
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') | |||
if self.save_path: | |||
self.cp_save_path = os.path.join(self.save_path, 'checkpoints', self.start_time) | |||
else: | |||
self.cp_save_path = None | |||
# use INFO in the master, WARN for others | |||
logging.basicConfig(filename=log_path, | |||
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | |||
datefmt='%m/%d/%Y %H:%M:%S', | |||
level=logging.INFO if self.is_master else logging.WARN) | |||
self.logger = logging.getLogger(__name__) | |||
self.logger.info("Setup Distributed Trainer") | |||
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | |||
os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False)) | |||
self.logger.info("Num of processes: {}".format(self.world_size)) | |||
self.logger.info("Use device: {}".format(device)) | |||
self.logger.info("Training with fp16: {}, optimization level: {}".format( | |||
len(self.fp16) > 0, self.fp16 if self.fp16 else None)) | |||
def _get_n_steps(self): | |||
batch_size = self.world_size * self.batch_size_per_gpu | |||
return (len(self.train_data) // batch_size + int( | |||
len(self.train_data) % batch_size != 0)) * int(self.drop_last == 0) * self.n_epochs | |||
def _get_data_iter(self, dataset): | |||
if isinstance(dataset, DataSet): | |||
return DataSetIter( | |||
dataset=dataset, batch_size=self.batch_size_per_gpu, | |||
num_workers=self.num_data_workers, sampler=self.sampler, | |||
drop_last=self.drop_last | |||
) | |||
elif isinstance(dataset, BatchIter): | |||
return dataset | |||
else: | |||
raise TypeError("train_data type {} not support".format(type(dataset))) | |||
def _get_optimizer(self, optimizer): | |||
if isinstance(optimizer, torch.optim.Optimizer): | |||
return optimizer | |||
elif isinstance(optimizer, Optimizer): | |||
return optimizer.construct_from_pytorch(self.model.parameters()) | |||
elif optimizer is None: | |||
return torch.optim.Adam(self.model.parameters(), lr=4e-3) | |||
else: | |||
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | |||
@property | |||
def is_master(self): | |||
return self.rank == 0 | |||
def train(self, on_exception='auto'): | |||
try: | |||
self.logger.info("###### Training epochs started ######") | |||
self.logger.info('Total epochs: %d'% self.n_epochs) | |||
self.logger.info('Total steps: %d'% self.n_steps) | |||
self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu) | |||
self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size()) | |||
self.logger.info('Total num of samples: %d'% len(self.train_data)) | |||
self.logger.info("Num of callbacks for all workers: {}".format( | |||
len(self.callback_manager.callbacks_all))) | |||
self.logger.info("Num of callbacks for master workers: {}".format( | |||
len(self.callback_manager.callbacks_master))) | |||
self.logger.info("Callbacks for all workers: {}".format( | |||
[repr(cb) for cb in self.callback_manager.callbacks_all])) | |||
self.logger.info("Callbacks for master workers: {}".format( | |||
[repr(cb) for cb in self.callback_manager.callbacks_master])) | |||
start_time = time.time() | |||
results = {} | |||
if self.n_epochs <= 0: | |||
self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs)) | |||
results['seconds'] = 0. | |||
return results | |||
try: | |||
self.callback_manager.on_train_begin() | |||
self._train() | |||
self.callback_manager.on_train_end() | |||
except BaseException as e: | |||
self.callback_manager.on_exception(e) | |||
if on_exception == 'auto': | |||
if not isinstance(e, (CallbackException, KeyboardInterrupt)): | |||
raise e | |||
else: | |||
self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__)) | |||
elif on_exception == 'raise': | |||
raise e | |||
results['seconds'] = round(time.time() - start_time, 2) | |||
self.logger.info("###### Train finished ######") | |||
self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) | |||
return results | |||
finally: | |||
self.close() | |||
def _train(self): | |||
if self.fp16: | |||
# skip check, done in __init__() | |||
from apex import amp | |||
self.step = 0 | |||
self.epoch = 0 | |||
self.pbar = tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', | |||
leave=False, dynamic_ncols=True, disable=not self.is_master) | |||
pbar = self.pbar | |||
avg_loss = 0 | |||
data_iterator = self.data_iterator | |||
self.model.zero_grad() | |||
for epoch in range(1, self.n_epochs + 1): | |||
self.epoch = epoch | |||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||
# early stopping | |||
self.callback_manager.on_epoch_begin() | |||
for batch_x, batch_y in data_iterator: | |||
self.model.train() | |||
self.step += 1 | |||
_move_dict_value_to_device(batch_x, batch_y, device=self.device) | |||
indices = data_iterator.get_batch_indices() | |||
# negative sampling; replace unknown; re-weight batch_y | |||
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | |||
prediction = self._data_forward(self.model, batch_x) | |||
# edit prediction | |||
self.callback_manager.on_loss_begin(batch_y, prediction) | |||
loss = self._compute_loss(prediction, batch_y) | |||
avg_loss += loss.item() | |||
# Is loss NaN or inf? requires_grad = False | |||
self.callback_manager.on_backward_begin(loss) | |||
if self.fp16: | |||
with amp.scale_loss(loss, self.optimizer) as scale_loss: | |||
scale_loss.backward() | |||
else: | |||
loss.backward() | |||
self.callback_manager.on_backward_end() | |||
self._update() | |||
self.callback_manager.on_step_end() | |||
if self.step % self.print_every == 0: | |||
avg_loss = float(avg_loss) / self.print_every | |||
print_output = "loss:{:<6.5f}".format(avg_loss) | |||
pbar.update(self.print_every) | |||
pbar.set_postfix_str(print_output) | |||
avg_loss = 0 | |||
self.callback_manager.on_batch_end() | |||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | |||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)): | |||
self.callback_manager.on_valid_begin() | |||
eval_res = self.callback_manager.on_validation() | |||
eval_res = list(filter(lambda x: x is not None, eval_res)) | |||
if len(eval_res): | |||
eval_res, is_better = list(zip(*eval_res)) | |||
else: | |||
eval_res, is_better = None, None | |||
self.callback_manager.on_valid_end( | |||
eval_res, self.metric_key, self.optimizer, is_better) | |||
dist.barrier() | |||
if self.cp_save_path and \ | |||
self.save_every > 0 and \ | |||
self.step % self.save_every == 0: | |||
self.save_check_point() | |||
# ================= mini-batch end ==================== # | |||
if self.save_every < 0 and self.cp_save_path: | |||
self.save_check_point() | |||
# lr decay; early stopping | |||
self.callback_manager.on_epoch_end() | |||
# =============== epochs end =================== # | |||
pbar.close() | |||
self.pbar = None | |||
# ============ tqdm end ============== # | |||
def _update(self): | |||
"""Perform weight update on a model. | |||
""" | |||
if self.step % self.update_every == 0: | |||
self.optimizer.step() | |||
self.model.zero_grad() | |||
def _data_forward(self, network, x): | |||
x = _build_args(self._forward_func, **x) | |||
y = network(**x) | |||
if not isinstance(y, dict): | |||
raise TypeError( | |||
f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") | |||
return y | |||
def _compute_loss(self, predict, truth): | |||
"""Compute loss given prediction and ground truth. | |||
:param predict: prediction dict, produced by model.forward | |||
:param truth: ground truth dict, produced by batch_y | |||
:return: a scalar | |||
""" | |||
loss = self.losser(predict, truth) | |||
if self.update_every > 1: | |||
loss = loss / self.update_every | |||
return loss.mean() | |||
def save_check_point(self, only_params=False): | |||
# only master save models | |||
if self.is_master: | |||
os.makedirs(self.cp_save_path, exist_ok=True) | |||
path = os.path.join(self.cp_save_path, 'checkpoint-{}.bin'.format(self.step)) | |||
self.logger.info("Save checkpoint to {}".format(path)) | |||
model_to_save = self.model.module | |||
if only_params: | |||
model_to_save = model_to_save.state_dict() | |||
torch.save(model_to_save, path) | |||
def close(self): | |||
dist.destroy_process_group() |
@@ -7,6 +7,7 @@ from typing import Any | |||
from abc import abstractmethod | |||
from copy import deepcopy | |||
from collections import Counter | |||
from .utils import _is_iterable | |||
class SetInputOrTargetException(Exception): | |||
def __init__(self, msg, index=None, field_name=None): | |||
@@ -23,7 +24,8 @@ class AppendToTargetOrInputException(Exception): | |||
self.field_name = field_name # 标示当前field的名称 | |||
class FieldArray: | |||
def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False): | |||
def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False, | |||
use_1st_ins_infer_dim_type=True): | |||
if len(content)==0: | |||
raise RuntimeError("Empty fieldarray is not allowed.") | |||
_content = content | |||
@@ -38,6 +40,7 @@ class FieldArray: | |||
# 根据input的情况设置input,target等 | |||
self._cell_ndim = None # 多少维度 | |||
self.dtype = None # 最内层的element都是什么类型的 | |||
self._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) | |||
self._is_input = False | |||
self._is_target = False | |||
@@ -77,7 +80,7 @@ class FieldArray: | |||
if value is True and \ | |||
self._is_target is False and \ | |||
self._ignore_type is False: | |||
self._check_dtype_and_ndim() | |||
self._check_dtype_and_ndim(only_check_1st_ins_dim_type=self._use_1st_ins_infer_dim_type) | |||
if value is False and self._is_target is False: | |||
self.dtype = None | |||
self._cell_ndim = None | |||
@@ -95,32 +98,34 @@ class FieldArray: | |||
if value is True and \ | |||
self._is_input is False and \ | |||
self._ignore_type is False: | |||
self._check_dtype_and_ndim() | |||
self._check_dtype_and_ndim(only_check_1st_ins_dim_type=self._use_1st_ins_infer_dim_type) | |||
if value is False and self._is_input is False: | |||
self.dtype = None | |||
self._cell_ndim = None | |||
self._is_target = value | |||
def _check_dtype_and_ndim(self): | |||
def _check_dtype_and_ndim(self, only_check_1st_ins_dim_type=True): | |||
""" | |||
检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有 | |||
通过将直接报错. | |||
:param bool only_check_1st_ins_dim_type: 是否只检查第一个元素的type和dim | |||
:return: | |||
""" | |||
cell_0 = self.content[0] | |||
index = 0 | |||
try: | |||
type_0, dim_0 = _get_ele_type_and_dim(cell_0) | |||
for cell in self.content[1:]: | |||
index += 1 | |||
type_i, dim_i = _get_ele_type_and_dim(cell) | |||
if type_i!=type_0: | |||
raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}." | |||
".".format(type_i, index, type_0)) | |||
if dim_0!=dim_i: | |||
raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with " | |||
"dimension:{}.".format(dim_i, index, dim_0)) | |||
if not only_check_1st_ins_dim_type: | |||
for cell in self.content[1:]: | |||
index += 1 | |||
type_i, dim_i = _get_ele_type_and_dim(cell) | |||
if type_i!=type_0: | |||
raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}." | |||
".".format(type_i, index, type_0)) | |||
if dim_0!=dim_i: | |||
raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with " | |||
"dimension:{}.".format(dim_i, index, dim_0)) | |||
self._cell_ndim = dim_0 | |||
self.dtype = type_0 | |||
except SetInputOrTargetException as e: | |||
@@ -132,7 +137,7 @@ class FieldArray: | |||
:param val: 把该val append到fieldarray。 | |||
:return: | |||
""" | |||
if (self._is_target or self._is_input) and self._ignore_type is False: | |||
if (self._is_target or self._is_input) and self._ignore_type is False and not self._use_1st_ins_infer_dim_type: | |||
type_, dim_ = _get_ele_type_and_dim(val) | |||
if self.dtype!=type_: | |||
raise AppendToTargetOrInputException(f"Value(type:{type_}) are of different types with " | |||
@@ -144,6 +149,14 @@ class FieldArray: | |||
else: | |||
self.content.append(val) | |||
def pop(self, index): | |||
""" | |||
删除该field中index处的元素 | |||
:param int index: 从0开始的数据下标。 | |||
:return: | |||
""" | |||
self.content.pop(index) | |||
def __getitem__(self, indices): | |||
return self.get(indices, pad=False) | |||
@@ -431,15 +444,6 @@ def _get_ele_type_and_dim(cell:Any, dim=0): | |||
raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") | |||
def _is_iterable(value): | |||
# 检查是否是iterable的, duck typing | |||
try: | |||
iter(value) | |||
return True | |||
except BaseException as e: | |||
return False | |||
class Padder: | |||
""" | |||
别名::class:`fastNLP.Padder` :class:`fastNLP.core.field.Padder` | |||
@@ -35,6 +35,13 @@ class Instance(object): | |||
:param Any field: 新增field的内容 | |||
""" | |||
self.fields[field_name] = field | |||
def items(self): | |||
""" | |||
返回一个迭代器,迭代器返回两个内容,第一个内容是field_name, 第二个内容是field_value | |||
:return: | |||
""" | |||
return self.fields.items() | |||
def __getitem__(self, name): | |||
if name in self.fields: | |||
@@ -28,6 +28,7 @@ from .utils import _check_arg_dict_list | |||
from .utils import _check_function_or_method | |||
from .utils import _get_func_signature | |||
from .utils import seq_len_to_mask | |||
import warnings | |||
class LossBase(object): | |||
@@ -225,8 +226,9 @@ class CrossEntropyLoss(LossBase): | |||
def get_loss(self, pred, target, seq_len=None): | |||
if pred.dim() > 2: | |||
if pred.size(1) != target.size(1): | |||
pred = pred.transpose(1, 2) | |||
if pred.size(1) != target.size(1): # 有可能顺序替换了 | |||
raise RuntimeError("It seems like that your prediction's shape is (batch_size, num_labels, max_len)." | |||
" It should be (batch_size, max_len, num_labels).") | |||
pred = pred.reshape(-1, pred.size(-1)) | |||
target = target.reshape(-1) | |||
if seq_len is not None: | |||
@@ -624,7 +624,7 @@ class SpanFPreRecMetric(MetricBase): | |||
f, pre, rec = self._compute_f_pre_rec(tp, fn, fp) | |||
f_sum += f | |||
pre_sum += pre | |||
rec_sum + rec | |||
rec_sum += rec | |||
if not self.only_gross and tag != '': # tag!=''防止无tag的情况 | |||
f_key = 'f-{}'.format(tag) | |||
pre_key = 'pre-{}'.format(tag) | |||
@@ -814,8 +814,8 @@ class ExtractiveQAMetric(MetricBase): | |||
if not self.right_open: | |||
e += 1 | |||
te += 1 | |||
if ts == 0 and te == int(not self.right_open): | |||
if s == 0 and e == int(not self.right_open): | |||
if ts == 0 and te == 1: | |||
if s == 0 and e == 1: | |||
self.no_ans_correct += 1 | |||
self.no2no += 1 | |||
else: | |||
@@ -49,7 +49,7 @@ class NullOptimizer(Optimizer): | |||
super().__init__(None) | |||
def construct_from_pytorch(self, model_params): | |||
pass | |||
return self | |||
def __getattr__(self, item): | |||
def pass_func(*args, **kwargs): | |||
@@ -25,9 +25,9 @@ class Sampler(object): | |||
def __call__(self, data_set): | |||
""" | |||
:param DataSet data_set: `DataSet` 对象, 需要Sample的数据 | |||
:return result: list(int) 其中元素的下标序列, ``data_set`` 中元素会按 ``result`` 中顺序取出 | |||
""" | |||
:param DataSet data_set: `DataSet` 对象, 需要Sample的数据 | |||
:return result: list(int) 其中元素的下标序列, ``data_set`` 中元素会按 ``result`` 中顺序取出 | |||
""" | |||
raise NotImplementedError | |||
@@ -32,9 +32,16 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation | |||
""" | |||
import time | |||
import torch | |||
import torch.nn as nn | |||
try: | |||
from tqdm.auto import tqdm | |||
except: | |||
from .utils import _pseudo_tqdm as tqdm | |||
from .batch import BatchIter, DataSetIter | |||
from .dataset import DataSet | |||
from .metrics import _prepare_metrics | |||
@@ -47,6 +54,7 @@ from .utils import _get_func_signature | |||
from .utils import _get_model_device | |||
from .utils import _move_model_to_device | |||
from ._parallel_utils import _data_parallel_wrapper | |||
from ._parallel_utils import _model_contains_inner_module | |||
from functools import partial | |||
__all__ = [ | |||
@@ -79,13 +87,12 @@ class Tester(object): | |||
如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 | |||
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | |||
:param bool use_tqdm: 是否使用tqdm来显示测试进度; 如果为False,则不会显示任何内容。 | |||
""" | |||
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1): | |||
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True): | |||
super(Tester, self).__init__() | |||
if not isinstance(data, DataSet): | |||
raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.") | |||
if not isinstance(model, nn.Module): | |||
raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.") | |||
@@ -95,6 +102,7 @@ class Tester(object): | |||
self._model = _move_model_to_device(model, device=device) | |||
self.batch_size = batch_size | |||
self.verbose = verbose | |||
self.use_tqdm = use_tqdm | |||
if isinstance(data, DataSet): | |||
self.data_iterator = DataSetIter( | |||
@@ -106,19 +114,22 @@ class Tester(object): | |||
# check predict | |||
if (hasattr(self._model, 'predict') and callable(self._model.predict)) or \ | |||
(isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and | |||
callable(self._model.module.predict)): | |||
(_model_contains_inner_module(self._model) and hasattr(self._model.module, 'predict') and | |||
callable(self._model.module.predict)): | |||
if isinstance(self._model, nn.DataParallel): | |||
self._predict_func_wrapper = partial(_data_parallel_wrapper('predict', | |||
self._model.device_ids, | |||
self._model.output_device), | |||
network=self._model.module) | |||
self._predict_func = self._model.module.predict # 用于匹配参数 | |||
elif isinstance(self._model, nn.parallel.DistributedDataParallel): | |||
self._predict_func = self._model.module.predict | |||
self._predict_func_wrapper = self._model.module.predict # 用于调用 | |||
else: | |||
self._predict_func = self._model.predict | |||
self._predict_func_wrapper = self._model.predict | |||
else: | |||
if isinstance(self._model, nn.DataParallel): | |||
if _model_contains_inner_module(model): | |||
self._predict_func_wrapper = self._model.forward | |||
self._predict_func = self._model.module.forward | |||
else: | |||
@@ -139,21 +150,39 @@ class Tester(object): | |||
eval_results = {} | |||
try: | |||
with torch.no_grad(): | |||
for batch_x, batch_y in data_iterator: | |||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||
pred_dict = self._data_forward(self._predict_func, batch_x) | |||
if not isinstance(pred_dict, dict): | |||
raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} " | |||
f"must be `dict`, got {type(pred_dict)}.") | |||
if not self.use_tqdm: | |||
from .utils import _pseudo_tqdm as inner_tqdm | |||
else: | |||
inner_tqdm = tqdm | |||
with inner_tqdm(total=len(data_iterator), leave=False, dynamic_ncols=True) as pbar: | |||
pbar.set_description_str(desc="Test") | |||
start_time = time.time() | |||
for batch_x, batch_y in data_iterator: | |||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||
pred_dict = self._data_forward(self._predict_func, batch_x) | |||
if not isinstance(pred_dict, dict): | |||
raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} " | |||
f"must be `dict`, got {type(pred_dict)}.") | |||
for metric in self.metrics: | |||
metric(pred_dict, batch_y) | |||
if self.use_tqdm: | |||
pbar.update() | |||
for metric in self.metrics: | |||
metric(pred_dict, batch_y) | |||
for metric in self.metrics: | |||
eval_result = metric.get_metric() | |||
if not isinstance(eval_result, dict): | |||
raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be " | |||
f"`dict`, got {type(eval_result)}") | |||
metric_name = metric.__class__.__name__ | |||
eval_results[metric_name] = eval_result | |||
eval_result = metric.get_metric() | |||
if not isinstance(eval_result, dict): | |||
raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be " | |||
f"`dict`, got {type(eval_result)}") | |||
metric_name = metric.__class__.__name__ | |||
eval_results[metric_name] = eval_result | |||
end_time = time.time() | |||
test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!' | |||
pbar.write(test_str) | |||
pbar.close() | |||
except _CheckError as e: | |||
prev_func_signature = _get_func_signature(self._predict_func) | |||
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | |||
@@ -352,6 +352,7 @@ from .utils import _move_dict_value_to_device | |||
from .utils import _get_func_signature | |||
from .utils import _get_model_device | |||
from .utils import _move_model_to_device | |||
from ._parallel_utils import _model_contains_inner_module | |||
class Trainer(object): | |||
@@ -389,8 +390,8 @@ class Trainer(object): | |||
要指定以哪个指标为准。另外有些指标是越小效果越好,比如语言模型的困惑度,这种情况下,在key前面增加一个'-'来表 | |||
明验证时,值越小越好(比如: "-ppl")。仅在传入dev_data时有效。 | |||
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 | |||
:param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模型。 | |||
保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | |||
:param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存 | |||
最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | |||
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 | |||
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | |||
的计算位置进行管理。支持以下的输入: | |||
@@ -421,7 +422,7 @@ class Trainer(object): | |||
num_workers=0, n_epochs=10, print_every=5, | |||
dev_data=None, metrics=None, metric_key=None, | |||
validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False, | |||
callbacks=None, check_code_level=0): | |||
callbacks=None, check_code_level=0, **kwargs): | |||
if prefetch and num_workers==0: | |||
num_workers = 1 | |||
if prefetch: | |||
@@ -430,23 +431,23 @@ class Trainer(object): | |||
super(Trainer, self).__init__() | |||
if not isinstance(model, nn.Module): | |||
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | |||
# check metrics and dev_data | |||
if (not metrics) and dev_data is not None: | |||
raise ValueError("No metric for dev_data evaluation.") | |||
if metrics and (dev_data is None): | |||
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | |||
# check update every | |||
assert update_every >= 1, "update_every must be no less than 1." | |||
self.update_every = int(update_every) | |||
# check save_path | |||
if not (save_path is None or isinstance(save_path, str)): | |||
raise ValueError("save_path can only be None or `str`.") | |||
# prepare evaluate | |||
metrics = _prepare_metrics(metrics) | |||
# parse metric_key | |||
# increase_better is True. It means the exp result gets better if the indicator increases. | |||
# It is true by default. | |||
@@ -458,30 +459,69 @@ class Trainer(object): | |||
self.metric_key = None | |||
# prepare loss | |||
losser = _prepare_losser(loss) | |||
# sampler check | |||
if sampler is not None and not isinstance(sampler, Sampler): | |||
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | |||
if sampler is None: | |||
sampler = RandomSampler() | |||
elif hasattr(sampler, 'set_batch_size'): | |||
sampler.set_batch_size(batch_size) | |||
if isinstance(train_data, BatchIter): | |||
if sampler is not None: | |||
warnings.warn("sampler is ignored when train_data is a BatchIter.") | |||
if num_workers>0: | |||
warnings.warn("num_workers is ignored when train_data is BatchIter.") | |||
if drop_last: | |||
warnings.warn("drop_last is ignored when train_data is BatchIter.") | |||
if isinstance(model, nn.parallel.DistributedDataParallel): # 如果是分布式的 | |||
# device为None | |||
if device is not None: | |||
warnings.warn("device is ignored when model is nn.parallel.DistributedDataParallel.") | |||
device = None | |||
# Sampler要是分布式的 | |||
if sampler is None: | |||
sampler = torch.utils.data.DistributedSampler(train_data) | |||
elif not isinstance(sampler, torch.utils.data.DistributedSampler): | |||
raise TypeError("When using nn.parallel.DistributedDataParallel, " | |||
"sampler must be None or torch.utils.data.DistributedSampler.") | |||
# 不能保存模型 | |||
if save_path: | |||
raise RuntimeError("Saving model in Distributed situation is not allowed right now.") | |||
else: | |||
# sampler check | |||
if sampler is not None and not isinstance(sampler, (Sampler, torch.utils.data.Sampler)): | |||
raise ValueError(f"The type of sampler should be fastNLP.BaseSampler or pytorch's Sampler, got {type(sampler)}") | |||
if sampler is None: | |||
sampler = RandomSampler() | |||
elif hasattr(sampler, 'set_batch_size'): | |||
sampler.set_batch_size(batch_size) | |||
if isinstance(train_data, DataSet): | |||
self.data_iterator = DataSetIter( | |||
dataset=train_data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, drop_last=drop_last) | |||
elif isinstance(train_data, BatchIter): | |||
self.data_iterator = train_data | |||
train_data = train_data.dataset | |||
else: | |||
raise TypeError("train_data type {} not support".format(type(train_data))) | |||
if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): | |||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||
metric_key=self.metric_key, check_level=check_code_level, | |||
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | |||
# _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 | |||
self.model = _move_model_to_device(model, device=device) | |||
if _model_contains_inner_module(self.model): | |||
self._forward_func = self.model.module.forward | |||
else: | |||
self._forward_func = self.model.forward | |||
if check_code_level > -1: | |||
# _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的field名与模型的输入 | |||
# 名是否匹配 | |||
dev_dataset = dev_data | |||
if isinstance(dev_data, BatchIter): | |||
dev_dataset = None | |||
warnings.warn("dev_data is of BatchIter type, ignore validation checking.") | |||
check_batch_size = min(batch_size, DEFAULT_CHECK_BATCH_SIZE) | |||
if isinstance(self.model, nn.DataParallel): | |||
_num_devices = len(self.model.device_ids) | |||
if batch_size//_num_devices>1: # 如果多卡是每个卡可以分多个数据的,则用每个卡给两个sample | |||
check_batch_size = max(len(self.model.device_ids)*2, check_batch_size) | |||
else: | |||
check_batch_size = max(len(self.model.device_ids), check_batch_size) | |||
_check_code(dataset=train_data, model=self.model, losser=losser, forward_func=self._forward_func, metrics=metrics, | |||
dev_data=dev_dataset, metric_key=self.metric_key, check_level=check_code_level, | |||
batch_size=check_batch_size) | |||
self.train_data = train_data | |||
self.dev_data = dev_data # If None, No validation. | |||
@@ -496,8 +536,7 @@ class Trainer(object): | |||
self.best_dev_epoch = None | |||
self.best_dev_step = None | |||
self.best_dev_perf = None | |||
self.n_steps = (len(self.train_data) // self.batch_size + int( | |||
len(self.train_data) % self.batch_size != 0)) * int(drop_last==0) * self.n_epochs | |||
self.n_steps = len(self.data_iterator) * self.n_epochs | |||
if isinstance(optimizer, torch.optim.Optimizer): | |||
self.optimizer = optimizer | |||
@@ -507,22 +546,23 @@ class Trainer(object): | |||
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) | |||
else: | |||
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | |||
self.use_tqdm = use_tqdm | |||
self.pbar = None | |||
self.print_every = abs(self.print_every) | |||
self.kwargs = kwargs | |||
if self.dev_data is not None: | |||
self.tester = Tester(model=self.model, | |||
data=self.dev_data, | |||
metrics=self.metrics, | |||
batch_size=self.batch_size, | |||
batch_size=kwargs.get("dev_batch_size", self.batch_size), | |||
device=None, # 由上面的部分处理device | |||
verbose=0) | |||
verbose=0, | |||
use_tqdm=self.use_tqdm) | |||
self.step = 0 | |||
self.start_time = None # start timestamp | |||
self.callback_manager = CallbackManager(env={"trainer": self}, | |||
callbacks=callbacks) | |||
@@ -558,7 +598,7 @@ class Trainer(object): | |||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||
start_time = time.time() | |||
print("training epochs started " + self.start_time, flush=True) | |||
try: | |||
self.callback_manager.on_train_begin() | |||
self._train() | |||
@@ -571,7 +611,7 @@ class Trainer(object): | |||
raise e | |||
elif on_exception == 'raise': | |||
raise e | |||
if self.dev_data is not None and self.best_dev_perf is not None: | |||
print( | |||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||
@@ -589,21 +629,17 @@ class Trainer(object): | |||
finally: | |||
pass | |||
results['seconds'] = round(time.time() - start_time, 2) | |||
return results | |||
def _train(self): | |||
if not self.use_tqdm: | |||
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | |||
from .utils import _pseudo_tqdm as inner_tqdm | |||
else: | |||
inner_tqdm = tqdm | |||
self.step = 0 | |||
self.epoch = 0 | |||
start = time.time() | |||
if isinstance(self.model, nn.DataParallel): | |||
self._forward_func = self.model.module.forward | |||
else: | |||
self._forward_func = self.model.forward | |||
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||
self.pbar = pbar | |||
avg_loss = 0 | |||
@@ -621,21 +657,21 @@ class Trainer(object): | |||
# negative sampling; replace unknown; re-weight batch_y | |||
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | |||
prediction = self._data_forward(self.model, batch_x) | |||
# edit prediction | |||
self.callback_manager.on_loss_begin(batch_y, prediction) | |||
loss = self._compute_loss(prediction, batch_y).mean() | |||
avg_loss += loss.item() | |||
loss = loss / self.update_every | |||
# Is loss NaN or inf? requires_grad = False | |||
self.callback_manager.on_backward_begin(loss) | |||
self._grad_backward(loss) | |||
self.callback_manager.on_backward_end() | |||
self._update() | |||
self.callback_manager.on_step_end() | |||
if self.step % self.print_every == 0: | |||
avg_loss = float(avg_loss) / self.print_every | |||
if self.use_tqdm: | |||
@@ -649,29 +685,29 @@ class Trainer(object): | |||
pbar.set_postfix_str(print_output) | |||
avg_loss = 0 | |||
self.callback_manager.on_batch_end() | |||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | |||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | |||
and self.dev_data is not None: | |||
eval_res = self._do_validation(epoch=epoch, step=self.step) | |||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||
eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||
self.n_steps) + \ | |||
self.tester._format_eval_results(eval_res) | |||
pbar.write(eval_str + '\n') | |||
# ================= mini-batch end ==================== # | |||
# lr decay; early stopping | |||
self.callback_manager.on_epoch_end() | |||
# =============== epochs end =================== # | |||
pbar.close() | |||
self.pbar = None | |||
# ============ tqdm end ============== # | |||
def _do_validation(self, epoch, step): | |||
self.callback_manager.on_valid_begin() | |||
res = self.tester.test() | |||
is_better_eval = False | |||
if self._better_eval_result(res): | |||
if self.save_path is not None: | |||
@@ -686,7 +722,7 @@ class Trainer(object): | |||
# get validation results; adjust optimizer | |||
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) | |||
return res | |||
def _mode(self, model, is_test=False): | |||
"""Train mode or Test mode. This is for PyTorch currently. | |||
@@ -698,14 +734,14 @@ class Trainer(object): | |||
model.eval() | |||
else: | |||
model.train() | |||
def _update(self): | |||
"""Perform weight update on a model. | |||
""" | |||
if self.step % self.update_every == 0: | |||
self.optimizer.step() | |||
def _data_forward(self, network, x): | |||
x = _build_args(self._forward_func, **x) | |||
y = network(**x) | |||
@@ -713,7 +749,7 @@ class Trainer(object): | |||
raise TypeError( | |||
f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") | |||
return y | |||
def _grad_backward(self, loss): | |||
"""Compute gradient with link rules. | |||
@@ -724,7 +760,7 @@ class Trainer(object): | |||
if (self.step-1) % self.update_every == 0: | |||
self.model.zero_grad() | |||
loss.backward() | |||
def _compute_loss(self, predict, truth): | |||
"""Compute loss given prediction and ground truth. | |||
@@ -733,7 +769,7 @@ class Trainer(object): | |||
:return: a scalar | |||
""" | |||
return self.losser(predict, truth) | |||
def _save_model(self, model, model_name, only_param=False): | |||
""" 存储不含有显卡信息的state_dict或model | |||
:param model: | |||
@@ -745,7 +781,7 @@ class Trainer(object): | |||
model_path = os.path.join(self.save_path, model_name) | |||
if not os.path.exists(self.save_path): | |||
os.makedirs(self.save_path, exist_ok=True) | |||
if isinstance(model, nn.DataParallel): | |||
if _model_contains_inner_module(model): | |||
model = model.module | |||
if only_param: | |||
state_dict = model.state_dict() | |||
@@ -756,7 +792,7 @@ class Trainer(object): | |||
model.cpu() | |||
torch.save(model, model_path) | |||
model.to(self._model_device) | |||
def _load_model(self, model, model_name, only_param=False): | |||
# 返回bool值指示是否成功reload模型 | |||
if self.save_path is not None: | |||
@@ -765,7 +801,7 @@ class Trainer(object): | |||
states = torch.load(model_path) | |||
else: | |||
states = torch.load(model_path).state_dict() | |||
if isinstance(model, nn.DataParallel): | |||
if _model_contains_inner_module(model): | |||
model.module.load_state_dict(states) | |||
else: | |||
model.load_state_dict(states) | |||
@@ -774,7 +810,7 @@ class Trainer(object): | |||
else: | |||
return False | |||
return True | |||
def _better_eval_result(self, metrics): | |||
"""Check if the current epoch yields better validation results. | |||
@@ -800,6 +836,9 @@ class Trainer(object): | |||
is_better = False | |||
return is_better | |||
@property | |||
def is_master(self): | |||
return True | |||
DEFAULT_CHECK_BATCH_SIZE = 2 | |||
DEFAULT_CHECK_NUM_BATCH = 2 | |||
@@ -821,14 +860,15 @@ def _get_value_info(_dict): | |||
strs.append(_str) | |||
return strs | |||
from numbers import Number | |||
from .batch import _to_tensor | |||
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | |||
dev_data=None, metric_key=None, | |||
check_level=0): | |||
def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAULT_CHECK_BATCH_SIZE, | |||
dev_data=None, metric_key=None, check_level=0): | |||
# check get_loss 方法 | |||
model_devcie = _get_model_device(model=model) | |||
model_device = _get_model_device(model=model) | |||
def _iter(): | |||
start_idx = 0 | |||
while start_idx<len(dataset): | |||
@@ -849,7 +889,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
start_idx += batch_size | |||
for batch_count, (batch_x, batch_y) in enumerate(_iter()): | |||
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | |||
_move_dict_value_to_device(batch_x, batch_y, device=model_device) | |||
# forward check | |||
if batch_count == 0: | |||
info_str = "" | |||
@@ -868,15 +908,11 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
else: | |||
info_str += 'There is no target field.' | |||
print(info_str) | |||
_check_forward_error(forward_func=model.forward, dataset=dataset, | |||
_check_forward_error(forward_func=forward_func, dataset=dataset, | |||
batch_x=batch_x, check_level=check_level) | |||
if isinstance(model, nn.DataParallel): | |||
forward_func = model.module.forward | |||
else: | |||
forward_func = model.forward | |||
refined_batch_x = _build_args(forward_func, **batch_x) | |||
pred_dict = model(**refined_batch_x) | |||
func_signature = _get_func_signature(model.forward) | |||
func_signature = _get_func_signature(forward_func) | |||
if not isinstance(pred_dict, dict): | |||
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.") | |||
@@ -896,7 +932,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
loss.backward() | |||
except _CheckError as e: | |||
# TODO: another error raised if _CheckError caught | |||
pre_func_signature = _get_func_signature(model.forward) | |||
pre_func_signature = _get_func_signature(forward_func) | |||
_check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature, | |||
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, | |||
dataset=dataset, check_level=check_level) | |||
@@ -906,7 +942,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
if dev_data is not None: | |||
tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | |||
batch_size=batch_size, verbose=-1) | |||
batch_size=batch_size, verbose=-1, use_tqdm=False) | |||
evaluate_results = tester.test() | |||
_check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics) | |||
@@ -4,6 +4,7 @@ utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户 | |||
__all__ = [ | |||
"cache_results", | |||
"seq_len_to_mask", | |||
"get_seq_len" | |||
] | |||
import _pickle | |||
@@ -62,7 +63,6 @@ def _prepare_cache_filepath(filepath): | |||
os.makedirs(cache_dir) | |||
# TODO 可以保存下缓存时的参数,如果load的时候发现参数不一致,发出警告。 | |||
def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
""" | |||
别名::class:`fastNLP.cache_results` :class:`fastNLP.core.uitls.cache_results` | |||
@@ -188,50 +188,6 @@ def _save_model(model, model_name, save_dir, only_param=False): | |||
torch.save(model, model_path) | |||
model.to(_model_device) | |||
# def save_pickle(obj, pickle_path, file_name): | |||
# """Save an object into a pickle file. | |||
# | |||
# :param obj: an object | |||
# :param pickle_path: str, the directory where the pickle file is to be saved | |||
# :param file_name: str, the name of the pickle file. In general, it should be ended by "pkl". | |||
# """ | |||
# if not os.path.exists(pickle_path): | |||
# os.mkdir(pickle_path) | |||
# print("make dir {} before saving pickle file".format(pickle_path)) | |||
# with open(os.path.join(pickle_path, file_name), "wb") as f: | |||
# _pickle.dump(obj, f) | |||
# print("{} saved in {}".format(file_name, pickle_path)) | |||
# | |||
# | |||
# def load_pickle(pickle_path, file_name): | |||
# """Load an object from a given pickle file. | |||
# | |||
# :param pickle_path: str, the directory where the pickle file is. | |||
# :param file_name: str, the name of the pickle file. | |||
# :return obj: an object stored in the pickle | |||
# """ | |||
# with open(os.path.join(pickle_path, file_name), "rb") as f: | |||
# obj = _pickle.load(f) | |||
# print("{} loaded from {}".format(file_name, pickle_path)) | |||
# return obj | |||
# | |||
# | |||
# def pickle_exist(pickle_path, pickle_name): | |||
# """Check if a given pickle file exists in the directory. | |||
# | |||
# :param pickle_path: the directory of target pickle file | |||
# :param pickle_name: the filename of target pickle file | |||
# :return: True if file exists else False | |||
# """ | |||
# if not os.path.exists(pickle_path): | |||
# os.makedirs(pickle_path) | |||
# file_name = os.path.join(pickle_path, pickle_name) | |||
# if os.path.exists(file_name): | |||
# return True | |||
# else: | |||
# return False | |||
def _move_model_to_device(model, device): | |||
""" | |||
将model移动到device | |||
@@ -254,8 +210,8 @@ def _move_model_to_device(model, device): | |||
:return: torch.nn.DataParallel or torch.nn.Module | |||
""" | |||
if isinstance(model, torch.nn.parallel.DistributedDataParallel): | |||
raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") | |||
# if isinstance(model, torch.nn.parallel.DistributedDataParallel): | |||
# raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") | |||
if device is None: | |||
if isinstance(model, torch.nn.DataParallel): | |||
@@ -352,7 +308,6 @@ def _map_args(maps: dict, **kwargs): | |||
output.update({name: val}) | |||
for keys in maps.keys(): | |||
if keys not in output.keys(): | |||
# TODO: add UNUSED warning. | |||
pass | |||
return output | |||
@@ -570,18 +525,6 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
else: | |||
_tmp = f'Provide `{_miss}` in DataSet or output of {prev_func_signature}.' | |||
suggestions.append(_tmp) | |||
# for _miss in unmapped_missing: | |||
# if _miss in dataset: | |||
# suggestions.append(f"Set `{_miss}` as target.") | |||
# else: | |||
# _tmp = '' | |||
# if check_res.unused: | |||
# _tmp = f"Specify your assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}." | |||
# if _tmp: | |||
# _tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.' | |||
# else: | |||
# _tmp = f'Provide `{_miss}` in output of {prev_func_signature} or DataSet.' | |||
# suggestions.append(_tmp) | |||
if check_res.duplicated: | |||
errs.append(f"\tduplicated param: {check_res.duplicated}.") | |||
@@ -788,3 +731,23 @@ def iob2bioes(tags: List[str]) -> List[str]: | |||
else: | |||
raise TypeError("Invalid IOB format.") | |||
return new_tags | |||
def _is_iterable(value): | |||
# 检查是否是iterable的, duck typing | |||
try: | |||
iter(value) | |||
return True | |||
except BaseException as e: | |||
return False | |||
def get_seq_len(words, pad_value=0): | |||
""" | |||
给定batch_size x max_len的words矩阵,返回句子长度 | |||
:param words: batch_size x max_len | |||
:return: (batch_size,) | |||
""" | |||
mask = words.ne(pad_value) | |||
return mask.sum(dim=-1) |
@@ -4,12 +4,12 @@ __all__ = [ | |||
] | |||
from functools import wraps | |||
from collections import Counter, defaultdict | |||
from collections import Counter | |||
from .dataset import DataSet | |||
from .utils import Option | |||
from functools import partial | |||
import numpy as np | |||
from .utils import _is_iterable | |||
class VocabularyOption(Option): | |||
def __init__(self, | |||
@@ -131,11 +131,11 @@ class Vocabulary(object): | |||
""" | |||
在新加入word时,检查_no_create_word的设置。 | |||
:param str, List[str] word: | |||
:param str List[str] word: | |||
:param bool no_create_entry: | |||
:return: | |||
""" | |||
if isinstance(word, str): | |||
if isinstance(word, str) or not _is_iterable(word): | |||
word = [word] | |||
for w in word: | |||
if no_create_entry and self.word_count.get(w, 0) == self._no_create_word.get(w, 0): | |||
@@ -257,35 +257,45 @@ class Vocabulary(object): | |||
vocab.index_dataset(train_data, dev_data, test_data, field_name='words') | |||
:param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集 | |||
:param str field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field. | |||
目前仅支持 ``str`` , ``List[str]`` , ``List[List[str]]`` | |||
:param str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. | |||
Default: ``None`` | |||
:param list,str field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field. | |||
目前支持 ``str`` , ``List[str]`` | |||
:param list,str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. | |||
Default: ``None``. | |||
""" | |||
def index_instance(ins): | |||
def index_instance(field): | |||
""" | |||
有几种情况, str, 1d-list, 2d-list | |||
:param ins: | |||
:return: | |||
""" | |||
field = ins[field_name] | |||
if isinstance(field, str): | |||
if isinstance(field, str) or not _is_iterable(field): | |||
return self.to_index(field) | |||
elif isinstance(field, list): | |||
if not isinstance(field[0], list): | |||
else: | |||
if isinstance(field[0], str) or not _is_iterable(field[0]): | |||
return [self.to_index(w) for w in field] | |||
else: | |||
if isinstance(field[0][0], list): | |||
if not isinstance(field[0][0], str) and _is_iterable(field[0][0]): | |||
raise RuntimeError("Only support field with 2 dimensions.") | |||
return [[self.to_index(c) for c in w] for w in field] | |||
if new_field_name is None: | |||
new_field_name = field_name | |||
new_field_name = new_field_name or field_name | |||
if type(new_field_name) == type(field_name): | |||
if isinstance(new_field_name, list): | |||
assert len(new_field_name) == len(field_name), "new_field_name should have same number elements with " \ | |||
"field_name." | |||
elif isinstance(new_field_name, str): | |||
field_name = [field_name] | |||
new_field_name = [new_field_name] | |||
else: | |||
raise TypeError("field_name and new_field_name can only be str or List[str].") | |||
for idx, dataset in enumerate(datasets): | |||
if isinstance(dataset, DataSet): | |||
try: | |||
dataset.apply(index_instance, new_field_name=new_field_name) | |||
for f_n, n_f_n in zip(field_name, new_field_name): | |||
dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n) | |||
except Exception as e: | |||
print("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||
raise e | |||
@@ -306,9 +316,8 @@ class Vocabulary(object): | |||
:param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集 | |||
:param str,List[str] field_name: 可为 ``str`` 或 ``List[str]`` . | |||
构建词典所使用的 field(s), 支持一个或多个field | |||
若有多个 DataSet, 每个DataSet都必须有这些field. | |||
目前仅支持的field结构: ``str`` , ``List[str]`` , ``list[List[str]]`` | |||
构建词典所使用的 field(s), 支持一个或多个field,若有多个 DataSet, 每个DataSet都必须有这些field. 目前支持的field结构 | |||
: ``str`` , ``List[str]`` | |||
:param no_create_entry_dataset: 可以传入DataSet, List[DataSet]或者None(默认),该选项用在接下来的模型会使用pretrain | |||
的embedding(包括glove, word2vec, elmo与bert)且会finetune的情况。如果仅使用来自于train的数据建立vocabulary,会导致test与dev | |||
中的数据无法充分利用到来自于预训练embedding的信息,所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。 | |||
@@ -326,14 +335,14 @@ class Vocabulary(object): | |||
def construct_vocab(ins, no_create_entry=False): | |||
for fn in field_name: | |||
field = ins[fn] | |||
if isinstance(field, str): | |||
if isinstance(field, str) or not _is_iterable(field): | |||
self.add_word(field, no_create_entry=no_create_entry) | |||
elif isinstance(field, (list, np.ndarray)): | |||
if not isinstance(field[0], (list, np.ndarray)): | |||
else: | |||
if isinstance(field[0], str) or not _is_iterable(field[0]): | |||
for word in field: | |||
self.add_word(word, no_create_entry=no_create_entry) | |||
else: | |||
if isinstance(field[0][0], (list, np.ndarray)): | |||
if not isinstance(field[0][0], str) and _is_iterable(field[0][0]): | |||
raise RuntimeError("Only support field with 2 dimensions.") | |||
for words in field: | |||
for word in words: | |||
@@ -343,8 +352,8 @@ class Vocabulary(object): | |||
if isinstance(dataset, DataSet): | |||
try: | |||
dataset.apply(construct_vocab) | |||
except Exception as e: | |||
print("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||
except BaseException as e: | |||
print("When processing the `{}` dataset, the following error occurred:".format(idx)) | |||
raise e | |||
else: | |||
raise TypeError("Only DataSet type is allowed.") | |||
@@ -367,7 +376,7 @@ class Vocabulary(object): | |||
:return: bool | |||
""" | |||
return word in self._no_create_word | |||
def to_index(self, w): | |||
""" | |||
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``:: | |||
@@ -10,6 +10,7 @@ __all__ = [ | |||
"StaticEmbedding", | |||
"ElmoEmbedding", | |||
"BertEmbedding", | |||
"BertWordPieceEncoder", | |||
"StackEmbedding", | |||
"LSTMCharEmbedding", | |||
"CNNCharEmbedding", | |||
@@ -20,7 +21,7 @@ __all__ = [ | |||
from .embedding import Embedding | |||
from .static_embedding import StaticEmbedding | |||
from .elmo_embedding import ElmoEmbedding | |||
from .bert_embedding import BertEmbedding | |||
from .bert_embedding import BertEmbedding, BertWordPieceEncoder | |||
from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding | |||
from .stack_embedding import StackEmbedding | |||
from .utils import get_embeddings |
@@ -8,10 +8,10 @@ import numpy as np | |||
from itertools import chain | |||
from ..core.vocabulary import Vocabulary | |||
from ..io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
from ..io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer | |||
from .contextual_embedding import ContextualEmbedding | |||
import warnings | |||
class BertEmbedding(ContextualEmbedding): | |||
""" | |||
@@ -38,8 +38,8 @@ class BertEmbedding(ContextualEmbedding): | |||
:param ~fastNLP.Vocabulary vocab: 词表 | |||
:param str model_dir_or_name: 模型所在目录或者模型的名称。当传入模型所在目录时,目录中应该包含一个词表文件(以.txt作为后缀名), | |||
权重文件(以.bin作为文件后缀名), 配置文件(以.json作为后缀名)。 | |||
:param str layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,可以以负数 | |||
去索引倒数几层。 | |||
:param str layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,层的序号是 | |||
从0开始,可以以负数去索引倒数几层。 | |||
:param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces | |||
中计算得到它对应的表示。支持 ``last`` , ``first`` , ``avg`` , ``max``。 | |||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||
@@ -47,27 +47,35 @@ class BertEmbedding(ContextualEmbedding): | |||
:param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样 | |||
会使得word embedding的结果比输入的结果长两个token。如果该值为True,则在使用 :class::StackEmbedding 可能会与其它类型的 | |||
embedding长度不匹配。 | |||
:param bool pooled_cls: 返回的[CLS]是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取[CLS]做预测, | |||
一般该值为True。 | |||
:param bool requires_grad: 是否需要gradient以更新Bert的权重。 | |||
:param bool auto_truncate: 当句子words拆分为word pieces长度超过bert最大允许长度(一般为512), 自动截掉拆分后的超过510个 | |||
word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS] | |||
来进行分类的任务将auto_truncate置为True。 | |||
""" | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1', | |||
pool_method: str='first', word_dropout=0, dropout=0, requires_grad: bool=False, | |||
include_cls_sep: bool=False): | |||
pool_method: str='first', word_dropout=0, dropout=0, include_cls_sep: bool=False, | |||
pooled_cls=True, requires_grad: bool=False, auto_truncate:bool=False): | |||
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
# 根据model_dir_or_name检查是否存在并下载 | |||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | |||
PRETRAIN_URL = _get_base_url('bert') | |||
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | |||
model_url = PRETRAIN_URL + model_name | |||
model_dir = cached_path(model_url) | |||
model_url = _get_embedding_url('bert', model_dir_or_name.lower()) | |||
model_dir = cached_path(model_url, name='embedding') | |||
# 检查是否存在 | |||
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||
model_dir = model_dir_or_name | |||
model_dir = os.path.expanduser(os.path.abspath(model_dir_or_name)) | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
self._word_sep_index = None | |||
if '[SEP]' in vocab: | |||
self._word_sep_index = vocab['[SEP]'] | |||
self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers, | |||
pool_method=pool_method, include_cls_sep=include_cls_sep) | |||
pool_method=pool_method, include_cls_sep=include_cls_sep, | |||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2) | |||
self.requires_grad = requires_grad | |||
self._embed_size = len(self.model.layers)*self.model.encoder.hidden_size | |||
@@ -83,7 +91,11 @@ class BertEmbedding(ContextualEmbedding): | |||
:param torch.LongTensor words: [batch_size, max_len] | |||
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) | |||
""" | |||
if self._word_sep_index: # 不能drop sep | |||
sep_mask = words.eq(self._word_sep_index) | |||
words = self.drop_word(words) | |||
if self._word_sep_index: | |||
words.masked_fill_(sep_mask, self._word_sep_index) | |||
outputs = self._get_sent_reprs(words) | |||
if outputs is not None: | |||
return self.dropout(words) | |||
@@ -120,24 +132,24 @@ class BertWordPieceEncoder(nn.Module): | |||
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` | |||
:param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | |||
:param bool pooled_cls: 返回的句子开头的[CLS]是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取 | |||
[CLS]做预测,一般该值为True。 | |||
:param bool requires_grad: 是否需要gradient。 | |||
""" | |||
def __init__(self, model_dir_or_name: str='en-base-uncased', layers: str='-1', | |||
requires_grad: bool=False): | |||
pooled_cls: bool = False, requires_grad: bool=False): | |||
super().__init__() | |||
PRETRAIN_URL = _get_base_url('bert') | |||
if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: | |||
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | |||
model_url = PRETRAIN_URL + model_name | |||
model_dir = cached_path(model_url) | |||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | |||
model_url = _get_embedding_url('bert', model_dir_or_name.lower()) | |||
model_dir = cached_path(model_url, name='embedding') | |||
# 检查是否存在 | |||
elif os.path.isdir(model_dir_or_name): | |||
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||
model_dir = model_dir_or_name | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers) | |||
self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers, pooled_cls=pooled_cls) | |||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||
self.requires_grad = requires_grad | |||
@@ -162,16 +174,25 @@ class BertWordPieceEncoder(nn.Module): | |||
def embed_size(self): | |||
return self._embed_size | |||
def index_datasets(self, *datasets, field_name): | |||
@property | |||
def embedding_dim(self): | |||
return self._embed_size | |||
@property | |||
def num_embedding(self): | |||
return self.model.encoder.config.vocab_size | |||
def index_datasets(self, *datasets, field_name, add_cls_sep=True): | |||
""" | |||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是 | |||
[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 | |||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了 | |||
bert的pad value。 | |||
:param datasets: DataSet对象 | |||
:param field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 | |||
:param ~fastNLP.DataSet datasets: DataSet对象 | |||
:param str field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 | |||
:param bool add_cls_sep: 如果首尾不是[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP]。 | |||
:return: | |||
""" | |||
self.model.index_dataset(*datasets, field_name=field_name) | |||
self.model.index_dataset(*datasets, field_name=field_name, add_cls_sep=add_cls_sep) | |||
def forward(self, word_pieces, token_type_ids=None): | |||
""" | |||
@@ -188,11 +209,13 @@ class BertWordPieceEncoder(nn.Module): | |||
class _WordBertModel(nn.Module): | |||
def __init__(self, model_dir:str, vocab:Vocabulary, layers:str='-1', pool_method:str='first', include_cls_sep:bool=False): | |||
def __init__(self, model_dir:str, vocab:Vocabulary, layers:str='-1', pool_method:str='first', | |||
include_cls_sep:bool=False, pooled_cls:bool=False, auto_truncate:bool=False, min_freq=2): | |||
super().__init__() | |||
self.tokenzier = BertTokenizer.from_pretrained(model_dir) | |||
self.encoder = BertModel.from_pretrained(model_dir) | |||
self._max_position_embeddings = self.encoder.config.max_position_embeddings | |||
# 检查encoder_layer_number是否合理 | |||
encoder_layer_number = len(self.encoder.encoder.layer) | |||
self.layers = list(map(int, layers.split(','))) | |||
@@ -207,12 +230,21 @@ class _WordBertModel(nn.Module): | |||
assert pool_method in ('avg', 'max', 'first', 'last') | |||
self.pool_method = pool_method | |||
self.include_cls_sep = include_cls_sep | |||
self.pooled_cls = pooled_cls | |||
self.auto_truncate = auto_truncate | |||
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] | |||
print("Start to generating word pieces for word.") | |||
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | |||
word_piece_dict = {'[CLS]':1, '[SEP]':1} # 用到的word_piece以及新增的 | |||
found_count = 0 | |||
self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids | |||
if '[sep]' in vocab: | |||
warnings.warn("Lower cased [sep] detected, it cannot be correctly recognized as [SEP] by BertEmbedding.") | |||
if "[CLS]" in vocab: | |||
warnings.warn("[CLS] detected in your vocabulary. BertEmbedding will add [CSL] and [SEP] to the begin " | |||
"and end of the input automatically, make sure you don't add [CLS] and [SEP] at the begin" | |||
" and end.") | |||
for word, index in vocab: | |||
if index == vocab.padding_idx: # pad是个特殊的符号 | |||
word = '[PAD]' | |||
@@ -222,7 +254,8 @@ class _WordBertModel(nn.Module): | |||
if len(word_pieces)==1: | |||
if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 | |||
if index!=vocab.unknown_idx and word_pieces[0]=='[UNK]': # 说明这个词不在原始的word里面 | |||
word_piece_dict[word] = 1 # 新增一个值 | |||
if vocab.word_count[word]>=min_freq and not vocab._is_word_no_create_entry(word): #出现次数大于这个次数才新增 | |||
word_piece_dict[word] = 1 # 新增一个值 | |||
continue | |||
for word_piece in word_pieces: | |||
word_piece_dict[word_piece] = 1 | |||
@@ -258,7 +291,7 @@ class _WordBertModel(nn.Module): | |||
print("Found(Or seg into word pieces) {} words out of {}.".format(found_count, len(vocab))) | |||
self._cls_index = self.tokenzier.vocab['[CLS]'] | |||
self._sep_index = self.tokenzier.vocab['[SEP]'] | |||
self._pad_index = vocab.padding_idx | |||
self._word_pad_index = vocab.padding_idx | |||
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece | |||
self.word_to_wordpieces = np.array(word_to_wordpieces) | |||
self.word_pieces_lengths = nn.Parameter(torch.LongTensor(word_pieces_lengths), requires_grad=False) | |||
@@ -270,29 +303,50 @@ class _WordBertModel(nn.Module): | |||
:param words: torch.LongTensor, batch_size x max_len | |||
:return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size | |||
""" | |||
batch_size, max_word_len = words.size() | |||
seq_len = words.ne(self._pad_index).sum(dim=-1) | |||
batch_word_pieces_length = self.word_pieces_lengths[words] # batch_size x max_len | |||
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) | |||
max_word_piece_length = word_pieces_lengths.max().item() | |||
# +2是由于需要加入[CLS]与[SEP] | |||
word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index) | |||
word_pieces[:, 0].fill_(self._cls_index) | |||
batch_indexes = torch.arange(batch_size).to(words) | |||
word_pieces[batch_indexes, word_pieces_lengths+1] = self._sep_index | |||
attn_masks = torch.zeros_like(word_pieces) | |||
# 1. 获取words的word_pieces的id,以及对应的span范围 | |||
word_indexes = words.tolist() | |||
for i in range(batch_size): | |||
word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i]])) | |||
word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i) | |||
attn_masks[i, :len(word_pieces_i)+2].fill_(1) | |||
# TODO 截掉长度超过的部分。 | |||
with torch.no_grad(): | |||
batch_size, max_word_len = words.size() | |||
word_mask = words.ne(self._word_pad_index) # 为1的地方有word | |||
seq_len = word_mask.sum(dim=-1) | |||
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(0), 0) # batch_size x max_len | |||
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size | |||
word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding) | |||
if word_piece_length+2>self._max_position_embeddings: | |||
if self.auto_truncate: | |||
word_pieces_lengths = word_pieces_lengths.masked_fill(word_pieces_lengths+2>self._max_position_embeddings, | |||
self._max_position_embeddings-2) | |||
else: | |||
raise RuntimeError("After split words into word pieces, the lengths of word pieces are longer than the " | |||
f"maximum allowed sequence length:{self._max_position_embeddings} of bert.") | |||
# +2是由于需要加入[CLS]与[SEP] | |||
word_pieces = words.new_full((batch_size, min(word_piece_length+2, self._max_position_embeddings)), | |||
fill_value=self._wordpiece_pad_index) | |||
attn_masks = torch.zeros_like(word_pieces) | |||
# 1. 获取words的word_pieces的id,以及对应的span范围 | |||
word_indexes = words.cpu().numpy() | |||
for i in range(batch_size): | |||
word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i, :seq_len[i]]])) | |||
if self.auto_truncate and len(word_pieces_i)>self._max_position_embeddings-2: | |||
word_pieces_i = word_pieces_i[:self._max_position_embeddings-2] | |||
word_pieces[i, 1:word_pieces_lengths[i]+1] = torch.LongTensor(word_pieces_i) | |||
attn_masks[i, :word_pieces_lengths[i]+2].fill_(1) | |||
# 添加[cls]和[sep] | |||
word_pieces[:, 0].fill_(self._cls_index) | |||
batch_indexes = torch.arange(batch_size).to(words) | |||
word_pieces[batch_indexes, word_pieces_lengths+1] = self._sep_index | |||
if self._has_sep_in_vocab: #但[SEP]在vocab中出现应该才会需要token_ids | |||
sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len | |||
sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) | |||
token_type_ids = sep_mask_cumsum.fmod(2) | |||
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 | |||
token_type_ids = token_type_ids.eq(0).float() | |||
else: | |||
token_type_ids = torch.zeros_like(word_pieces) | |||
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 | |||
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] | |||
bert_outputs, _ = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks, | |||
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, | |||
output_all_encoded_layers=True) | |||
# output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size | |||
# output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size | |||
if self.include_cls_sep: | |||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, | |||
@@ -306,6 +360,12 @@ class _WordBertModel(nn.Module): | |||
batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len | |||
for l_index, l in enumerate(self.layers): | |||
output_layer = bert_outputs[l] | |||
real_word_piece_length = output_layer.size(1) - 2 | |||
if word_piece_length > real_word_piece_length: # 如果实际上是截取出来的 | |||
paddings = output_layer.new_zeros(batch_size, | |||
word_piece_length-real_word_piece_length, | |||
output_layer.size(2)) | |||
output_layer = torch.cat((output_layer, paddings), dim=1).contiguous() | |||
# 从word_piece collapse到word的表示 | |||
truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size | |||
outputs_seq_len = seq_len + s_shift | |||
@@ -328,7 +388,10 @@ class _WordBertModel(nn.Module): | |||
start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j+1] | |||
outputs[l_index, i, j+s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2) | |||
if self.include_cls_sep: | |||
outputs[l_index, :, 0] = output_layer[:, 0] | |||
if l in (len(bert_outputs)-1, -1) and self.pooled_cls: | |||
outputs[l_index, :, 0] = pooled_cls | |||
else: | |||
outputs[l_index, :, 0] = output_layer[:, 0] | |||
outputs[l_index, batch_indexes, seq_len+s_shift] = output_layer[batch_indexes, seq_len+s_shift] | |||
# 3. 最终的embedding结果 | |||
return outputs | |||
@@ -95,7 +95,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||
for i in range(len(kernel_sizes))]) | |||
self._embed_size = embed_size | |||
self.fc = nn.Linear(sum(filter_nums), embed_size) | |||
self.init_param() | |||
self.reset_parameters() | |||
def forward(self, words): | |||
""" | |||
@@ -152,7 +152,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||
continue | |||
param.requires_grad = value | |||
def init_param(self): | |||
def reset_parameters(self): | |||
for name, param in self.named_parameters(): | |||
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset | |||
continue | |||
@@ -1,4 +1,3 @@ | |||
from abc import abstractmethod | |||
import torch | |||
@@ -9,6 +8,10 @@ from ..core.sampler import SequentialSampler | |||
from ..core.utils import _move_model_to_device, _get_model_device | |||
from .embedding import TokenEmbedding | |||
__all__ = [ | |||
"ContextualEmbedding" | |||
] | |||
class ContextualEmbedding(TokenEmbedding): | |||
def __init__(self, vocab: Vocabulary, word_dropout:float=0.0, dropout:float=0.0): | |||
@@ -8,7 +8,7 @@ import json | |||
import codecs | |||
from ..core.vocabulary import Vocabulary | |||
from ..io.file_utils import cached_path, _get_base_url, PRETRAINED_ELMO_MODEL_DIR | |||
from ..io.file_utils import cached_path, _get_embedding_url, PRETRAINED_ELMO_MODEL_DIR | |||
from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder | |||
from .contextual_embedding import ContextualEmbedding | |||
@@ -40,7 +40,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||
:param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo所在文件夹,该文件夹下面应该有两个文件, | |||
其中一个是以json为后缀的配置文件,另一个是以pkl为后缀的权重文件;第二种是传入ELMo版本的名称,将自动查看缓存中是否存在该模型, | |||
没有的话将自动下载并缓存。 | |||
:param layers: str, 指定返回的层数, 以,隔开不同的层。如果要返回第二层的结果'2', 返回后两层的结果'1,2'。不同的层的结果 | |||
:param layers: str, 指定返回的层数(从0开始), 以,隔开不同的层。如果要返回第二层的结果'2', 返回后两层的结果'1,2'。不同的层的结果 | |||
按照这个顺序concat起来,默认为'2'。'mix'会使用可学习的权重结合不同层的表示(权重是否可训练与requires_grad保持一致, | |||
初始化权重对三层结果进行mean-pooling, 可以通过ElmoEmbedding.set_mix_weights_requires_grad()方法只将mix weights设置为可学习。) | |||
:param requires_grad: bool, 该层是否需要gradient, 默认为False. | |||
@@ -56,10 +56,8 @@ class ElmoEmbedding(ContextualEmbedding): | |||
# 根据model_dir_or_name检查是否存在并下载 | |||
if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: | |||
PRETRAIN_URL = _get_base_url('elmo') | |||
model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name] | |||
model_url = PRETRAIN_URL + model_name | |||
model_dir = cached_path(model_url) | |||
model_url = _get_embedding_url('elmo', model_dir_or_name.lower()) | |||
model_dir = cached_path(model_url, name='embedding') | |||
# 检查是否存在 | |||
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||
model_dir = model_dir_or_name | |||
@@ -185,8 +183,8 @@ class _ElmoModel(nn.Module): | |||
raise Exception(f"Multiple config files(*.json) or weight files(*.hdf5) detected in {model_dir}.") | |||
elif config_count == 0 or weight_count == 0: | |||
raise Exception(f"No config file or weight file found in {model_dir}") | |||
config = json.load(open(os.path.join(model_dir, config_file), 'r')) | |||
with open(os.path.join(model_dir, config_file), 'r') as config_f: | |||
config = json.load(config_f) | |||
self.weight_file = os.path.join(model_dir, weight_file) | |||
self.config = config | |||
@@ -42,7 +42,12 @@ class Embedding(nn.Module): | |||
self.dropout = nn.Dropout(dropout) | |||
if not isinstance(self.embed, TokenEmbedding): | |||
self._embed_size = self.embed.weight.size(1) | |||
if hasattr(self.embed, 'embed_size'): | |||
self._embed_size = self.embed.embed_size | |||
elif hasattr(self.embed, 'embedding_dim'): | |||
self._embed_size = self.embed.embedding_dim | |||
else: | |||
self._embed_size = self.embed.weight.size(1) | |||
if word_dropout>0 and not isinstance(unk_index, int): | |||
raise ValueError("When drop word is set, you need to pass in the unk_index.") | |||
else: | |||
@@ -7,9 +7,11 @@ import numpy as np | |||
import warnings | |||
from ..core.vocabulary import Vocabulary | |||
from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_base_url, cached_path | |||
from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path | |||
from .embedding import TokenEmbedding | |||
from ..modules.utils import _get_file_name_base_on_postfix | |||
from copy import deepcopy | |||
from collections import defaultdict | |||
class StaticEmbedding(TokenEmbedding): | |||
""" | |||
@@ -45,15 +47,16 @@ class StaticEmbedding(TokenEmbedding): | |||
如果输入为None则使用embedding_dim的维度随机初始化一个embedding。 | |||
:param int embedding_dim: 随机初始化的embedding的维度,仅在model_dir_or_name为None时有效。 | |||
:param bool requires_grad: 是否需要gradient. 默认为True | |||
:param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对象。 | |||
:param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对 | |||
:param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独 | |||
为大写的词语开辟一个vector表示,则将lower设置为False。 | |||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | |||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||
:param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。 | |||
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 | |||
""" | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', embedding_dim=100, requires_grad: bool=True, | |||
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False): | |||
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs): | |||
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
# 得到cache_path | |||
@@ -62,10 +65,8 @@ class StaticEmbedding(TokenEmbedding): | |||
embedding_dim = int(embedding_dim) | |||
model_path = None | |||
elif model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: | |||
PRETRAIN_URL = _get_base_url('static') | |||
model_name = PRETRAIN_STATIC_FILES[model_dir_or_name] | |||
model_url = PRETRAIN_URL + model_name | |||
model_path = cached_path(model_url) | |||
model_url = _get_embedding_url('static', model_dir_or_name.lower()) | |||
model_path = cached_path(model_url, name='embedding') | |||
# 检查是否存在 | |||
elif os.path.isfile(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||
model_path = model_dir_or_name | |||
@@ -74,27 +75,49 @@ class StaticEmbedding(TokenEmbedding): | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
# 根据min_freq缩小vocab | |||
truncate_vocab = (vocab.min_freq is None and min_freq>1) or (vocab.min_freq and vocab.min_freq<min_freq) | |||
if truncate_vocab: | |||
truncated_vocab = deepcopy(vocab) | |||
truncated_vocab.min_freq = min_freq | |||
truncated_vocab.word2idx = None | |||
if lower: # 如果有lower,将大小写的的freq需要同时考虑到 | |||
lowered_word_count = defaultdict(int) | |||
for word, count in truncated_vocab.word_count.items(): | |||
lowered_word_count[word.lower()] += count | |||
for word in truncated_vocab.word_count.keys(): | |||
word_count = truncated_vocab.word_count[word] | |||
if lowered_word_count[word.lower()]>=min_freq and word_count<min_freq: | |||
truncated_vocab.add_word_lst([word]*(min_freq-word_count), | |||
no_create_entry=truncated_vocab._is_word_no_create_entry(word)) | |||
# 只限制在train里面的词语使用min_freq筛选 | |||
if kwargs.get('only_train_min_freq', False): | |||
for word in truncated_vocab.word_count.keys(): | |||
if truncated_vocab._is_word_no_create_entry(word) and truncated_vocab.word_count[word]<min_freq: | |||
truncated_vocab.add_word_lst([word] * (min_freq - truncated_vocab.word_count[word]), | |||
no_create_entry=True) | |||
truncated_vocab.build_vocab() | |||
truncated_words_to_words = torch.arange(len(vocab)).long() | |||
for word, index in vocab: | |||
truncated_words_to_words[index] = truncated_vocab.to_index(word) | |||
print(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.") | |||
vocab = truncated_vocab | |||
# 读取embedding | |||
if lower: | |||
lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown) | |||
for word, index in vocab: | |||
if not vocab._is_word_no_create_entry(word): | |||
if vocab._is_word_no_create_entry(word): | |||
lowered_vocab.add_word(word.lower(), no_create_entry=True) | |||
else: | |||
lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的 | |||
for word in vocab._no_create_word.keys(): # 不需要创建entry的 | |||
if word in vocab: | |||
lowered_word = word.lower() | |||
if lowered_word not in lowered_vocab.word_count: | |||
lowered_vocab.add_word(lowered_word) | |||
lowered_vocab._no_create_word[lowered_word] += 1 | |||
print(f"All word in vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} unique lowered " | |||
f"words.") | |||
print(f"All word in the vocab have been lowered before finding pretrained vectors. There are {len(vocab)} " | |||
f"words, {len(lowered_vocab)} unique lowered words.") | |||
if model_path: | |||
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) | |||
else: | |||
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) | |||
# 需要适配一下 | |||
if not hasattr(self, 'words_to_words'): | |||
self.words_to_words = torch.arange(len(lowered_vocab, )).long() | |||
if lowered_vocab.unknown: | |||
unknown_idx = lowered_vocab.unknown_idx | |||
else: | |||
@@ -104,10 +127,11 @@ class StaticEmbedding(TokenEmbedding): | |||
for word, index in vocab: | |||
if word not in lowered_vocab: | |||
word = word.lower() | |||
if lowered_vocab._is_word_no_create_entry(word): # 如果不需要创建entry,已经默认unknown了 | |||
continue | |||
if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(word): | |||
continue # 如果不需要创建entry,已经默认unknown了 | |||
words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)] | |||
self.words_to_words = words_to_words | |||
self._word_unk_index = lowered_vocab.unknown_idx # 替换一下unknown的index | |||
else: | |||
if model_path: | |||
embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method) | |||
@@ -115,6 +139,14 @@ class StaticEmbedding(TokenEmbedding): | |||
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) | |||
if normalize: | |||
embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12) | |||
if truncate_vocab: | |||
for i in range(len(truncated_words_to_words)): | |||
index_in_truncated_vocab = truncated_words_to_words[i] | |||
truncated_words_to_words[i] = self.words_to_words[index_in_truncated_vocab] | |||
del self.words_to_words | |||
self.words_to_words = nn.Parameter(truncated_words_to_words, requires_grad=False) | |||
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], | |||
padding_idx=vocab.padding_idx, | |||
max_norm=None, norm_type=2, scale_grad_by_freq=False, | |||
@@ -191,6 +223,10 @@ class StaticEmbedding(TokenEmbedding): | |||
dim = len(parts) - 1 | |||
f.seek(0) | |||
matrix = {} | |||
if vocab.padding: | |||
matrix[vocab.padding_idx] = torch.zeros(dim) | |||
if vocab.unknown: | |||
matrix[vocab.unknown_idx] = torch.zeros(dim) | |||
found_count = 0 | |||
for idx, line in enumerate(f, start_idx): | |||
try: | |||
@@ -219,26 +255,21 @@ class StaticEmbedding(TokenEmbedding): | |||
matrix[index] = matrix[vocab.unknown_idx] | |||
else: | |||
matrix[index] = None | |||
# matrix中代表是需要建立entry的词 | |||
vectors = self._randomly_init_embed(len(matrix), dim, init_method) | |||
if vocab._no_create_word_length>0: | |||
if vocab.unknown is None: # 创建一个专门的unknown | |||
unknown_idx = len(matrix) | |||
vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() | |||
else: | |||
unknown_idx = vocab.unknown_idx | |||
words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), | |||
requires_grad=False) | |||
for order, (index, vec) in enumerate(matrix.items()): | |||
if vec is not None: | |||
vectors[order] = vec | |||
words_to_words[index] = order | |||
self.words_to_words = words_to_words | |||
if vocab.unknown is None: # 创建一个专门的unknown | |||
unknown_idx = len(matrix) | |||
vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() | |||
else: | |||
for index, vec in matrix.items(): | |||
if vec is not None: | |||
vectors[index] = vec | |||
unknown_idx = vocab.unknown_idx | |||
self.words_to_words = nn.Parameter(torch.full((len(vocab), ), fill_value=unknown_idx).long(), | |||
requires_grad=False) | |||
for index, (index_in_vocab, vec) in enumerate(matrix.items()): | |||
if vec is not None: | |||
vectors[index] = vec | |||
self.words_to_words[index_in_vocab] = index | |||
return vectors | |||
@@ -3,9 +3,9 @@ | |||
1. 用于读入 embedding 的 :doc:`EmbedLoader <fastNLP.io.embed_loader>` 类, | |||
2. 用于读入不同格式数据的 :doc:`DataSetLoader <fastNLP.io.dataset_loader>` 类 | |||
2. 用于读入不同格式数据的 :doc:`Loader <fastNLP.io.loader>` 类 | |||
3. 用于读入不同数据集并进行预处理的 :doc:`DataLoader <fastNLP.io.data_loader>` 类 | |||
3. 用于处理读入数据的 :doc:`Pipe <fastNLP.io.pipe>` 类 | |||
4. 用于保存和载入模型的类, 参考 :doc:`model_io文档</fastNLP.io.model_io>` | |||
@@ -14,27 +14,56 @@ | |||
__all__ = [ | |||
'EmbedLoader', | |||
'CSVLoader', | |||
'JsonLoader', | |||
'DataBundle', | |||
'DataSetLoader', | |||
'ConllLoader', | |||
'Conll2003Loader', | |||
'YelpLoader', | |||
'YelpFullLoader', | |||
'YelpPolarityLoader', | |||
'IMDBLoader', | |||
'MatchingLoader', | |||
'SNLILoader', | |||
'MNLILoader', | |||
'MTL16Loader', | |||
'PeopleDailyCorpusLoader', | |||
'QNLILoader', | |||
'QuoraLoader', | |||
'RTELoader', | |||
'SSTLoader', | |||
'SST2Loader', | |||
'YelpLoader', | |||
'ConllLoader', | |||
'Conll2003Loader', | |||
'Conll2003NERLoader', | |||
'OntoNotesNERLoader', | |||
'CTBLoader', | |||
'Loader', | |||
'CSVLoader', | |||
'JsonLoader', | |||
'CWSLoader', | |||
'MNLILoader', | |||
"QuoraLoader", | |||
"SNLILoader", | |||
"QNLILoader", | |||
"RTELoader", | |||
"YelpFullPipe", | |||
"YelpPolarityPipe", | |||
"SSTPipe", | |||
"SST2Pipe", | |||
"IMDBPipe", | |||
"Conll2003NERPipe", | |||
"OntoNotesNERPipe", | |||
"MatchingBertPipe", | |||
"RTEBertPipe", | |||
"SNLIBertPipe", | |||
"QuoraBertPipe", | |||
"QNLIBertPipe", | |||
"MNLIBertPipe", | |||
"MatchingPipe", | |||
"RTEPipe", | |||
"SNLIPipe", | |||
"QuoraPipe", | |||
"QNLIPipe", | |||
"MNLIPipe", | |||
'ModelLoader', | |||
'ModelSaver', | |||
] | |||
@@ -44,4 +73,5 @@ from .base_loader import DataBundle, DataSetLoader | |||
from .dataset_loader import CSVLoader, JsonLoader | |||
from .model_io import ModelLoader, ModelSaver | |||
from .data_loader import * | |||
from .loader import * | |||
from .pipe import * |
@@ -5,10 +5,10 @@ __all__ = [ | |||
] | |||
import _pickle as pickle | |||
import os | |||
from typing import Union, Dict | |||
import os | |||
from ..core.dataset import DataSet | |||
from ..core.vocabulary import Vocabulary | |||
class BaseLoader(object): | |||
@@ -111,7 +111,10 @@ def _uncompress(src, dst): | |||
class DataBundle: | |||
""" | |||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。 | |||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。该对象一般由fastNLP中各种 | |||
DataSetLoader的load函数生成,可以通过以下的方法获取里面的内容 | |||
Example:: | |||
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | |||
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict | |||
@@ -121,6 +124,88 @@ class DataBundle: | |||
self.vocabs = vocabs or {} | |||
self.datasets = datasets or {} | |||
def set_vocab(self, vocab, field_name): | |||
""" | |||
向DataBunlde中增加vocab | |||
:param ~fastNLP.Vocabulary vocab: 词表 | |||
:param str field_name: 这个vocab对应的field名称 | |||
:return: | |||
""" | |||
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary supports." | |||
self.vocabs[field_name] = vocab | |||
def set_dataset(self, dataset, name): | |||
""" | |||
:param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet | |||
:param str name: dataset的名称 | |||
:return: | |||
""" | |||
self.datasets[name] = dataset | |||
def get_dataset(self, name:str)->DataSet: | |||
""" | |||
获取名为name的dataset | |||
:param str name: dataset的名称,一般为'train', 'dev', 'test' | |||
:return: DataSet | |||
""" | |||
return self.datasets[name] | |||
def get_vocab(self, field_name:str)->Vocabulary: | |||
""" | |||
获取field名为field_name对应的vocab | |||
:param str field_name: 名称 | |||
:return: Vocabulary | |||
""" | |||
return self.vocabs[field_name] | |||
def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_field=True): | |||
""" | |||
将field_names中的field设置为input, 对data_bundle中所有的dataset执行该操作:: | |||
data_bundle.set_input('words', 'seq_len') # 将words和seq_len这两个field的input属性设置为True | |||
data_bundle.set_input('words', flag=False) # 将words这个field的input属性设置为False | |||
:param str field_names: field的名称 | |||
:param bool flag: 将field_name的input状态设置为flag | |||
:param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一 | |||
行的数据进行类型和维度推断本列的数据的类型和维度。 | |||
:param bool ignore_miss_field: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略; 如果为False,则报错 | |||
""" | |||
for field_name in field_names: | |||
for name, dataset in self.datasets.items(): | |||
if not ignore_miss_field and not dataset.has_field(field_name): | |||
raise KeyError(f"Field:{field_name} was not found in DataSet:{name}") | |||
if not dataset.has_field(field_name): | |||
continue | |||
else: | |||
dataset.set_input(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type) | |||
def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_field=True): | |||
""" | |||
将field_names中的field设置为target, 对data_bundle中所有的dataset执行该操作:: | |||
data_bundle.set_target('target', 'seq_len') # 将words和target这两个field的input属性设置为True | |||
data_bundle.set_target('target', flag=False) # 将target这个field的input属性设置为False | |||
:param str field_names: field的名称 | |||
:param bool flag: 将field_name的target状态设置为flag | |||
:param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一 | |||
行的数据进行类型和维度推断本列的数据的类型和维度。 | |||
:param bool ignore_miss_field: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略; 如果为False,则报错 | |||
""" | |||
for field_name in field_names: | |||
for name, dataset in self.datasets.items(): | |||
if not ignore_miss_field and not dataset.has_field(field_name): | |||
raise KeyError(f"Field:{field_name} was not found in DataSet:{name}") | |||
if not dataset.has_field(field_name): | |||
continue | |||
else: | |||
dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type) | |||
def __repr__(self): | |||
_str = 'In total {} datasets:\n'.format(len(self.datasets)) | |||
for name, dataset in self.datasets.items(): | |||
@@ -1,7 +1,9 @@ | |||
""" | |||
用于读入和处理和保存 config 文件 | |||
.. todo:: | |||
.. todo:: | |||
这个模块中的类可能被抛弃? | |||
""" | |||
__all__ = [ | |||
"ConfigLoader", | |||
@@ -1,4 +1,8 @@ | |||
""" | |||
.. warning:: | |||
本模块在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。 | |||
用于读数据集的模块, 可以读取文本分类、序列标注、Matching任务的数据集 | |||
这些模块的具体介绍如下,您可以通过阅读 :doc:`教程</tutorials/tutorial_2_load_dataset>` 来进行了解。 | |||
@@ -3,38 +3,47 @@ from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ..base_loader import DataSetLoader | |||
from ..file_reader import _read_conll | |||
from typing import Union, Dict | |||
from ..utils import check_loader_paths | |||
from ..base_loader import DataBundle | |||
class ConllLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.data_loader.ConllLoader` | |||
读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html. 数据中以"-DOCSTART-"开头的行将被忽略,因为 | |||
该符号在conll 2003中被用为文档分割符。 | |||
列号从0开始, 每列对应内容为:: | |||
Column Type | |||
0 Document ID | |||
1 Part number | |||
2 Word number | |||
3 Word itself | |||
4 Part-of-Speech | |||
5 Parse bit | |||
6 Predicate lemma | |||
7 Predicate Frameset ID | |||
8 Word sense | |||
9 Speaker/Author | |||
10 Named Entities | |||
11:N Predicate Arguments | |||
N Coreference | |||
:param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | |||
:param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` | |||
该ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示: | |||
Example:: | |||
# 文件中的内容 | |||
Nadim NNP B-NP B-PER | |||
Ladki NNP I-NP I-PER | |||
AL-AIN NNP B-NP B-LOC | |||
United NNP B-NP B-LOC | |||
Arab NNP I-NP I-LOC | |||
Emirates NNPS I-NP I-LOC | |||
1996-12-06 CD I-NP O | |||
... | |||
# 如果用以下的参数读取,返回的DataSet将包含raw_words和pos两个field, 这两个field的值分别取自于第0列与第1列 | |||
dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll') | |||
# 如果用以下的参数读取,返回的DataSet将包含raw_words和ner两个field, 这两个field的值分别取自于第0列与第2列 | |||
dataset = ConllLoader(headers=['raw_words', 'ner'], indexes=[0, 3])._load('/path/to/train.conll') | |||
# 如果用以下的参数读取,返回的DataSet将包含raw_words, pos和ner三个field | |||
dataset = ConllLoader(headers=['raw_words', 'pos', 'ner'], indexes=[0, 1, 3])._load('/path/to/train.conll') | |||
dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll')中DataSet的raw_words | |||
列与pos列的内容都是List[str] | |||
数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 | |||
:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | |||
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | |||
""" | |||
def __init__(self, headers, indexes=None, dropna=False): | |||
def __init__(self, headers, indexes=None, dropna=True): | |||
super(ConllLoader, self).__init__() | |||
if not isinstance(headers, (list, tuple)): | |||
raise TypeError( | |||
@@ -49,25 +58,74 @@ class ConllLoader(DataSetLoader): | |||
self.indexes = indexes | |||
def _load(self, path): | |||
""" | |||
传入的一个文件路径,将该文件读入DataSet中,field由Loader初始化时指定的headers决定。 | |||
:param str path: 文件的路径 | |||
:return: DataSet | |||
""" | |||
ds = DataSet() | |||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | |||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||
ds.append(Instance(**ins)) | |||
return ds | |||
def load(self, paths: Union[str, Dict[str, str]]) -> DataBundle: | |||
""" | |||
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 | |||
读取的field根据ConllLoader初始化时传入的headers决定。 | |||
:param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式 | |||
(1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件 | |||
名包含'train'、 'dev'、 'test'则会报错 | |||
Example:: | |||
data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train, dev, test等有所变化 | |||
# 可以通过以下的方式取出DataSet | |||
tr_data = data_bundle.datasets['train'] | |||
te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段 | |||
(2) 传入文件path | |||
Example:: | |||
data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' | |||
tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet | |||
(3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test | |||
Example:: | |||
paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} | |||
data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" | |||
dev_data = data_bundle.datasets['dev'] | |||
:return: :class:`~fastNLP.DataSet` 类的对象或 :class:`~fastNLP.io.DataBundle` 的字典 | |||
""" | |||
paths = check_loader_paths(paths) | |||
datasets = {name: self._load(path) for name, path in paths.items()} | |||
data_bundle = DataBundle(datasets=datasets) | |||
return data_bundle | |||
class Conll2003Loader(ConllLoader): | |||
""" | |||
别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.data_loader.Conll2003Loader` | |||
读取Conll2003数据 | |||
该Loader用以读取Conll2003数据,conll2003的数据可以在https://github.com/davidsbatista/NER-datasets/tree/master/CONLL2003 | |||
找到。数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 | |||
返回的DataSet将具有以下['raw_words', 'pos', 'chunks', 'ner']四个field, 每个field中的内容都是List[str]。 | |||
.. csv-table:: Conll2003Loader处理之 :header: "raw_words", "words", "target", "seq_len" | |||
"[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2 | |||
"[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 5 | |||
"[...]", "[...]", "[...]", . | |||
关于数据集的更多信息,参考: | |||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | |||
""" | |||
def __init__(self): | |||
headers = [ | |||
'tokens', 'pos', 'chunks', 'ner', | |||
'raw_words', 'pos', 'chunks', 'ner', | |||
] | |||
super(Conll2003Loader, self).__init__(headers=headers) |
@@ -121,7 +121,7 @@ class MatchingLoader(DataSetLoader): | |||
PRETRAIN_URL = _get_base_url('bert') | |||
model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] | |||
model_url = PRETRAIN_URL + model_name | |||
model_dir = cached_path(model_url) | |||
model_dir = cached_path(model_url, name='embedding') | |||
# 检查是否存在 | |||
elif os.path.isdir(bert_tokenizer): | |||
model_dir = bert_tokenizer | |||
@@ -5,7 +5,7 @@ from ..base_loader import DataBundle | |||
from ..dataset_loader import CSVLoader | |||
from ...core.vocabulary import Vocabulary, VocabularyOption | |||
from ...core.const import Const | |||
from ..utils import check_dataloader_paths | |||
from ..utils import check_loader_paths | |||
class MTL16Loader(CSVLoader): | |||
@@ -38,7 +38,7 @@ class MTL16Loader(CSVLoader): | |||
src_vocab_opt: VocabularyOption = None, | |||
tgt_vocab_opt: VocabularyOption = None,): | |||
paths = check_dataloader_paths(paths) | |||
paths = check_loader_paths(paths) | |||
datasets = {} | |||
info = DataBundle() | |||
for name, path in paths.items(): | |||
@@ -8,7 +8,7 @@ from ...core.vocabulary import VocabularyOption, Vocabulary | |||
from ...core.dataset import DataSet | |||
from ...core.const import Const | |||
from ...core.instance import Instance | |||
from ..utils import check_dataloader_paths, get_tokenizer | |||
from ..utils import check_loader_paths, get_tokenizer | |||
class SSTLoader(DataSetLoader): | |||
@@ -67,7 +67,7 @@ class SSTLoader(DataSetLoader): | |||
paths, train_subtree=True, | |||
src_vocab_op: VocabularyOption = None, | |||
tgt_vocab_op: VocabularyOption = None,): | |||
paths = check_dataloader_paths(paths) | |||
paths = check_loader_paths(paths) | |||
input_name, target_name = 'words', 'target' | |||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
@@ -129,11 +129,12 @@ class SST2Loader(CSVLoader): | |||
tgt_vocab_opt: VocabularyOption = None, | |||
char_level_op=False): | |||
paths = check_dataloader_paths(paths) | |||
paths = check_loader_paths(paths) | |||
datasets = {} | |||
info = DataBundle() | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
dataset.apply_field(lambda words:words.copy(), field_name='words', new_field_name='raw_words') | |||
datasets[name] = dataset | |||
def wordtochar(words): | |||
@@ -154,7 +155,9 @@ class SST2Loader(CSVLoader): | |||
for dataset in datasets.values(): | |||
dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) | |||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT) | |||
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT, no_create_entry_dataset=[ | |||
dataset for name, dataset in datasets.items() if name!='train' | |||
]) | |||
src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
@@ -8,7 +8,7 @@ from ...core.instance import Instance | |||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||
from ..base_loader import DataBundle, DataSetLoader | |||
from typing import Union, Dict | |||
from ..utils import check_dataloader_paths, get_tokenizer | |||
from ..utils import check_loader_paths, get_tokenizer | |||
class YelpLoader(DataSetLoader): | |||
@@ -62,7 +62,7 @@ class YelpLoader(DataSetLoader): | |||
src_vocab_op: VocabularyOption = None, | |||
tgt_vocab_op: VocabularyOption = None, | |||
char_level_op=False): | |||
paths = check_dataloader_paths(paths) | |||
paths = check_loader_paths(paths) | |||
info = DataBundle(datasets=self.load(paths)) | |||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
@@ -1,4 +1,8 @@ | |||
""" | |||
.. warning:: | |||
本模块将在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。 | |||
dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的数据, 并返回 `DataSet` , | |||
得到的 :class:`~fastNLP.DataSet` 对象可以直接传入 :class:`~fastNLP.Trainer` 和 :class:`~fastNLP.Tester`, 用于模型的训练和测试。 | |||
以SNLI数据集为例:: | |||
@@ -11,6 +15,7 @@ dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的 | |||
# ... do stuff | |||
为 fastNLP 提供 DataSetLoader 的开发者请参考 :class:`~fastNLP.io.DataSetLoader` 的介绍。 | |||
""" | |||
__all__ = [ | |||
'CSVLoader', | |||
@@ -114,25 +119,3 @@ def _cut_long_sentence(sent, max_sample_length=200): | |||
else: | |||
cutted_sentence.append(sent) | |||
return cutted_sentence | |||
def _add_seg_tag(data): | |||
""" | |||
:param data: list of ([word], [pos], [heads], [head_tags]) | |||
:return: list of ([word], [pos]) | |||
""" | |||
_processed = [] | |||
for word_list, pos_list, _, _ in data: | |||
new_sample = [] | |||
for word, pos in zip(word_list, pos_list): | |||
if len(word) == 1: | |||
new_sample.append((word, 'S-' + pos)) | |||
else: | |||
new_sample.append((word[0], 'B-' + pos)) | |||
for c in word[1:-1]: | |||
new_sample.append((c, 'M-' + pos)) | |||
new_sample.append((word[-1], 'E-' + pos)) | |||
_processed.append(list(map(list, zip(*new_sample)))) | |||
return _processed |
@@ -2,7 +2,7 @@ | |||
此模块用于给其它模块提供读取文件的函数,没有为用户提供 API | |||
""" | |||
import json | |||
import warnings | |||
def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | |||
""" | |||
@@ -91,7 +91,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
with open(path, 'r', encoding=encoding) as f: | |||
sample = [] | |||
start = next(f).strip() | |||
if '-DOCSTART-' not in start and start!='': | |||
if start!='': | |||
sample.append(start.split()) | |||
for line_idx, line in enumerate(f, 1): | |||
line = line.strip() | |||
@@ -103,13 +103,13 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
yield line_idx, res | |||
except Exception as e: | |||
if dropna: | |||
warnings.warn('Invalid instance ends at line: {} has been dropped.'.format(line_idx)) | |||
continue | |||
raise ValueError('invalid instance ends at line: {}'.format(line_idx)) | |||
raise ValueError('Invalid instance ends at line: {}'.format(line_idx)) | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
if not line.startswith('-DOCSTART-'): | |||
sample.append(line.split()) | |||
sample.append(line.split()) | |||
if len(sample) > 0: | |||
try: | |||
res = parse_conll(sample) | |||
@@ -1,4 +1,3 @@ | |||
import os | |||
from pathlib import Path | |||
from urllib.parse import urlparse | |||
@@ -7,65 +6,124 @@ import requests | |||
import tempfile | |||
from tqdm import tqdm | |||
import shutil | |||
import hashlib | |||
from requests import HTTPError | |||
PRETRAINED_BERT_MODEL_DIR = { | |||
'en': 'bert-base-cased-f89bfe08.zip', | |||
'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | |||
'en-base-cased': 'bert-base-cased-f89bfe08.zip', | |||
'en-large-uncased': 'bert-large-uncased-20939f45.zip', | |||
'en-large-cased': 'bert-large-cased-e0cf90fc.zip', | |||
'en-large-cased-wwm': 'bert-large-cased-wwm-a457f118.zip', | |||
'en-large-uncased-wwm': 'bert-large-uncased-wwm-92a50aeb.zip', | |||
'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc-c7099855.zip', | |||
'cn': 'bert-base-chinese-29d0a84a.zip', | |||
'cn-base': 'bert-base-chinese-29d0a84a.zip', | |||
'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', | |||
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
'en': 'bert-base-cased.zip', | |||
'en-large-cased-wwm': 'bert-large-cased-wwm.zip', | |||
'en-large-uncased-wwm': 'bert-large-uncased-wwm.zip', | |||
'en-large-uncased': 'bert-large-uncased.zip', | |||
'en-large-cased': 'bert-large-cased.zip', | |||
'en-base-uncased': 'bert-base-uncased.zip', | |||
'en-base-cased': 'bert-base-cased.zip', | |||
'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc.zip', | |||
'multi-base-cased': 'bert-base-multilingual-cased.zip', | |||
'multi-base-uncased': 'bert-base-multilingual-uncased.zip', | |||
'cn': 'bert-chinese-wwm.zip', | |||
'cn-base': 'bert-base-chinese.zip', | |||
'cn-wwm': 'bert-chinese-wwm.zip', | |||
} | |||
PRETRAINED_ELMO_MODEL_DIR = { | |||
'en': 'elmo_en-d39843fe.tar.gz', | |||
'cn': 'elmo_cn-5e9b34e2.tar.gz' | |||
'en': 'elmo_en_Medium.zip', | |||
'en-small': "elmo_en_Small.zip", | |||
'en-original-5.5b': 'elmo_en_Original_5.5B.zip', | |||
'en-original': 'elmo_en_Original.zip', | |||
'en-medium': 'elmo_en_Medium.zip' | |||
} | |||
PRETRAIN_STATIC_FILES = { | |||
'en': 'glove.840B.300d-cc1ad5e1.tar.gz', | |||
'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz', | |||
'en-glove-6b-50': "glove.6B.50d-a6028c70.tar.gz", | |||
'en-word2vec-300': "GoogleNews-vectors-negative300-be166d9d.tar.gz", | |||
'en-fasttext': "cc.en.300.vec-d53187b2.gz", | |||
'cn': "tencent_cn-dab24577.tar.gz", | |||
'cn-fasttext': "cc.zh.300.vec-d68a9bcf.gz", | |||
'en': 'glove.840B.300d.zip', | |||
'en-glove-6b-50d': 'glove.6B.50d.zip', | |||
'en-glove-6b-100d': 'glove.6B.100d.zip', | |||
'en-glove-6b-200d': 'glove.6B.200d.zip', | |||
'en-glove-6b-300d': 'glove.6B.300d.zip', | |||
'en-glove-42b-300d': 'glove.42B.300d.zip', | |||
'en-glove-840b-300d': 'glove.840B.300d.zip', | |||
'en-glove-twitter-27b-25d': 'glove.twitter.27B.25d.zip', | |||
'en-glove-twitter-27b-50d': 'glove.twitter.27B.50d.zip', | |||
'en-glove-twitter-27b-100d': 'glove.twitter.27B.100d.zip', | |||
'en-glove-twitter-27b-200d': 'glove.twitter.27B.200d.zip', | |||
'en-word2vec-300': "GoogleNews-vectors-negative300.txt.gz", | |||
'en-fasttext-wiki': "wiki-news-300d-1M.vec.zip", | |||
'en-fasttext-crawl': "crawl-300d-2M.vec.zip", | |||
'cn': "tencent_cn.txt.zip", | |||
'cn-tencent': "tencent_cn.txt.zip", | |||
'cn-fasttext': "cc.zh.300.vec.gz", | |||
'cn-sgns-literature-word': 'sgns.literature.word.txt.zip', | |||
} | |||
DATASET_DIR = { | |||
'aclImdb': "imdb.zip", | |||
"yelp-review-full": "yelp_review_full.tar.gz", | |||
"yelp-review-polarity": "yelp_review_polarity.tar.gz", | |||
"mnli": "MNLI.zip", | |||
"snli": "SNLI.zip", | |||
"qnli": "QNLI.zip", | |||
"sst-2": "SST-2.zip", | |||
"sst": "SST.zip", | |||
"rte": "RTE.zip" | |||
} | |||
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | |||
"bert": PRETRAINED_BERT_MODEL_DIR, | |||
"static": PRETRAIN_STATIC_FILES} | |||
# 用于扩展fastNLP的下载 | |||
FASTNLP_EXTEND_DATASET_URL = 'fastnlp_dataset_url.txt' | |||
FASTNLP_EXTEND_EMBEDDING_URL = {'elmo': 'fastnlp_elmo_url.txt', | |||
'bert':'fastnlp_bert_url.txt', | |||
'static': 'fastnlp_static_url.txt' | |||
} | |||
def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: | |||
def cached_path(url_or_filename: str, cache_dir: str = None, name=None) -> Path: | |||
""" | |||
给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 | |||
将文件放入到cache_dir中 | |||
给定一个url,尝试通过url中的解析出来的文件名字filename到{cache_dir}/{name}/{filename}下寻找这个文件, | |||
1. 如果cache_dir=None, 则cache_dir=~/.fastNLP/; 否则cache_dir=cache_dir | |||
2. 如果name=None, 则没有中间的{name}这一层结构;否者中间结构就为{name} | |||
如果有该文件,就直接返回路径 | |||
如果没有该文件,则尝试用传入的url下载 | |||
或者文件名(可以是具体的文件名,也可以是文件夹),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 | |||
将文件放入到cache_dir中. | |||
:param str url_or_filename: 文件的下载url或者文件名称。 | |||
:param str cache_dir: 文件的缓存文件夹。如果为None,将使用"~/.fastNLP"这个默认路径 | |||
:param str name: 中间一层的名称。如embedding, dataset | |||
:return: | |||
""" | |||
if cache_dir is None: | |||
dataset_cache = Path(get_defalt_path()) | |||
data_cache = Path(get_cache_path()) | |||
else: | |||
dataset_cache = cache_dir | |||
data_cache = cache_dir | |||
if name: | |||
data_cache = os.path.join(data_cache, name) | |||
parsed = urlparse(url_or_filename) | |||
if parsed.scheme in ("http", "https"): | |||
# URL, so get it from the cache (downloading if necessary) | |||
return get_from_cache(url_or_filename, dataset_cache) | |||
elif parsed.scheme == "" and Path(os.path.join(dataset_cache, url_or_filename)).exists(): | |||
return get_from_cache(url_or_filename, Path(data_cache)) | |||
elif parsed.scheme == "" and Path(os.path.join(data_cache, url_or_filename)).exists(): | |||
# File, and it exists. | |||
return Path(url_or_filename) | |||
return Path(os.path.join(data_cache, url_or_filename)) | |||
elif parsed.scheme == "": | |||
# File, but it doesn't exist. | |||
raise FileNotFoundError("file {} not found".format(url_or_filename)) | |||
raise FileNotFoundError("file {} not found in {}.".format(url_or_filename, data_cache)) | |||
else: | |||
# Something unknown | |||
raise ValueError( | |||
@@ -75,48 +133,143 @@ def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: | |||
def get_filepath(filepath): | |||
""" | |||
如果filepath中只有一个文件,则直接返回对应的全路径 | |||
:param filepath: | |||
如果filepath为文件夹, | |||
如果内含多个文件, 返回filepath | |||
如果只有一个文件, 返回filepath + filename | |||
如果filepath为文件 | |||
返回filepath | |||
:param str filepath: 路径 | |||
:return: | |||
""" | |||
if os.path.isdir(filepath): | |||
files = os.listdir(filepath) | |||
if len(files)==1: | |||
if len(files) == 1: | |||
return os.path.join(filepath, files[0]) | |||
else: | |||
return filepath | |||
return filepath | |||
elif os.path.isfile(filepath): | |||
return filepath | |||
else: | |||
raise FileNotFoundError(f"{filepath} is not a valid file or directory.") | |||
def get_defalt_path(): | |||
def get_cache_path(): | |||
""" | |||
获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 | |||
获取fastNLP默认cache的存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 | |||
:return: | |||
:return str: 存放路径 | |||
""" | |||
if 'FASTNLP_CACHE_DIR' in os.environ: | |||
fastnlp_cache_dir = os.environ.get('FASTNLP_CACHE_DIR') | |||
if os.path.exists(fastnlp_cache_dir): | |||
if os.path.isdir(fastnlp_cache_dir): | |||
return fastnlp_cache_dir | |||
raise RuntimeError("Some errors happens on cache directory.") | |||
else: | |||
raise RuntimeError("There function is not available right now.") | |||
else: | |||
raise NotADirectoryError(f"{os.environ['FASTNLP_CACHE_DIR']} is not a directory.") | |||
fastnlp_cache_dir = os.path.expanduser(os.path.join("~", ".fastNLP")) | |||
return fastnlp_cache_dir | |||
def _get_base_url(name): | |||
""" | |||
根据name返回下载的url地址。 | |||
:param str name: 支持dataset和embedding两种 | |||
:return: | |||
""" | |||
# 返回的URL结尾必须是/ | |||
if 'FASTNLP_BASE_URL' in os.environ: | |||
fastnlp_base_url = os.environ['FASTNLP_BASE_URL'] | |||
return fastnlp_base_url | |||
raise RuntimeError("There function is not available right now.") | |||
environ_name = "FASTNLP_{}_URL".format(name.upper()) | |||
if environ_name in os.environ: | |||
url = os.environ[environ_name] | |||
if url.endswith('/'): | |||
return url | |||
else: | |||
return url + '/' | |||
else: | |||
URLS = { | |||
'embedding': "http://dbcloud.irocn.cn:8989/api/public/dl/", | |||
"dataset": "http://dbcloud.irocn.cn:8989/api/public/dl/dataset/" | |||
} | |||
if name.lower() not in URLS: | |||
raise KeyError(f"{name} is not recognized.") | |||
return URLS[name.lower()] | |||
def _get_embedding_url(embed_type, name): | |||
""" | |||
给定embedding类似和名称,返回下载url | |||
:param str embed_type: 支持static, bert, elmo。即embedding的类型 | |||
:param str name: embedding的名称, 例如en, cn, based等 | |||
:return: str, 下载的url地址 | |||
""" | |||
# 从扩展中寻找下载的url | |||
_filename = FASTNLP_EXTEND_EMBEDDING_URL.get(embed_type, None) | |||
if _filename: | |||
url = _read_extend_url_file(_filename, name) | |||
if url: | |||
return url | |||
embed_map = PRETRAIN_MAP.get(embed_type, None) | |||
if embed_map: | |||
filename = embed_map.get(name, None) | |||
if filename: | |||
url = _get_base_url('embedding') + filename | |||
return url | |||
raise KeyError("There is no {}. Only supports {}.".format(name, list(embed_map.keys()))) | |||
else: | |||
raise KeyError(f"There is no {embed_type}. Only supports bert, elmo, static") | |||
def _read_extend_url_file(filename, name)->str: | |||
""" | |||
filename中的内容使用制表符隔开,第一列是名称,第二列是下载的url地址 | |||
:param str filename: 在默认的路径下寻找file这个文件 | |||
:param str name: 需要寻找的资源的名称 | |||
:return: str or None | |||
""" | |||
cache_dir = get_cache_path() | |||
filepath = os.path.join(cache_dir, filename) | |||
if os.path.exists(filepath): | |||
with open(filepath, 'r', encoding='utf-8') as f: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split('\t') | |||
if len(parts) == 2: | |||
if name == parts[0]: | |||
return parts[1] | |||
return None | |||
def _get_dataset_url(name): | |||
""" | |||
给定dataset的名称,返回下载url | |||
:param str name: 给定dataset的名称,比如imdb, sst-2等 | |||
:return: str | |||
""" | |||
# 从扩展中寻找下载的url | |||
url = _read_extend_url_file(FASTNLP_EXTEND_DATASET_URL, name) | |||
if url: | |||
return url | |||
filename = DATASET_DIR.get(name, None) | |||
if filename: | |||
url = _get_base_url('dataset') + filename | |||
return url | |||
else: | |||
raise KeyError(f"There is no {name}.") | |||
def split_filename_suffix(filepath): | |||
""" | |||
给定filepath返回对应的name和suffix | |||
:param filepath: | |||
给定filepath 返回对应的name和suffix. 如果后缀是多个点,仅支持.tar.gz类型 | |||
:param filepath: 文件路径 | |||
:return: filename, suffix | |||
""" | |||
filename = os.path.basename(filepath) | |||
@@ -127,21 +280,19 @@ def split_filename_suffix(filepath): | |||
def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||
""" | |||
尝试在cache_dir中寻找url定义的资源; 如果没有找到。则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。 | |||
如果从url中下载的资源解压后有多个文件,则返回directory的路径; 如果只有一个资源,则返回具体的路径。 | |||
尝试在cache_dir中寻找url定义的资源; 如果没有找到; 则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。会将下载的 | |||
文件解压,将解压后的文件全部放在cache_dir文件夹中。 | |||
如果从url中下载的资源解压后有多个文件,则返回目录的路径; 如果只有一个资源文件,则返回具体的路径。 | |||
:param url: 资源的 url | |||
:param cache_dir: cache 目录 | |||
:return: 路径 | |||
""" | |||
cache_dir.mkdir(parents=True, exist_ok=True) | |||
filename = re.sub(r".+/", "", url) | |||
dir_name, suffix = split_filename_suffix(filename) | |||
sep_index = dir_name[::-1].index('-') | |||
if sep_index<0: | |||
check_sum = None | |||
else: | |||
check_sum = dir_name[-sep_index+1:] | |||
sep_index = len(dir_name) if sep_index==-1 else -sep_index-1 | |||
dir_name = dir_name[:sep_index] | |||
# 寻找与它名字匹配的内容, 而不关心后缀 | |||
match_dir_name = match_file(dir_name, cache_dir) | |||
@@ -154,11 +305,11 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||
return get_filepath(cache_path) | |||
# make HEAD request to check ETag TODO ETag可以用来判断资源是否已经更新了,之后需要加上 | |||
response = requests.head(url, headers={"User-Agent": "fastNLP"}) | |||
if response.status_code != 200: | |||
raise IOError( | |||
f"HEAD request failed for url {url} with status code {response.status_code}." | |||
) | |||
# response = requests.head(url, headers={"User-Agent": "fastNLP"}) | |||
# if response.status_code != 200: | |||
# raise IOError( | |||
# f"HEAD request failed for url {url} with status code {response.status_code}." | |||
# ) | |||
# add ETag to filename if it exists | |||
# etag = response.headers.get("ETag") | |||
@@ -166,74 +317,73 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||
if not cache_path.exists(): | |||
# Download to temporary file, then copy to cache dir once finished. | |||
# Otherwise you get corrupt cache entries if the download gets interrupted. | |||
fd, temp_filename = tempfile.mkstemp() | |||
print("%s not found in cache, downloading to %s"%(url, temp_filename)) | |||
# GET file object | |||
req = requests.get(url, stream=True, headers={"User-Agent": "fastNLP"}) | |||
content_length = req.headers.get("Content-Length") | |||
total = int(content_length) if content_length is not None else None | |||
progress = tqdm(unit="B", total=total) | |||
sha256 = hashlib.sha256() | |||
with open(temp_filename, "wb") as temp_file: | |||
for chunk in req.iter_content(chunk_size=1024): | |||
if chunk: # filter out keep-alive new chunks | |||
progress.update(len(chunk)) | |||
temp_file.write(chunk) | |||
sha256.update(chunk) | |||
# check sum | |||
digit = sha256.hexdigest()[:8] | |||
if not check_sum: | |||
assert digit == check_sum, "File corrupted when download." | |||
progress.close() | |||
print(f"Finish download from {url}.") | |||
# 开始解压 | |||
delete_temp_dir = None | |||
if suffix in ('.zip', '.tar.gz'): | |||
uncompress_temp_dir = tempfile.mkdtemp() | |||
delete_temp_dir = uncompress_temp_dir | |||
print(f"Start to uncompress file to {uncompress_temp_dir}.") | |||
if suffix == '.zip': | |||
unzip_file(Path(temp_filename), Path(uncompress_temp_dir)) | |||
if req.status_code == 200: | |||
content_length = req.headers.get("Content-Length") | |||
total = int(content_length) if content_length is not None else None | |||
progress = tqdm(unit="B", total=total, unit_scale=1) | |||
fd, temp_filename = tempfile.mkstemp() | |||
print("%s not found in cache, downloading to %s" % (url, temp_filename)) | |||
with open(temp_filename, "wb") as temp_file: | |||
for chunk in req.iter_content(chunk_size=1024 * 16): | |||
if chunk: # filter out keep-alive new chunks | |||
progress.update(len(chunk)) | |||
temp_file.write(chunk) | |||
progress.close() | |||
print(f"Finish download from {url}.") | |||
# 开始解压 | |||
delete_temp_dir = None | |||
if suffix in ('.zip', '.tar.gz'): | |||
uncompress_temp_dir = tempfile.mkdtemp() | |||
delete_temp_dir = uncompress_temp_dir | |||
print(f"Start to uncompress file to {uncompress_temp_dir}") | |||
if suffix == '.zip': | |||
unzip_file(Path(temp_filename), Path(uncompress_temp_dir)) | |||
else: | |||
untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir)) | |||
filenames = os.listdir(uncompress_temp_dir) | |||
if len(filenames) == 1: | |||
if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])): | |||
uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0]) | |||
cache_path.mkdir(parents=True, exist_ok=True) | |||
print("Finish un-compressing file.") | |||
else: | |||
untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir)) | |||
filenames = os.listdir(uncompress_temp_dir) | |||
if len(filenames)==1: | |||
if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])): | |||
uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0]) | |||
cache_path.mkdir(parents=True, exist_ok=True) | |||
print("Finish un-compressing file.") | |||
uncompress_temp_dir = temp_filename | |||
cache_path = str(cache_path) + suffix | |||
success = False | |||
try: | |||
# 复制到指定的位置 | |||
print(f"Copy file to {cache_path}") | |||
if os.path.isdir(uncompress_temp_dir): | |||
for filename in os.listdir(uncompress_temp_dir): | |||
if os.path.isdir(os.path.join(uncompress_temp_dir, filename)): | |||
shutil.copytree(os.path.join(uncompress_temp_dir, filename), cache_path / filename) | |||
else: | |||
shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path / filename) | |||
else: | |||
shutil.copyfile(uncompress_temp_dir, cache_path) | |||
success = True | |||
except Exception as e: | |||
print(e) | |||
raise e | |||
finally: | |||
if not success: | |||
if cache_path.exists(): | |||
if cache_path.is_file(): | |||
os.remove(cache_path) | |||
else: | |||
shutil.rmtree(cache_path) | |||
if delete_temp_dir: | |||
shutil.rmtree(delete_temp_dir) | |||
os.close(fd) | |||
os.remove(temp_filename) | |||
return get_filepath(cache_path) | |||
else: | |||
uncompress_temp_dir = temp_filename | |||
cache_path = str(cache_path) + suffix | |||
success = False | |||
try: | |||
# 复制到指定的位置 | |||
print(f"Copy file to {cache_path}.") | |||
if os.path.isdir(uncompress_temp_dir): | |||
for filename in os.listdir(uncompress_temp_dir): | |||
shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path/filename) | |||
else: | |||
shutil.copyfile(uncompress_temp_dir, cache_path) | |||
success = True | |||
except Exception as e: | |||
print(e) | |||
raise e | |||
finally: | |||
if not success: | |||
if cache_path.exists(): | |||
if cache_path.is_file(): | |||
os.remove(cache_path) | |||
else: | |||
shutil.rmtree(cache_path) | |||
if delete_temp_dir: | |||
shutil.rmtree(delete_temp_dir) | |||
os.close(fd) | |||
os.remove(temp_filename) | |||
return get_filepath(cache_path) | |||
raise HTTPError(f"Status code:{req.status_code}. Fail to download from {url}.") | |||
def unzip_file(file: Path, to: Path): | |||
@@ -245,55 +395,30 @@ def unzip_file(file: Path, to: Path): | |||
zipObj.extractall(to) | |||
def untar_gz_file(file:Path, to:Path): | |||
def untar_gz_file(file: Path, to: Path): | |||
import tarfile | |||
with tarfile.open(file, 'r:gz') as tar: | |||
tar.extractall(to) | |||
def match_file(dir_name: str, cache_dir: str) -> str: | |||
def match_file(dir_name: str, cache_dir: Path) -> str: | |||
""" | |||
匹配的原则是,在cache_dir下的文件: (1) 与dir_name完全一致; (2) 除了后缀以外和dir_name完全一致。 | |||
匹配的原则是: 在cache_dir下的文件与dir_name完全一致, 或除了后缀以外和dir_name完全一致。 | |||
如果找到了两个匹配的结果将报错. 如果找到了则返回匹配的文件的名称; 没有找到返回空字符串 | |||
:param dir_name: 需要匹配的名称 | |||
:param cache_dir: 在该目录下找匹配dir_name是否存在 | |||
:return: str | |||
:return str: 做为匹配结果的字符串 | |||
""" | |||
files = os.listdir(cache_dir) | |||
matched_filenames = [] | |||
for file_name in files: | |||
if re.match(dir_name+'$', file_name) or re.match(dir_name+'\\..*', file_name): | |||
if re.match(dir_name + '$', file_name) or re.match(dir_name + '\\..*', file_name): | |||
matched_filenames.append(file_name) | |||
if len(matched_filenames)==0: | |||
if len(matched_filenames) == 0: | |||
return '' | |||
elif len(matched_filenames)==1: | |||
elif len(matched_filenames) == 1: | |||
return matched_filenames[-1] | |||
else: | |||
raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.") | |||
if __name__ == '__main__': | |||
cache_dir = Path('caches') | |||
cache_dir = None | |||
# 需要对cache_dir进行测试 | |||
base_url = 'http://0.0.0.0:8888/file/download' | |||
# if True: | |||
# for filename in os.listdir(cache_dir): | |||
# if os.path.isdir(os.path.join(cache_dir, filename)): | |||
# shutil.rmtree(os.path.join(cache_dir, filename)) | |||
# else: | |||
# os.remove(os.path.join(cache_dir, filename)) | |||
# 1. 测试.txt文件 | |||
print(cached_path(base_url + '/{}'.format('txt_test-bcb4fe65.txt'), cache_dir)) | |||
# 2. 测试.zip文件(只有一个文件) | |||
print(cached_path(base_url + '/{}'.format('zip_test-40966d39.zip'), cache_dir)) | |||
# 3. 测试.zip文件(有多个文件) | |||
print(cached_path(base_url + '/{}'.format('zip_pack_test-70c0b20d.zip'), cache_dir)) | |||
# 4. 测试.tar.gz文件 | |||
print(cached_path(base_url + '/{}'.format('tar_gz_test-3e2679cf.tar.gz'), cache_dir)) | |||
# 5. 测试.tar.gz多个文件 | |||
print(cached_path(base_url + '/{}'.format('tar_gz_pack_test-08dfdccd.tar.gz'), cache_dir)) | |||
# 6. 测试.pkl文件 |
@@ -0,0 +1,78 @@ | |||
""" | |||
Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle` 中。所有的Loader都支持以下的 | |||
三个方法: ``__init__`` , ``_load`` , ``loads`` . 其中 ``__init__(...)`` 用于申明读取参数,以及说明该Loader支持的数据格式, | |||
读取后 :class:`~fastNLP.DataSet` 中的 `field` ; ``_load(path)`` 方法传入文件路径读取单个文件,并返回 :class:`~fastNLP.DataSet` ; | |||
``load(paths)`` 用于读取文件夹下的文件,并返回 :class:`~fastNLP.io.DataBundle` 类型的对象 , load()方法支持以下几种类型的参数: | |||
0.传入None | |||
将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。 | |||
1.传入一个文件的 path | |||
返回的 `data_bundle` 包含一个名为 `train` 的 dataset ,可以通过 ``data_bundle.datasets['train']`` 获取 | |||
2.传入一个文件夹目录 | |||
将读取的是这个文件夹下文件名中包含 `train` , `test` , `dev` 的文件,其它文件会被忽略。假设某个目录下的文件为:: | |||
| | |||
+-train.txt | |||
+-dev.txt | |||
+-test.txt | |||
+-other.txt | |||
在 Loader().load('/path/to/dir') 返回的 `data_bundle` 中可以用 ``data_bundle.datasets['train']`` , ``data_bundle.datasets['dev']`` , | |||
``data_bundle.datasets['test']`` 获取对应的 `dataset` ,其中 `other.txt` 的内容会被忽略。假设某个目录下的文件为:: | |||
| | |||
+-train.txt | |||
+-dev.txt | |||
在 Loader().load('/path/to/dir') 返回的 `data_bundle` 中可以用 ``data_bundle.datasets['train']`` , | |||
``data_bundle.datasets['dev']`` 获取对应的 dataset。 | |||
3.传入一个字典 | |||
字典的的 key 为 `dataset` 的名称,value 是该 `dataset` 的文件路径:: | |||
paths = {'train':'/path/to/train', 'dev': '/path/to/dev', 'test':'/path/to/test'} | |||
在 Loader().load(paths) 返回的 `data_bundle` 中可以用 ``data_bundle.datasets['train']`` , ``data_bundle.datasets['dev']`` , | |||
``data_bundle.datasets['test']`` 来获取对应的 `dataset` | |||
fastNLP 目前提供了如下的 Loader | |||
""" | |||
__all__ = [ | |||
'YelpLoader', | |||
'YelpFullLoader', | |||
'YelpPolarityLoader', | |||
'IMDBLoader', | |||
'SSTLoader', | |||
'SST2Loader', | |||
'ConllLoader', | |||
'Conll2003Loader', | |||
'Conll2003NERLoader', | |||
'OntoNotesNERLoader', | |||
'CTBLoader', | |||
'Loader', | |||
'CSVLoader', | |||
'JsonLoader', | |||
'CWSLoader', | |||
'MNLILoader', | |||
"QuoraLoader", | |||
"SNLILoader", | |||
"QNLILoader", | |||
"RTELoader" | |||
] | |||
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader | |||
from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader | |||
from .csv import CSVLoader | |||
from .cws import CWSLoader | |||
from .json import JsonLoader | |||
from .loader import Loader | |||
from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader |
@@ -0,0 +1,369 @@ | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from .loader import Loader | |||
import warnings | |||
import os | |||
import random | |||
import shutil | |||
import numpy as np | |||
class YelpLoader(Loader): | |||
""" | |||
别名::class:`fastNLP.io.YelpLoader` :class:`fastNLP.io.loader.YelpLoader` | |||
原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。 | |||
Example:: | |||
"1","I got 'new' tires from the..." | |||
"1","Don't waste your time..." | |||
读取YelpFull, YelpPolarity的数据。可以通过xxx下载并预处理数据。 | |||
读取的DataSet将具备以下的数据结构 | |||
.. csv-table:: | |||
:header: "raw_words", "target" | |||
"I got 'new' tires from them and... ", "1" | |||
"Don't waste your time. We had two...", "1" | |||
"...", "..." | |||
""" | |||
def __init__(self): | |||
super(YelpLoader, self).__init__() | |||
def _load(self, path: str=None): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
for line in f: | |||
line = line.strip() | |||
sep_index = line.index(',') | |||
target = line[:sep_index] | |||
raw_words = line[sep_index + 1:] | |||
if target.startswith("\""): | |||
target = target[1:] | |||
if target.endswith("\""): | |||
target = target[:-1] | |||
if raw_words.endswith("\""): | |||
raw_words = raw_words[:-1] | |||
if raw_words.startswith('"'): | |||
raw_words = raw_words[1:] | |||
raw_words = raw_words.replace('""', '"') # 替换双引号 | |||
if raw_words: | |||
ds.append(Instance(raw_words=raw_words, target=target)) | |||
return ds | |||
class YelpFullLoader(YelpLoader): | |||
def download(self, dev_ratio: float = 0.1, seed: int = 0): | |||
""" | |||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | |||
in Neural Information Processing Systems 28 (NIPS 2015) | |||
根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后在output_dir中有train.csv, test.csv, | |||
dev.csv三个文件。 | |||
:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 | |||
:param int seed: 划分dev时的随机数种子 | |||
:return: str, 数据集的目录地址 | |||
""" | |||
dataset_name = 'yelp-review-full' | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否需要重新下载 | |||
re_download = True | |||
if dev_ratio>0: | |||
dev_line_count = 0 | |||
tr_line_count = 0 | |||
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ | |||
open(os.path.join(data_dir, 'dev.csv'), 'r', encoding='utf-8') as f2: | |||
for line in f1: | |||
tr_line_count += 1 | |||
for line in f2: | |||
dev_line_count += 1 | |||
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): | |||
re_download = True | |||
else: | |||
re_download = False | |||
if re_download: | |||
shutil.rmtree(data_dir) | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | |||
if dev_ratio > 0: | |||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||
random.seed(int(seed)) | |||
try: | |||
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f, \ | |||
open(os.path.join(data_dir, 'middle_file.csv'), 'w', encoding='utf-8') as f1, \ | |||
open(os.path.join(data_dir, 'dev.csv'), 'w', encoding='utf-8') as f2: | |||
for line in f: | |||
if random.random() < dev_ratio: | |||
f2.write(line) | |||
else: | |||
f1.write(line) | |||
os.remove(os.path.join(data_dir, 'train.csv')) | |||
os.renames(os.path.join(data_dir, 'middle_file.csv'), os.path.join(data_dir, 'train.csv')) | |||
finally: | |||
if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): | |||
os.remove(os.path.join(data_dir, 'middle_file.csv')) | |||
return data_dir | |||
class YelpPolarityLoader(YelpLoader): | |||
def download(self, dev_ratio: float = 0.1, seed: int = 0): | |||
""" | |||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | |||
in Neural Information Processing Systems 28 (NIPS 2015) | |||
根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后从train中切分0.1作为dev | |||
:param float dev_ratio: 如果路径中不存在dev.csv, 从train划分多少作为dev的数据. 如果为0,则不划分dev | |||
:param int seed: 划分dev时的随机数种子 | |||
:return: str, 数据集的目录地址 | |||
""" | |||
dataset_name = 'yelp-review-polarity' | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否符合比例要求 | |||
re_download = True | |||
if dev_ratio>0: | |||
dev_line_count = 0 | |||
tr_line_count = 0 | |||
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ | |||
open(os.path.join(data_dir, 'dev.csv'), 'r', encoding='utf-8') as f2: | |||
for line in f1: | |||
tr_line_count += 1 | |||
for line in f2: | |||
dev_line_count += 1 | |||
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): | |||
re_download = True | |||
else: | |||
re_download = False | |||
if re_download: | |||
shutil.rmtree(data_dir) | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | |||
if dev_ratio > 0: | |||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||
random.seed(int(seed)) | |||
try: | |||
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f, \ | |||
open(os.path.join(data_dir, 'middle_file.csv'), 'w', encoding='utf-8') as f1, \ | |||
open(os.path.join(data_dir, 'dev.csv'), 'w', encoding='utf-8') as f2: | |||
for line in f: | |||
if random.random() < dev_ratio: | |||
f2.write(line) | |||
else: | |||
f1.write(line) | |||
os.remove(os.path.join(data_dir, 'train.csv')) | |||
os.renames(os.path.join(data_dir, 'middle_file.csv'), os.path.join(data_dir, 'train.csv')) | |||
finally: | |||
if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): | |||
os.remove(os.path.join(data_dir, 'middle_file.csv')) | |||
return data_dir | |||
class IMDBLoader(Loader): | |||
""" | |||
别名::class:`fastNLP.io.IMDBLoader` :class:`fastNLP.io.loader.IMDBLoader` | |||
IMDBLoader读取后的数据将具有以下两列内容: raw_words: str, 需要分类的文本; target: str, 文本的标签 | |||
DataSet具备以下的结构: | |||
.. csv-table:: | |||
:header: "raw_words", "target" | |||
"Bromwell High is a cartoon ... ", "pos" | |||
"Story of a man who has ...", "neg" | |||
"...", "..." | |||
""" | |||
def __init__(self): | |||
super(IMDBLoader, self).__init__() | |||
def _load(self, path: str): | |||
dataset = DataSet() | |||
with open(path, 'r', encoding="utf-8") as f: | |||
for line in f: | |||
line = line.strip() | |||
if not line: | |||
continue | |||
parts = line.split('\t') | |||
target = parts[0] | |||
words = parts[1] | |||
if words: | |||
dataset.append(Instance(raw_words=words, target=target)) | |||
if len(dataset) == 0: | |||
raise RuntimeError(f"{path} has no valid data.") | |||
return dataset | |||
def download(self, dev_ratio: float = 0.1, seed: int = 0): | |||
""" | |||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
http://www.aclweb.org/anthology/P11-1015 | |||
根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后从train中切分0.1作为dev | |||
:param float dev_ratio: 如果路径中没有dev.txt。从train划分多少作为dev的数据. 如果为0,则不划分dev | |||
:param int seed: 划分dev时的随机数种子 | |||
:return: str, 数据集的目录地址 | |||
""" | |||
dataset_name = 'aclImdb' | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
if os.path.exists(os.path.join(data_dir, 'dev.txt')): # 存在dev的话,check是否符合比例要求 | |||
re_download = True | |||
if dev_ratio>0: | |||
dev_line_count = 0 | |||
tr_line_count = 0 | |||
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f1, \ | |||
open(os.path.join(data_dir, 'dev.txt'), 'r', encoding='utf-8') as f2: | |||
for line in f1: | |||
tr_line_count += 1 | |||
for line in f2: | |||
dev_line_count += 1 | |||
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): | |||
re_download = True | |||
else: | |||
re_download = False | |||
if re_download: | |||
shutil.rmtree(data_dir) | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | |||
if dev_ratio > 0: | |||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||
random.seed(int(seed)) | |||
try: | |||
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \ | |||
open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \ | |||
open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2: | |||
for line in f: | |||
if random.random() < dev_ratio: | |||
f2.write(line) | |||
else: | |||
f1.write(line) | |||
os.remove(os.path.join(data_dir, 'train.txt')) | |||
os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt')) | |||
finally: | |||
if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): | |||
os.remove(os.path.join(data_dir, 'middle_file.txt')) | |||
return data_dir | |||
class SSTLoader(Loader): | |||
""" | |||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.loader.SSTLoader` | |||
读取之后的DataSet具有以下的结构 | |||
.. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field | |||
:header: "raw_words" | |||
"(3 (2 It) (4 (4 (2 's) (4 (3 (2 a)..." | |||
"(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..." | |||
"..." | |||
raw_words列是str。 | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path: str): | |||
""" | |||
从path读取SST文件 | |||
:param str path: 文件路径 | |||
:return: DataSet | |||
""" | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
ds.append(Instance(raw_words=line)) | |||
return ds | |||
def download(self): | |||
""" | |||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf | |||
:return: str, 数据集的目录地址 | |||
""" | |||
output_dir = self._get_dataset_path(dataset_name='sst') | |||
return output_dir | |||
class SST2Loader(Loader): | |||
""" | |||
数据SST2的Loader | |||
读取之后DataSet将如下所示 | |||
.. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field | |||
:header: "raw_words", "target" | |||
"it 's a charming and often affecting...", "1" | |||
"unflinchingly bleak and...", "0" | |||
"..." | |||
test的DataSet没有target列。 | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path: str): | |||
""" | |||
从path读取SST2文件 | |||
:param str path: 数据路径 | |||
:return: DataSet | |||
""" | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() # 跳过header | |||
if 'test' in os.path.split(path)[1]: | |||
warnings.warn("SST2's test file has no target.") | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
sep_index = line.index('\t') | |||
raw_words = line[sep_index + 1:] | |||
if raw_words: | |||
ds.append(Instance(raw_words=raw_words)) | |||
else: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
raw_words = line[:-2] | |||
target = line[-1] | |||
if raw_words: | |||
ds.append(Instance(raw_words=raw_words, target=target)) | |||
return ds | |||
def download(self): | |||
""" | |||
自动下载数据集,如果你使用了该数据集,请引用以下的文章 | |||
https://nlp.stanford.edu/pubs/SocherBauerManningNg_ACL2013.pdf | |||
:return: | |||
""" | |||
output_dir = self._get_dataset_path(dataset_name='sst-2') | |||
return output_dir |
@@ -0,0 +1,264 @@ | |||
from typing import Dict, Union | |||
from .loader import Loader | |||
from ...core.dataset import DataSet | |||
from ..file_reader import _read_conll | |||
from ...core.instance import Instance | |||
from .. import DataBundle | |||
from ..utils import check_loader_paths | |||
from ...core.const import Const | |||
class ConllLoader(Loader): | |||
""" | |||
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.data_loader.ConllLoader` | |||
ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示: | |||
Example:: | |||
# 文件中的内容 | |||
Nadim NNP B-NP B-PER | |||
Ladki NNP I-NP I-PER | |||
AL-AIN NNP B-NP B-LOC | |||
United NNP B-NP B-LOC | |||
Arab NNP I-NP I-LOC | |||
Emirates NNPS I-NP I-LOC | |||
1996-12-06 CD I-NP O | |||
... | |||
# 如果用以下的参数读取,返回的DataSet将包含raw_words和pos两个field, 这两个field的值分别取自于第0列与第1列 | |||
dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll') | |||
# 如果用以下的参数读取,返回的DataSet将包含raw_words和ner两个field, 这两个field的值分别取自于第0列与第2列 | |||
dataset = ConllLoader(headers=['raw_words', 'ner'], indexes=[0, 3])._load('/path/to/train.conll') | |||
# 如果用以下的参数读取,返回的DataSet将包含raw_words, pos和ner三个field | |||
dataset = ConllLoader(headers=['raw_words', 'pos', 'ner'], indexes=[0, 1, 3])._load('/path/to/train.conll') | |||
ConllLoader返回的DataSet的field由传入的headers确定。 | |||
数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 | |||
:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | |||
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | |||
""" | |||
def __init__(self, headers, indexes=None, dropna=True): | |||
super(ConllLoader, self).__init__() | |||
if not isinstance(headers, (list, tuple)): | |||
raise TypeError( | |||
'invalid headers: {}, should be list of strings'.format(headers)) | |||
self.headers = headers | |||
self.dropna = dropna | |||
if indexes is None: | |||
self.indexes = list(range(len(self.headers))) | |||
else: | |||
if len(indexes) != len(headers): | |||
raise ValueError | |||
self.indexes = indexes | |||
def _load(self, path): | |||
""" | |||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | |||
:param str path: 文件的路径 | |||
:return: DataSet | |||
""" | |||
ds = DataSet() | |||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | |||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||
ds.append(Instance(**ins)) | |||
return ds | |||
class Conll2003Loader(ConllLoader): | |||
""" | |||
用于读取conll2003任务的数据。数据的内容应该类似与以下的内容, 第一列为raw_words, 第二列为pos, 第三列为chunking,第四列为ner。 | |||
Example:: | |||
Nadim NNP B-NP B-PER | |||
Ladki NNP I-NP I-PER | |||
AL-AIN NNP B-NP B-LOC | |||
United NNP B-NP B-LOC | |||
Arab NNP I-NP I-LOC | |||
Emirates NNPS I-NP I-LOC | |||
1996-12-06 CD I-NP O | |||
... | |||
返回的DataSet的内容为 | |||
.. csv-table:: 下面是Conll2003Loader加载后数据具备的结构。 | |||
:header: "raw_words", "pos", "chunk", "ner" | |||
"[Nadim, Ladki]", "[NNP, NNP]", "[B-NP, I-NP]", "[B-PER, I-PER]" | |||
"[AL-AIN, United, Arab, ...]", "[NNP, NNP, NNP, ...]", "[B-NP, B-NP, I-NP, ...]", "[B-LOC, B-LOC, I-LOC, ...]" | |||
"[...]", "[...]", "[...]", "[...]" | |||
""" | |||
def __init__(self): | |||
headers = [ | |||
'raw_words', 'pos', 'chunk', 'ner', | |||
] | |||
super(Conll2003Loader, self).__init__(headers=headers) | |||
def _load(self, path): | |||
""" | |||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | |||
:param str path: 文件的路径 | |||
:return: DataSet | |||
""" | |||
ds = DataSet() | |||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | |||
doc_start = False | |||
for i, h in enumerate(self.headers): | |||
field = data[i] | |||
if str(field[0]).startswith('-DOCSTART-'): | |||
doc_start = True | |||
break | |||
if doc_start: | |||
continue | |||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||
ds.append(Instance(**ins)) | |||
return ds | |||
def download(self, output_dir=None): | |||
raise RuntimeError("conll2003 cannot be downloaded automatically.") | |||
class Conll2003NERLoader(ConllLoader): | |||
""" | |||
用于读取conll2003任务的NER数据。 | |||
Example:: | |||
Nadim NNP B-NP B-PER | |||
Ladki NNP I-NP I-PER | |||
AL-AIN NNP B-NP B-LOC | |||
United NNP B-NP B-LOC | |||
Arab NNP I-NP I-LOC | |||
Emirates NNPS I-NP I-LOC | |||
1996-12-06 CD I-NP O | |||
... | |||
返回的DataSet的内容为 | |||
.. csv-table:: 下面是Conll2003Loader加载后数据具备的结构, target是BIO2编码 | |||
:header: "raw_words", "target" | |||
"[Nadim, Ladki]", "[B-PER, I-PER]" | |||
"[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" | |||
"[...]", "[...]" | |||
""" | |||
def __init__(self): | |||
headers = [ | |||
'raw_words', 'target', | |||
] | |||
super().__init__(headers=headers, indexes=[0, 3]) | |||
def _load(self, path): | |||
""" | |||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | |||
:param str path: 文件的路径 | |||
:return: DataSet | |||
""" | |||
ds = DataSet() | |||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | |||
doc_start = False | |||
for i, h in enumerate(self.headers): | |||
field = data[i] | |||
if str(field[0]).startswith('-DOCSTART-'): | |||
doc_start = True | |||
break | |||
if doc_start: | |||
continue | |||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||
ds.append(Instance(**ins)) | |||
return ds | |||
def download(self): | |||
raise RuntimeError("conll2003 cannot be downloaded automatically.") | |||
class OntoNotesNERLoader(ConllLoader): | |||
""" | |||
用以读取OntoNotes的NER数据,同时也是Conll2012的NER任务数据。将OntoNote数据处理为conll格式的过程可以参考 | |||
https://github.com/yhcc/OntoNotes-5.0-NER。OntoNoteNERLoader将取第4列和第11列的内容。 | |||
返回的DataSet的内容为 | |||
.. csv-table:: 下面是使用OntoNoteNERLoader读取的DataSet所具备的结构, target列是BIO编码 | |||
:header: "raw_words", "target" | |||
"[Nadim, Ladki]", "[B-PER, I-PER]" | |||
"[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" | |||
"[...]", "[...]" | |||
""" | |||
def __init__(self): | |||
super().__init__(headers=[Const.RAW_WORD, Const.TARGET], indexes=[3, 10]) | |||
def _load(self, path:str): | |||
dataset = super()._load(path) | |||
def convert_to_bio(tags): | |||
bio_tags = [] | |||
flag = None | |||
for tag in tags: | |||
label = tag.strip("()*") | |||
if '(' in tag: | |||
bio_label = 'B-' + label | |||
flag = label | |||
elif flag: | |||
bio_label = 'I-' + flag | |||
else: | |||
bio_label = 'O' | |||
if ')' in tag: | |||
flag = None | |||
bio_tags.append(bio_label) | |||
return bio_tags | |||
def convert_word(words): | |||
converted_words = [] | |||
for word in words: | |||
word = word.replace('/.', '.') # 有些结尾的.是/.形式的 | |||
if not word.startswith('-'): | |||
converted_words.append(word) | |||
continue | |||
# 以下是由于这些符号被转义了,再转回来 | |||
tfrs = {'-LRB-':'(', | |||
'-RRB-': ')', | |||
'-LSB-': '[', | |||
'-RSB-': ']', | |||
'-LCB-': '{', | |||
'-RCB-': '}' | |||
} | |||
if word in tfrs: | |||
converted_words.append(tfrs[word]) | |||
else: | |||
converted_words.append(word) | |||
return converted_words | |||
dataset.apply_field(convert_word, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) | |||
dataset.apply_field(convert_to_bio, field_name=Const.TARGET, new_field_name=Const.TARGET) | |||
return dataset | |||
def download(self): | |||
raise RuntimeError("Ontonotes cannot be downloaded automatically, you can refer " | |||
"https://github.com/yhcc/OntoNotes-5.0-NER to download and preprocess.") | |||
class CTBLoader(Loader): | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path:str): | |||
pass |
@@ -0,0 +1,32 @@ | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ..file_reader import _read_csv | |||
from .loader import Loader | |||
class CSVLoader(Loader): | |||
""" | |||
别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader` | |||
读取CSV格式的数据集, 返回 ``DataSet`` 。 | |||
:param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 | |||
若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` | |||
:param str sep: CSV文件中列与列之间的分隔符. Default: "," | |||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||
Default: ``False`` | |||
""" | |||
def __init__(self, headers=None, sep=",", dropna=False): | |||
super().__init__() | |||
self.headers = headers | |||
self.sep = sep | |||
self.dropna = dropna | |||
def _load(self, path): | |||
ds = DataSet() | |||
for idx, data in _read_csv(path, headers=self.headers, | |||
sep=self.sep, dropna=self.dropna): | |||
ds.append(Instance(**data)) | |||
return ds | |||
@@ -0,0 +1,41 @@ | |||
from .loader import Loader | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
class CWSLoader(Loader): | |||
""" | |||
分词任务数据加载器, | |||
SigHan2005的数据可以用xxx下载并预处理 | |||
CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: | |||
Example:: | |||
上海 浦东 开发 与 法制 建设 同步 | |||
新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 ) | |||
... | |||
该Loader读取后的DataSet具有如下的结构 | |||
.. csv-table:: | |||
:header: "raw_words" | |||
"上海 浦东 开发 与 法制 建设 同步" | |||
"新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" | |||
"..." | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path:str): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
ds.append(Instance(raw_words=line)) | |||
return ds | |||
def download(self, output_dir=None): | |||
raise RuntimeError("You can refer {} for sighan2005's data downloading.") |
@@ -0,0 +1,40 @@ | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ..file_reader import _read_json | |||
from .loader import Loader | |||
class JsonLoader(Loader): | |||
""" | |||
别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.loader.JsonLoader` | |||
读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 | |||
:param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name | |||
``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` , | |||
`value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名 | |||
``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` | |||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||
Default: ``False`` | |||
""" | |||
def __init__(self, fields=None, dropna=False): | |||
super(JsonLoader, self).__init__() | |||
self.dropna = dropna | |||
self.fields = None | |||
self.fields_list = None | |||
if fields: | |||
self.fields = {} | |||
for k, v in fields.items(): | |||
self.fields[k] = k if v is None else v | |||
self.fields_list = list(self.fields.keys()) | |||
def _load(self, path): | |||
ds = DataSet() | |||
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | |||
if self.fields: | |||
ins = {self.fields[k]: v for k, v in d.items()} | |||
else: | |||
ins = d | |||
ds.append(Instance(**ins)) | |||
return ds |
@@ -0,0 +1,75 @@ | |||
from ...core.dataset import DataSet | |||
from .. import DataBundle | |||
from ..utils import check_loader_paths | |||
from typing import Union, Dict | |||
import os | |||
from ..file_utils import _get_dataset_url, get_cache_path, cached_path | |||
class Loader: | |||
def __init__(self): | |||
pass | |||
def _load(self, path:str) -> DataSet: | |||
raise NotImplementedError | |||
def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle: | |||
""" | |||
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 | |||
读取的field根据ConllLoader初始化时传入的headers决定。 | |||
:param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式 | |||
(0) 如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 | |||
(1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件 | |||
名包含'train'、 'dev'、 'test'则会报错 | |||
Example:: | |||
data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train、 | |||
# dev、 test等有所变化,可以通过以下的方式取出DataSet | |||
tr_data = data_bundle.datasets['train'] | |||
te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段 | |||
(2) 传入文件路径 | |||
Example:: | |||
data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' | |||
tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet | |||
(3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test | |||
Example:: | |||
paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} | |||
data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" | |||
dev_data = data_bundle.datasets['dev'] | |||
:return: 返回的:class:`~fastNLP.io.DataBundle` | |||
""" | |||
if paths is None: | |||
paths = self.download() | |||
paths = check_loader_paths(paths) | |||
datasets = {name: self._load(path) for name, path in paths.items()} | |||
data_bundle = DataBundle(datasets=datasets) | |||
return data_bundle | |||
def download(self): | |||
raise NotImplementedError(f"{self.__class__} cannot download data automatically.") | |||
def _get_dataset_path(self, dataset_name): | |||
""" | |||
传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存 | |||
:param str dataset_name: 数据集的名称 | |||
:return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。 | |||
""" | |||
default_cache_path = get_cache_path() | |||
url = _get_dataset_url(dataset_name) | |||
output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') | |||
return output_dir | |||
@@ -0,0 +1,302 @@ | |||
import warnings | |||
from .loader import Loader | |||
from .json import JsonLoader | |||
from ...core.const import Const | |||
from .. import DataBundle | |||
import os | |||
from typing import Union, Dict | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
class MNLILoader(Loader): | |||
""" | |||
读取MNLI任务的数据,读取之后的DataSet中包含以下的内容,words0是sentence1, words1是sentence2, target是gold_label, 测试集中没 | |||
有target列。 | |||
.. csv-table:: | |||
:header: "raw_words1", "raw_words2", "target" | |||
"The new rights are...", "Everyone really likes..", "neutral" | |||
"This site includes a...", "The Government Executive...", "contradiction" | |||
"...", "...","." | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path:str): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() # 跳过header | |||
if path.endswith("test.tsv"): | |||
warnings.warn("RTE's test file has no target.") | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split('\t') | |||
raw_words1 = parts[8] | |||
raw_words2 = parts[9] | |||
if raw_words1 and raw_words2: | |||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2)) | |||
else: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split('\t') | |||
raw_words1 = parts[8] | |||
raw_words2 = parts[9] | |||
target = parts[-1] | |||
if raw_words1 and raw_words2 and target: | |||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | |||
return ds | |||
def load(self, paths:str=None): | |||
""" | |||
:param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv, | |||
test_mismatched.tsv, train.tsv文件夹 | |||
:return: DataBundle | |||
""" | |||
if paths: | |||
paths = os.path.abspath(os.path.expanduser(paths)) | |||
else: | |||
paths = self.download() | |||
if not os.path.isdir(paths): | |||
raise NotADirectoryError(f"{paths} is not a valid directory.") | |||
files = {'dev_matched':"dev_matched.tsv", | |||
"dev_mismatched":"dev_mismatched.tsv", | |||
"test_matched":"test_matched.tsv", | |||
"test_mismatched":"test_mismatched.tsv", | |||
"train":'train.tsv'} | |||
datasets = {} | |||
for name, filename in files.items(): | |||
filepath = os.path.join(paths, filename) | |||
if not os.path.isfile(filepath): | |||
if 'test' not in name: | |||
raise FileNotFoundError(f"{name} not found in directory {filepath}.") | |||
datasets[name] = self._load(filepath) | |||
data_bundle = DataBundle(datasets=datasets) | |||
return data_bundle | |||
def download(self): | |||
""" | |||
如果你使用了这个数据,请引用 | |||
https://www.nyu.edu/projects/bowman/multinli/paper.pdf | |||
:return: | |||
""" | |||
output_dir = self._get_dataset_path('mnli') | |||
return output_dir | |||
class SNLILoader(JsonLoader): | |||
""" | |||
读取之后的DataSet中的field情况为 | |||
.. csv-table:: 下面是使用SNLILoader加载的DataSet所具备的field | |||
:header: "raw_words1", "raw_words2", "target" | |||
"The new rights are...", "Everyone really likes..", "neutral" | |||
"This site includes a...", "The Government Executive...", "entailment" | |||
"...", "...", "." | |||
""" | |||
def __init__(self): | |||
super().__init__(fields={ | |||
'sentence1': Const.RAW_WORDS(0), | |||
'sentence2': Const.RAW_WORDS(1), | |||
'gold_label': Const.TARGET, | |||
}) | |||
def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle: | |||
""" | |||
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 | |||
读取的field根据ConllLoader初始化时传入的headers决定。 | |||
:param str paths: 传入一个目录, 将在该目录下寻找snli_1.0_train.jsonl, snli_1.0_dev.jsonl | |||
和snli_1.0_test.jsonl三个文件。 | |||
:return: 返回的:class:`~fastNLP.io.DataBundle` | |||
""" | |||
_paths = {} | |||
if paths is None: | |||
paths = self.download() | |||
if paths: | |||
if os.path.isdir(paths): | |||
if not os.path.isfile(os.path.join(paths, 'snli_1.0_train.jsonl')): | |||
raise FileNotFoundError(f"snli_1.0_train.jsonl is not found in {paths}") | |||
_paths['train'] = os.path.join(paths, 'snli_1.0_train.jsonl') | |||
for filename in ['snli_1.0_dev.jsonl', 'snli_1.0_test.jsonl']: | |||
filepath = os.path.join(paths, filename) | |||
_paths[filename.split('_')[-1].split('.')[0]] = filepath | |||
paths = _paths | |||
else: | |||
raise NotADirectoryError(f"{paths} is not a valid directory.") | |||
datasets = {name: self._load(path) for name, path in paths.items()} | |||
data_bundle = DataBundle(datasets=datasets) | |||
return data_bundle | |||
def download(self): | |||
""" | |||
如果您的文章使用了这份数据,请引用 | |||
http://nlp.stanford.edu/pubs/snli_paper.pdf | |||
:return: str | |||
""" | |||
return self._get_dataset_path('snli') | |||
class QNLILoader(JsonLoader): | |||
""" | |||
QNLI数据集的Loader, | |||
加载的DataSet将具备以下的field, raw_words1是question, raw_words2是sentence, target是label | |||
.. csv-table:: | |||
:header: "raw_words1", "raw_words2", "target" | |||
"What came into force after the new...", "As of that day...", "entailment" | |||
"What is the first major...", "The most important tributaries", "not_entailment" | |||
"...","." | |||
test数据集没有target列 | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() # 跳过header | |||
if path.endswith("test.tsv"): | |||
warnings.warn("QNLI's test file has no target.") | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split('\t') | |||
raw_words1 = parts[1] | |||
raw_words2 = parts[2] | |||
if raw_words1 and raw_words2: | |||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2)) | |||
else: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split('\t') | |||
raw_words1 = parts[1] | |||
raw_words2 = parts[2] | |||
target = parts[-1] | |||
if raw_words1 and raw_words2 and target: | |||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | |||
return ds | |||
def download(self): | |||
""" | |||
如果您的实验使用到了该数据,请引用 | |||
TODO 补充 | |||
:return: | |||
""" | |||
return self._get_dataset_path('qnli') | |||
class RTELoader(Loader): | |||
""" | |||
RTE数据的loader | |||
加载的DataSet将具备以下的field, raw_words1是sentence0,raw_words2是sentence1, target是label | |||
.. csv-table:: | |||
:header: "raw_words1", "raw_words2", "target" | |||
"Dana Reeve, the widow of the actor...", "Christopher Reeve had an...", "not_entailment" | |||
"Yet, we now are discovering that...", "Bacteria is winning...", "entailment" | |||
"...","." | |||
test数据集没有target列 | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path:str): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() # 跳过header | |||
if path.endswith("test.tsv"): | |||
warnings.warn("RTE's test file has no target.") | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split('\t') | |||
raw_words1 = parts[1] | |||
raw_words2 = parts[2] | |||
if raw_words1 and raw_words2: | |||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2)) | |||
else: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split('\t') | |||
raw_words1 = parts[1] | |||
raw_words2 = parts[2] | |||
target = parts[-1] | |||
if raw_words1 and raw_words2 and target: | |||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | |||
return ds | |||
def download(self): | |||
return self._get_dataset_path('rte') | |||
class QuoraLoader(Loader): | |||
""" | |||
Quora matching任务的数据集Loader | |||
支持读取的文件中的内容,应该有以下的形式, 以制表符分隔,且前三列的内容必须是:第一列是label,第二列和第三列是句子 | |||
Example:: | |||
1 How do I get funding for my web based startup idea ? How do I get seed funding pre product ? 327970 | |||
1 How can I stop my depression ? What can I do to stop being depressed ? 339556 | |||
... | |||
加载的DataSet将具备以下的field | |||
.. csv-table:: | |||
:header: "raw_words1", "raw_words2", "target" | |||
"What should I do to avoid...", "1" | |||
"How do I not sleep in a boring class...", "0" | |||
"...","." | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path:str): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split('\t') | |||
raw_words1 = parts[1] | |||
raw_words2 = parts[2] | |||
target = parts[0] | |||
if raw_words1 and raw_words2 and target: | |||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | |||
return ds | |||
def download(self): | |||
raise RuntimeError("Quora cannot be downloaded automatically.") |
@@ -0,0 +1,37 @@ | |||
""" | |||
Pipe用于处理通过 Loader 读取的数据,所有的 Pipe 都包含 ``process`` 和 ``process_from_file`` 两种方法。 | |||
``process(data_bundle)`` 传入一个 :class:`~fastNLP.io.DataBundle` 类型的对象, 在传入的 `data_bundle` 上进行原位修改,并将其返回; | |||
``process_from_file(paths)`` 传入的文件路径,返回一个 :class:`~fastNLP.io.DataBundle` 类型的对象。 | |||
``process(data_bundle)`` 或者 ``process_from_file(paths)`` 的返回 `data_bundle` 中的 :class:`~fastNLP.DataSet` | |||
一般都包含原文与转换为index的输入以及转换为index的target;除了 :class:`~fastNLP.DataSet` 之外, | |||
`data_bundle` 还会包含将field转为index时所建立的词表。 | |||
""" | |||
__all__ = [ | |||
"YelpFullPipe", | |||
"YelpPolarityPipe", | |||
"SSTPipe", | |||
"SST2Pipe", | |||
"IMDBPipe", | |||
"Conll2003NERPipe", | |||
"OntoNotesNERPipe", | |||
"MatchingBertPipe", | |||
"RTEBertPipe", | |||
"SNLIBertPipe", | |||
"QuoraBertPipe", | |||
"QNLIBertPipe", | |||
"MNLIBertPipe", | |||
"MatchingPipe", | |||
"RTEPipe", | |||
"SNLIPipe", | |||
"QuoraPipe", | |||
"QNLIPipe", | |||
"MNLIPipe", | |||
] | |||
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe | |||
from .conll import Conll2003NERPipe, OntoNotesNERPipe | |||
from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ | |||
MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe |
@@ -0,0 +1,444 @@ | |||
from nltk import Tree | |||
from ..base_loader import DataBundle | |||
from ...core.vocabulary import Vocabulary | |||
from ...core.const import Const | |||
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance | |||
from .pipe import Pipe | |||
import re | |||
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | |||
from ...core.utils import cache_results | |||
class _CLSPipe(Pipe): | |||
""" | |||
分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列 | |||
""" | |||
def __init__(self, tokenizer:str='spacy', lang='en'): | |||
self.tokenizer = get_tokenizer(tokenizer, lang=lang) | |||
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): | |||
""" | |||
将DataBundle中的数据进行tokenize | |||
:param DataBundle data_bundle: | |||
:param str field_name: | |||
:param str new_field_name: | |||
:return: 传入的DataBundle对象 | |||
""" | |||
new_field_name = new_field_name or field_name | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) | |||
return data_bundle | |||
def _granularize(self, data_bundle, tag_map): | |||
""" | |||
该函数对data_bundle中'target'列中的内容进行转换。 | |||
:param data_bundle: | |||
:param dict tag_map: 将target列中的tag做以下的映射,比如{"0":0, "1":0, "3":1, "4":1}, 则会删除target为"2"的instance, | |||
且将"1"认为是第0类。 | |||
:return: 传入的data_bundle | |||
""" | |||
for name in list(data_bundle.datasets.keys()): | |||
dataset = data_bundle.get_dataset(name) | |||
dataset.apply_field(lambda target:tag_map.get(target, -100), field_name=Const.TARGET, | |||
new_field_name=Const.TARGET) | |||
dataset.drop(lambda ins:ins[Const.TARGET] == -100) | |||
data_bundle.set_dataset(dataset, name) | |||
return data_bundle | |||
def _clean_str(words): | |||
""" | |||
heavily borrowed from github | |||
https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb | |||
:param sentence: is a str | |||
:return: | |||
""" | |||
words_collection = [] | |||
for word in words: | |||
if word in ['-lrb-', '-rrb-', '<sssss>', '-r', '-l', 'b-']: | |||
continue | |||
tt = nonalpnum.split(word) | |||
t = ''.join(tt) | |||
if t != '': | |||
words_collection.append(t) | |||
return words_collection | |||
class YelpFullPipe(_CLSPipe): | |||
""" | |||
处理YelpFull的数据, 处理之后DataSet中的内容如下 | |||
.. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field | |||
:header: "raw_words", "words", "target", "seq_len" | |||
"It 's a ...", "[4, 2, 10, ...]", 0, 10 | |||
"Offers that ...", "[20, 40, ...]", 1, 21 | |||
"...", "[...]", ., . | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将1、2归为1类,4、5归为一类,丢掉2;若为3, 则有3分类问题,将 | |||
1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
""" | |||
def __init__(self, lower:bool=False, granularity=5, tokenizer:str='spacy'): | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
self.lower = lower | |||
assert granularity in (2, 3, 5), "granularity can only be 2,3,5." | |||
self.granularity = granularity | |||
if granularity==2: | |||
self.tag_map = {"1": 0, "2": 0, "4": 1, "5": 1} | |||
elif granularity==3: | |||
self.tag_map = {"1": 0, "2": 0, "3":1, "4": 2, "5": 2} | |||
else: | |||
self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4} | |||
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): | |||
""" | |||
将DataBundle中的数据进行tokenize | |||
:param DataBundle data_bundle: | |||
:param str field_name: | |||
:param str new_field_name: | |||
:return: 传入的DataBundle对象 | |||
""" | |||
new_field_name = new_field_name or field_name | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) | |||
dataset.apply_field(_clean_str, field_name=field_name, new_field_name=new_field_name) | |||
return data_bundle | |||
def process(self, data_bundle): | |||
""" | |||
传入的DataSet应该具备如下的结构 | |||
.. csv-table:: | |||
:header: "raw_words", "target" | |||
"I got 'new' tires from them and... ", "1" | |||
"Don't waste your time. We had two...", "1" | |||
"...", "..." | |||
:param data_bundle: | |||
:return: | |||
""" | |||
# 复制一列words | |||
data_bundle = _add_words_field(data_bundle, lower=self.lower) | |||
# 进行tokenize | |||
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) | |||
# 根据granularity设置tag | |||
data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) | |||
# 删除空行 | |||
data_bundle = _drop_empty_instance(data_bundle, field_name=Const.INPUT) | |||
# index | |||
data_bundle = _indexize(data_bundle=data_bundle) | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.INPUT) | |||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | |||
data_bundle.set_target(Const.TARGET) | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
""" | |||
:param paths: | |||
:return: DataBundle | |||
""" | |||
data_bundle = YelpFullLoader().load(paths) | |||
return self.process(data_bundle=data_bundle) | |||
class YelpPolarityPipe(_CLSPipe): | |||
""" | |||
处理YelpPolarity的数据, 处理之后DataSet中的内容如下 | |||
.. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field | |||
:header: "raw_words", "words", "target", "seq_len" | |||
"It 's a ...", "[4, 2, 10, ...]", 0, 10 | |||
"Offers that ...", "[20, 40, ...]", 1, 21 | |||
"...", "[...]", ., . | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
""" | |||
def __init__(self, lower:bool=False, tokenizer:str='spacy'): | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
self.lower = lower | |||
def process(self, data_bundle): | |||
# 复制一列words | |||
data_bundle = _add_words_field(data_bundle, lower=self.lower) | |||
# 进行tokenize | |||
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) | |||
# index | |||
data_bundle = _indexize(data_bundle=data_bundle) | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.INPUT) | |||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | |||
data_bundle.set_target(Const.TARGET) | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
""" | |||
:param str paths: | |||
:return: DataBundle | |||
""" | |||
data_bundle = YelpPolarityLoader().load(paths) | |||
return self.process(data_bundle=data_bundle) | |||
class SSTPipe(_CLSPipe): | |||
""" | |||
别名::class:`fastNLP.io.SSTPipe` :class:`fastNLP.io.pipe.SSTPipe` | |||
经过该Pipe之后,DataSet中具备的field如下所示 | |||
.. csv-table:: 下面是使用SSTPipe处理后的DataSet所具备的field | |||
:header: "raw_words", "words", "target", "seq_len" | |||
"It 's a ...", "[4, 2, 10, ...]", 0, 16 | |||
"Offers that ...", "[20, 40, ...]", 1, 18 | |||
"...", "[...]", ., . | |||
:param bool subtree: 是否将train, test, dev数据展开为子树,扩充数据量。 Default: ``False`` | |||
:param bool train_subtree: 是否将train集通过子树扩展数据。 | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将0、1归为1类,3、4归为一类,丢掉2;若为3, 则有3分类问题,将 | |||
0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
""" | |||
def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'): | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
self.subtree = subtree | |||
self.train_tree = train_subtree | |||
self.lower = lower | |||
assert granularity in (2, 3, 5), "granularity can only be 2,3,5." | |||
self.granularity = granularity | |||
if granularity==2: | |||
self.tag_map = {"0": 0, "1": 0, "3": 1, "4": 1} | |||
elif granularity==3: | |||
self.tag_map = {"0": 0, "1": 0, "2":1, "3": 2, "4": 2} | |||
else: | |||
self.tag_map = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4} | |||
def process(self, data_bundle:DataBundle): | |||
""" | |||
对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 | |||
.. csv-table:: | |||
:header: "raw_words" | |||
"(3 (2 It) (4 (4 (2 's) (4 (3 (2 a)..." | |||
"(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..." | |||
"..." | |||
:param ~fastNLP.io.DataBundle data_bundle: 需要处理的DataBundle对象 | |||
:return: | |||
""" | |||
# 先取出subtree | |||
for name in list(data_bundle.datasets.keys()): | |||
dataset = data_bundle.get_dataset(name) | |||
ds = DataSet() | |||
use_subtree = self.subtree or (name == 'train' and self.train_tree) | |||
for ins in dataset: | |||
raw_words = ins['raw_words'] | |||
tree = Tree.fromstring(raw_words) | |||
if use_subtree: | |||
for t in tree.subtrees(): | |||
raw_words = " ".join(t.leaves()) | |||
instance = Instance(raw_words=raw_words, target=t.label()) | |||
ds.append(instance) | |||
else: | |||
instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label()) | |||
ds.append(instance) | |||
data_bundle.set_dataset(ds, name) | |||
_add_words_field(data_bundle, lower=self.lower) | |||
# 进行tokenize | |||
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) | |||
# 根据granularity设置tag | |||
data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) | |||
# index | |||
data_bundle = _indexize(data_bundle=data_bundle) | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.INPUT) | |||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | |||
data_bundle.set_target(Const.TARGET) | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
data_bundle = SSTLoader().load(paths) | |||
return self.process(data_bundle=data_bundle) | |||
class SST2Pipe(_CLSPipe): | |||
""" | |||
加载SST2的数据, 处理完成之后DataSet将拥有以下的field | |||
.. csv-table:: | |||
:header: "raw_words", "words", "target", "seq_len" | |||
"it 's a charming and... ", "[3, 4, 5, 6, 7,...]", 1, 43 | |||
"unflinchingly bleak and...", "[10, 11, 7,...]", 1, 21 | |||
"...", "...", ., . | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
""" | |||
def __init__(self, lower=False, tokenizer='spacy'): | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
self.lower = lower | |||
def process(self, data_bundle:DataBundle): | |||
""" | |||
可以处理的DataSet应该具备如下的结构 | |||
.. csv-table:: | |||
:header: "raw_words", "target" | |||
"it 's a charming and... ", 1 | |||
"unflinchingly bleak and...", 1 | |||
"...", "..." | |||
:param data_bundle: | |||
:return: | |||
""" | |||
_add_words_field(data_bundle, self.lower) | |||
data_bundle = self._tokenize(data_bundle=data_bundle) | |||
src_vocab = Vocabulary() | |||
src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT, | |||
no_create_entry_dataset=[dataset for name,dataset in data_bundle.datasets.items() if | |||
name != 'train']) | |||
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) | |||
tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) | |||
datasets = [] | |||
for name, dataset in data_bundle.datasets.items(): | |||
if dataset.has_field(Const.TARGET): | |||
datasets.append(dataset) | |||
tgt_vocab.index_dataset(*datasets, field_name=Const.TARGET) | |||
data_bundle.set_vocab(src_vocab, Const.INPUT) | |||
data_bundle.set_vocab(tgt_vocab, Const.TARGET) | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.INPUT) | |||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | |||
data_bundle.set_target(Const.TARGET) | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
""" | |||
:param str paths: 如果为None,则自动下载并缓存到fastNLP的缓存地址。 | |||
:return: DataBundle | |||
""" | |||
data_bundle = SST2Loader().load(paths) | |||
return self.process(data_bundle) | |||
class IMDBPipe(_CLSPipe): | |||
""" | |||
经过本Pipe处理后DataSet将如下 | |||
.. csv-table:: 输出DataSet的field | |||
:header: "raw_words", "words", "target", "seq_len" | |||
"Bromwell High is a cartoon ... ", "[3, 5, 6, 9, ...]", 0, 20 | |||
"Story of a man who has ...", "[20, 43, 9, 10, ...]", 1, 31 | |||
"...", "[...]", ., . | |||
其中raw_words为str类型,是原文; words是转换为index的输入; target是转换为index的目标值; | |||
words列被设置为input; target列被设置为target。 | |||
:param bool lower: 是否将words列的数据小写。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
""" | |||
def __init__(self, lower:bool=False, tokenizer:str='spacy'): | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
self.lower = lower | |||
def process(self, data_bundle:DataBundle): | |||
""" | |||
期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型 | |||
.. csv-table:: 输入DataSet的field | |||
:header: "raw_words", "target" | |||
"Bromwell High is a cartoon ... ", "pos" | |||
"Story of a man who has ...", "neg" | |||
"...", "..." | |||
:param DataBunlde data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和target两个field,且raw_words列应该为str, | |||
target列应该为str。 | |||
:return: DataBundle | |||
""" | |||
# 替换<br /> | |||
def replace_br(raw_words): | |||
raw_words = raw_words.replace("<br />", ' ') | |||
return raw_words | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(replace_br, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) | |||
_add_words_field(data_bundle, lower=self.lower) | |||
self._tokenize(data_bundle, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||
_indexize(data_bundle) | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.INPUT) | |||
dataset.set_input(Const.INPUT, Const.INPUT_LEN) | |||
dataset.set_target(Const.TARGET) | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | |||
:return: DataBundle | |||
""" | |||
# 读取数据 | |||
data_bundle = IMDBLoader().load(paths) | |||
data_bundle = self.process(data_bundle) | |||
return data_bundle | |||
@@ -0,0 +1,148 @@ | |||
from .pipe import Pipe | |||
from .. import DataBundle | |||
from .utils import iob2, iob2bioes | |||
from ...core.const import Const | |||
from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader | |||
from .utils import _indexize, _add_words_field | |||
class _NERPipe(Pipe): | |||
""" | |||
NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 | |||
(创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 | |||
Vocabulary转换为index。 | |||
raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||
target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target, seq_len。 | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | |||
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。 | |||
""" | |||
def __init__(self, encoding_type: str = 'bio', lower: bool = False, target_pad_val=0): | |||
if encoding_type == 'bio': | |||
self.convert_tag = iob2 | |||
else: | |||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||
self.lower = lower | |||
self.target_pad_val = int(target_pad_val) | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
""" | |||
支持的DataSet的field为 | |||
.. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader | |||
:header: "raw_words", "target" | |||
"[Nadim, Ladki]", "[B-PER, I-PER]" | |||
"[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" | |||
"[...]", "[...]" | |||
:param DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。 | |||
在传入DataBundle基础上原位修改。 | |||
:return: DataBundle | |||
Example:: | |||
data_bundle = Conll2003Loader().load('/path/to/conll2003/') | |||
data_bundle = Conll2003NERPipe().process(data_bundle) | |||
# 获取train | |||
tr_data = data_bundle.get_dataset('train') | |||
# 获取target这个field的词表 | |||
target_vocab = data_bundle.get_vocab('target') | |||
# 获取words这个field的词表 | |||
word_vocab = data_bundle.get_vocab('words') | |||
""" | |||
# 转换tag | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) | |||
_add_words_field(data_bundle, lower=self.lower) | |||
# index | |||
_indexize(data_bundle) | |||
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] | |||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.set_pad_val(Const.TARGET, self.target_pad_val) | |||
dataset.add_seq_len(Const.INPUT) | |||
data_bundle.set_input(*input_fields) | |||
data_bundle.set_target(*target_fields) | |||
return data_bundle | |||
def process_from_file(self, paths) -> DataBundle: | |||
""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.ConllLoader` 的load函数。 | |||
:return: DataBundle | |||
""" | |||
# 读取数据 | |||
data_bundle = Conll2003NERLoader().load(paths) | |||
data_bundle = self.process(data_bundle) | |||
return data_bundle | |||
class Conll2003NERPipe(_NERPipe): | |||
""" | |||
Conll2003的NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 | |||
(创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 | |||
Vocabulary转换为index。 | |||
经过该Pipe过后,DataSet中的内容如下所示 | |||
.. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader | |||
:header: "raw_words", "words", "target", "seq_len" | |||
"[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2 | |||
"[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 10 | |||
"[...]", "[...]", "[...]", . | |||
raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||
target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | |||
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。 | |||
""" | |||
def process_from_file(self, paths) -> DataBundle: | |||
""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.ConllLoader` 的load函数。 | |||
:return: DataBundle | |||
""" | |||
# 读取数据 | |||
data_bundle = Conll2003NERLoader().load(paths) | |||
data_bundle = self.process(data_bundle) | |||
return data_bundle | |||
class OntoNotesNERPipe(_NERPipe): | |||
""" | |||
处理OntoNotes的NER数据,处理之后DataSet中的field情况为 | |||
.. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader | |||
:header: "raw_words", "words", "target", "seq_len" | |||
"[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2 | |||
"[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 6 | |||
"[...]", "[...]", "[...]", . | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | |||
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。 | |||
""" | |||
def process_from_file(self, paths): | |||
data_bundle = OntoNotesNERLoader().load(paths) | |||
return self.process(data_bundle) |
@@ -0,0 +1,252 @@ | |||
from .pipe import Pipe | |||
from .utils import get_tokenizer | |||
from ...core.const import Const | |||
from ...core.vocabulary import Vocabulary | |||
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader | |||
class MatchingBertPipe(Pipe): | |||
""" | |||
Matching任务的Bert pipe,输出的DataSet将包含以下的field | |||
.. csv-table:: | |||
:header: "raw_words1", "raw_words2", "words", "target", "seq_len" | |||
"The new rights are...", "Everyone really likes..", "[2, 3, 4, 5, ...]", 1, 10 | |||
"This site includes a...", "The Government Executive...", "[11, 12, 13,...]", 0, 5 | |||
"...", "...", "[...]", ., . | |||
words列是将raw_words1(即premise), raw_words2(即hypothesis)使用"[SEP]"链接起来转换为index的。 | |||
words列被设置为input,target列被设置为target和input(设置为input以方便在forward函数中计算loss, | |||
如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数的形参名进行传参). | |||
:param bool lower: 是否将word小写化。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
""" | |||
def __init__(self, lower=False, tokenizer: str='raw'): | |||
super().__init__() | |||
self.lower = bool(lower) | |||
self.tokenizer = get_tokenizer(tokenizer=tokenizer) | |||
def _tokenize(self, data_bundle, field_names, new_field_names): | |||
""" | |||
:param DataBundle data_bundle: DataBundle. | |||
:param list field_names: List[str], 需要tokenize的field名称 | |||
:param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。 | |||
:return: 输入的DataBundle对象 | |||
""" | |||
for name, dataset in data_bundle.datasets.items(): | |||
for field_name, new_field_name in zip(field_names, new_field_names): | |||
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, | |||
new_field_name=new_field_name) | |||
return data_bundle | |||
def process(self, data_bundle): | |||
for dataset in data_bundle.datasets.values(): | |||
if dataset.has_field(Const.TARGET): | |||
dataset.drop(lambda x: x[Const.TARGET] == '-') | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0)) | |||
dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1)) | |||
if self.lower: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset[Const.INPUTS(0)].lower() | |||
dataset[Const.INPUTS(1)].lower() | |||
data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUTS(1)], | |||
[Const.INPUTS(0), Const.INPUTS(1)]) | |||
# concat两个words | |||
def concat(ins): | |||
words0 = ins[Const.INPUTS(0)] | |||
words1 = ins[Const.INPUTS(1)] | |||
words = words0 + ['[SEP]'] + words1 | |||
return words | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply(concat, new_field_name=Const.INPUT) | |||
dataset.delete_field(Const.INPUTS(0)) | |||
dataset.delete_field(Const.INPUTS(1)) | |||
word_vocab = Vocabulary() | |||
word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], | |||
field_name=Const.INPUT, | |||
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if | |||
'train' not in name]) | |||
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) | |||
target_vocab = Vocabulary(padding=None, unknown=None) | |||
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) | |||
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | |||
dataset.has_field(Const.TARGET)] | |||
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) | |||
data_bundle.set_vocab(word_vocab, Const.INPUT) | |||
data_bundle.set_vocab(target_vocab, Const.TARGET) | |||
input_fields = [Const.INPUT, Const.INPUT_LEN, Const.TARGET] | |||
target_fields = [Const.TARGET] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.INPUT) | |||
dataset.set_input(*input_fields, flag=True) | |||
dataset.set_target(*target_fields, flag=True) | |||
return data_bundle | |||
class RTEBertPipe(MatchingBertPipe): | |||
def process_from_file(self, paths=None): | |||
data_bundle = RTELoader().load(paths) | |||
return self.process(data_bundle) | |||
class SNLIBertPipe(MatchingBertPipe): | |||
def process_from_file(self, paths=None): | |||
data_bundle = SNLILoader().load(paths) | |||
return self.process(data_bundle) | |||
class QuoraBertPipe(MatchingBertPipe): | |||
def process_from_file(self, paths): | |||
data_bundle = QuoraLoader().load(paths) | |||
return self.process(data_bundle) | |||
class QNLIBertPipe(MatchingBertPipe): | |||
def process_from_file(self, paths=None): | |||
data_bundle = QNLILoader().load(paths) | |||
return self.process(data_bundle) | |||
class MNLIBertPipe(MatchingBertPipe): | |||
def process_from_file(self, paths=None): | |||
data_bundle = MNLILoader().load(paths) | |||
return self.process(data_bundle) | |||
class MatchingPipe(Pipe): | |||
""" | |||
Matching任务的Pipe。输出的DataSet将包含以下的field | |||
.. csv-table:: | |||
:header: "raw_words1", "raw_words2", "words1", "words2", "target", "seq_len1", "seq_len2" | |||
"The new rights are...", "Everyone really likes..", "[2, 3, 4, 5, ...]", "[10, 20, 6]", 1, 10, 13 | |||
"This site includes a...", "The Government Executive...", "[11, 12, 13,...]", "[2, 7, ...]", 0, 6, 7 | |||
"...", "...", "[...]", "[...]", ., ., . | |||
words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target | |||
和input(设置为input以方便在forward函数中计算loss,如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数 | |||
的形参名进行传参)。 | |||
:param bool lower: 是否将所有raw_words转为小写。 | |||
:param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。 | |||
""" | |||
def __init__(self, lower=False, tokenizer: str='raw'): | |||
super().__init__() | |||
self.lower = bool(lower) | |||
self.tokenizer = get_tokenizer(tokenizer=tokenizer) | |||
def _tokenize(self, data_bundle, field_names, new_field_names): | |||
""" | |||
:param DataBundle data_bundle: DataBundle. | |||
:param list field_names: List[str], 需要tokenize的field名称 | |||
:param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。 | |||
:return: 输入的DataBundle对象 | |||
""" | |||
for name, dataset in data_bundle.datasets.items(): | |||
for field_name, new_field_name in zip(field_names, new_field_names): | |||
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, | |||
new_field_name=new_field_name) | |||
return data_bundle | |||
def process(self, data_bundle): | |||
""" | |||
接受的DataBundle中的DataSet应该具有以下的field, target列可以没有 | |||
.. csv-table:: | |||
:header: "raw_words1", "raw_words2", "target" | |||
"The new rights are...", "Everyone really likes..", "entailment" | |||
"This site includes a...", "The Government Executive...", "not_entailment" | |||
"...", "..." | |||
:param data_bundle: | |||
:return: | |||
""" | |||
data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)], | |||
[Const.INPUTS(0), Const.INPUTS(1)]) | |||
for dataset in data_bundle.datasets.values(): | |||
if dataset.has_field(Const.TARGET): | |||
dataset.drop(lambda x: x[Const.TARGET] == '-') | |||
if self.lower: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset[Const.INPUTS(0)].lower() | |||
dataset[Const.INPUTS(1)].lower() | |||
word_vocab = Vocabulary() | |||
word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], | |||
field_name=[Const.INPUTS(0), Const.INPUTS(1)], | |||
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if | |||
'train' not in name]) | |||
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)]) | |||
target_vocab = Vocabulary(padding=None, unknown=None) | |||
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) | |||
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | |||
dataset.has_field(Const.TARGET)] | |||
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) | |||
data_bundle.set_vocab(word_vocab, Const.INPUTS(0)) | |||
data_bundle.set_vocab(target_vocab, Const.TARGET) | |||
input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET] | |||
target_fields = [Const.TARGET] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0)) | |||
dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1)) | |||
dataset.set_input(*input_fields, flag=True) | |||
dataset.set_target(*target_fields, flag=True) | |||
return data_bundle | |||
class RTEPipe(MatchingPipe): | |||
def process_from_file(self, paths=None): | |||
data_bundle = RTELoader().load(paths) | |||
return self.process(data_bundle) | |||
class SNLIPipe(MatchingPipe): | |||
def process_from_file(self, paths=None): | |||
data_bundle = SNLILoader().load(paths) | |||
return self.process(data_bundle) | |||
class QuoraPipe(MatchingPipe): | |||
def process_from_file(self, paths): | |||
data_bundle = QuoraLoader().load(paths) | |||
return self.process(data_bundle) | |||
class QNLIPipe(MatchingPipe): | |||
def process_from_file(self, paths=None): | |||
data_bundle = QNLILoader().load(paths) | |||
return self.process(data_bundle) | |||
class MNLIPipe(MatchingPipe): | |||
def process_from_file(self, paths=None): | |||
data_bundle = MNLILoader().load(paths) | |||
return self.process(data_bundle) | |||
@@ -0,0 +1,9 @@ | |||
from .. import DataBundle | |||
class Pipe: | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
raise NotImplementedError | |||
def process_from_file(self, paths) -> DataBundle: | |||
raise NotImplementedError |
@@ -0,0 +1,142 @@ | |||
from typing import List | |||
from ...core.vocabulary import Vocabulary | |||
from ...core.const import Const | |||
def iob2(tags:List[str])->List[str]: | |||
""" | |||
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format | |||
:param tags: 需要转换的tags | |||
""" | |||
for i, tag in enumerate(tags): | |||
if tag == "O": | |||
continue | |||
split = tag.split("-") | |||
if len(split) != 2 or split[0] not in ["I", "B"]: | |||
raise TypeError("The encoding schema is not a valid IOB type.") | |||
if split[0] == "B": | |||
continue | |||
elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2 | |||
tags[i] = "B" + tag[1:] | |||
elif tags[i - 1][1:] == tag[1:]: | |||
continue | |||
else: # conversion IOB1 to IOB2 | |||
tags[i] = "B" + tag[1:] | |||
return tags | |||
def iob2bioes(tags:List[str])->List[str]: | |||
""" | |||
将iob的tag转换为bioes编码 | |||
:param tags: | |||
:return: | |||
""" | |||
new_tags = [] | |||
for i, tag in enumerate(tags): | |||
if tag == 'O': | |||
new_tags.append(tag) | |||
else: | |||
split = tag.split('-')[0] | |||
if split == 'B': | |||
if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I': | |||
new_tags.append(tag) | |||
else: | |||
new_tags.append(tag.replace('B-', 'S-')) | |||
elif split == 'I': | |||
if i + 1<len(tags) and tags[i+1].split('-')[0] == 'I': | |||
new_tags.append(tag) | |||
else: | |||
new_tags.append(tag.replace('I-', 'E-')) | |||
else: | |||
raise TypeError("Invalid IOB format.") | |||
return new_tags | |||
def get_tokenizer(tokenizer:str, lang='en'): | |||
""" | |||
:param str tokenizer: 获取tokenzier方法 | |||
:param str lang: 语言,当前仅支持en | |||
:return: 返回tokenize函数 | |||
""" | |||
if tokenizer == 'spacy': | |||
import spacy | |||
spacy.prefer_gpu() | |||
if lang != 'en': | |||
raise RuntimeError("Spacy only supports en right right.") | |||
en = spacy.load(lang) | |||
tokenizer = lambda x: [w.text for w in en.tokenizer(x)] | |||
elif tokenizer == 'raw': | |||
tokenizer = _raw_split | |||
else: | |||
raise RuntimeError("Only support `spacy`, `raw` tokenizer.") | |||
return tokenizer | |||
def _raw_split(sent): | |||
return sent.split() | |||
def _indexize(data_bundle): | |||
""" | |||
在dataset中的"words"列建立词表,"target"列建立词表,并把词表加入到data_bundle中。 | |||
:param data_bundle: | |||
:return: | |||
""" | |||
src_vocab = Vocabulary() | |||
src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT, | |||
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if | |||
name != 'train']) | |||
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) | |||
tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) | |||
tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.TARGET) | |||
data_bundle.set_vocab(src_vocab, Const.INPUT) | |||
data_bundle.set_vocab(tgt_vocab, Const.TARGET) | |||
return data_bundle | |||
def _add_words_field(data_bundle, lower=False): | |||
""" | |||
给data_bundle中的dataset中复制一列words. 并根据lower参数判断是否需要小写化 | |||
:param data_bundle: | |||
:param bool lower:是否要小写化 | |||
:return: 传入的DataBundle | |||
""" | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.copy_field(field_name=Const.RAW_WORD, new_field_name=Const.INPUT) | |||
if lower: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset[Const.INPUT].lower() | |||
return data_bundle | |||
def _drop_empty_instance(data_bundle, field_name): | |||
""" | |||
删除data_bundle的DataSet中存在的某个field为空的情况 | |||
:param data_bundle: DataBundle | |||
:param str field_name: 对哪个field进行检查,如果为None,则任意field为空都会删掉 | |||
:return: 传入的DataBundle | |||
""" | |||
def empty_instance(ins): | |||
if field_name: | |||
field_value = ins[field_name] | |||
if field_value in ((), {}, [], ''): | |||
return True | |||
return False | |||
for _, field_value in ins.items(): | |||
if field_value in ((), {}, [], ''): | |||
return True | |||
return False | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.drop(empty_instance) | |||
return data_bundle | |||
@@ -1,23 +1,27 @@ | |||
import os | |||
from typing import Union, Dict | |||
from pathlib import Path | |||
def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||
def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||
""" | |||
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果 | |||
{ | |||
'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 | |||
'test': 'xxx' # 可能有,也可能没有 | |||
... | |||
} | |||
如果paths为不合法的,将直接进行raise相应的错误 | |||
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果:: | |||
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 | |||
{ | |||
'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 | |||
'test': 'xxx' # 可能有,也可能没有 | |||
... | |||
} | |||
如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。 | |||
:param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 | |||
中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 | |||
:return: | |||
""" | |||
if isinstance(paths, str): | |||
if isinstance(paths, (str, Path)): | |||
paths = os.path.abspath(os.path.expanduser(paths)) | |||
if os.path.isfile(paths): | |||
return {'train': paths} | |||
elif os.path.isdir(paths): | |||
@@ -37,6 +41,8 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||
path_pair = ('test', filename) | |||
if path_pair: | |||
files[path_pair[0]] = os.path.join(paths, path_pair[1]) | |||
if 'train' not in files: | |||
raise KeyError(f"There is no train file in {paths}.") | |||
return files | |||
else: | |||
raise FileNotFoundError(f"{paths} is not a valid file path.") | |||
@@ -47,8 +53,10 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||
raise KeyError("You have to include `train` in your dict.") | |||
for key, value in paths.items(): | |||
if isinstance(key, str) and isinstance(value, str): | |||
value = os.path.abspath(os.path.expanduser(value)) | |||
if not os.path.isfile(value): | |||
raise TypeError(f"{value} is not a valid file.") | |||
paths[key] = value | |||
else: | |||
raise TypeError("All keys and values in paths should be str.") | |||
return paths | |||
@@ -2,13 +2,14 @@ | |||
bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. | |||
""" | |||
import os | |||
import torch | |||
from torch import nn | |||
from .base_model import BaseModel | |||
from ..core.const import Const | |||
from ..modules.encoder import BertModel | |||
from ..modules.encoder.bert import BertConfig | |||
from ..modules.encoder.bert import BertConfig, CONFIG_FILE | |||
class BertForSequenceClassification(BaseModel): | |||
@@ -54,6 +55,7 @@ class BertForSequenceClassification(BaseModel): | |||
self.num_labels = num_labels | |||
if bert_dir is not None: | |||
self.bert = BertModel.from_pretrained(bert_dir) | |||
config = BertConfig(os.path.join(bert_dir, CONFIG_FILE)) | |||
else: | |||
if config is None: | |||
config = BertConfig(30522) | |||
@@ -67,20 +69,20 @@ class BertForSequenceClassification(BaseModel): | |||
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) | |||
return model | |||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | |||
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | |||
def forward(self, words, seq_len=None, target=None): | |||
_, pooled_output = self.bert(words, attention_mask=seq_len, output_all_encoded_layers=False) | |||
pooled_output = self.dropout(pooled_output) | |||
logits = self.classifier(pooled_output) | |||
if labels is not None: | |||
if target is not None: | |||
loss_fct = nn.CrossEntropyLoss() | |||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |||
loss = loss_fct(logits, target) | |||
return {Const.OUTPUT: logits, Const.LOSS: loss} | |||
else: | |||
return {Const.OUTPUT: logits} | |||
def predict(self, input_ids, token_type_ids=None, attention_mask=None): | |||
logits = self.forward(input_ids, token_type_ids, attention_mask) | |||
def predict(self, words, seq_len=None): | |||
logits = self.forward(words, seq_len=seq_len)[Const.OUTPUT] | |||
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} | |||
@@ -140,7 +142,8 @@ class BertForMultipleChoice(BaseModel): | |||
model = cls(num_choices=num_choices, config=config, bert_dir=pretrained_model_dir) | |||
return model | |||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | |||
def forward(self, words, seq_len1=None, seq_len2=None, target=None): | |||
input_ids, token_type_ids, attention_mask = words, seq_len1, seq_len2 | |||
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) | |||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) | |||
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) | |||
@@ -149,15 +152,15 @@ class BertForMultipleChoice(BaseModel): | |||
logits = self.classifier(pooled_output) | |||
reshaped_logits = logits.view(-1, self.num_choices) | |||
if labels is not None: | |||
if target is not None: | |||
loss_fct = nn.CrossEntropyLoss() | |||
loss = loss_fct(reshaped_logits, labels) | |||
loss = loss_fct(reshaped_logits, target) | |||
return {Const.OUTPUT: reshaped_logits, Const.LOSS: loss} | |||
else: | |||
return {Const.OUTPUT: reshaped_logits} | |||
def predict(self, input_ids, token_type_ids=None, attention_mask=None): | |||
logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT] | |||
def predict(self, words, seq_len1=None, seq_len2=None,): | |||
logits = self.forward(words, seq_len1=seq_len1, seq_len2=seq_len2)[Const.OUTPUT] | |||
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} | |||
@@ -219,27 +222,27 @@ class BertForTokenClassification(BaseModel): | |||
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) | |||
return model | |||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | |||
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | |||
def forward(self, words, seq_len1=None, seq_len2=None, target=None): | |||
sequence_output, _ = self.bert(words, seq_len1, seq_len2, output_all_encoded_layers=False) | |||
sequence_output = self.dropout(sequence_output) | |||
logits = self.classifier(sequence_output) | |||
if labels is not None: | |||
if target is not None: | |||
loss_fct = nn.CrossEntropyLoss() | |||
# Only keep active parts of the loss | |||
if attention_mask is not None: | |||
active_loss = attention_mask.view(-1) == 1 | |||
if seq_len2 is not None: | |||
active_loss = seq_len2.view(-1) == 1 | |||
active_logits = logits.view(-1, self.num_labels)[active_loss] | |||
active_labels = labels.view(-1)[active_loss] | |||
active_labels = target.view(-1)[active_loss] | |||
loss = loss_fct(active_logits, active_labels) | |||
else: | |||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |||
loss = loss_fct(logits.view(-1, self.num_labels), target.view(-1)) | |||
return {Const.OUTPUT: logits, Const.LOSS: loss} | |||
else: | |||
return {Const.OUTPUT: logits} | |||
def predict(self, input_ids, token_type_ids=None, attention_mask=None): | |||
logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT] | |||
def predict(self, words, seq_len1=None, seq_len2=None): | |||
logits = self.forward(words, seq_len1, seq_len2)[Const.OUTPUT] | |||
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} | |||
@@ -304,34 +307,34 @@ class BertForQuestionAnswering(BaseModel): | |||
model = cls(config=config, bert_dir=pretrained_model_dir) | |||
return model | |||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): | |||
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | |||
def forward(self, words, seq_len1=None, seq_len2=None, target1=None, target2=None): | |||
sequence_output, _ = self.bert(words, seq_len1, seq_len2, output_all_encoded_layers=False) | |||
logits = self.qa_outputs(sequence_output) | |||
start_logits, end_logits = logits.split(1, dim=-1) | |||
start_logits = start_logits.squeeze(-1) | |||
end_logits = end_logits.squeeze(-1) | |||
if start_positions is not None and end_positions is not None: | |||
if target1 is not None and target2 is not None: | |||
# If we are on multi-GPU, split add a dimension | |||
if len(start_positions.size()) > 1: | |||
start_positions = start_positions.squeeze(-1) | |||
if len(end_positions.size()) > 1: | |||
end_positions = end_positions.squeeze(-1) | |||
if len(target1.size()) > 1: | |||
target1 = target1.squeeze(-1) | |||
if len(target2.size()) > 1: | |||
target2 = target2.squeeze(-1) | |||
# sometimes the start/end positions are outside our model inputs, we ignore these terms | |||
ignored_index = start_logits.size(1) | |||
start_positions.clamp_(0, ignored_index) | |||
end_positions.clamp_(0, ignored_index) | |||
target1.clamp_(0, ignored_index) | |||
target2.clamp_(0, ignored_index) | |||
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) | |||
start_loss = loss_fct(start_logits, start_positions) | |||
end_loss = loss_fct(end_logits, end_positions) | |||
start_loss = loss_fct(start_logits, target1) | |||
end_loss = loss_fct(end_logits, target2) | |||
total_loss = (start_loss + end_loss) / 2 | |||
return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits, Const.LOSS: total_loss} | |||
else: | |||
return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits} | |||
def predict(self, input_ids, token_type_ids=None, attention_mask=None, **kwargs): | |||
logits = self.forward(input_ids, token_type_ids, attention_mask) | |||
def predict(self, words, seq_len1=None, seq_len2=None): | |||
logits = self.forward(words, seq_len1, seq_len2) | |||
start_logits = logits[Const.OUTPUTS(0)] | |||
end_logits = logits[Const.OUTPUTS(1)] | |||
return {Const.OUTPUTS(0): torch.argmax(start_logits, dim=-1), | |||
@@ -9,7 +9,7 @@ import torch.nn.functional as F | |||
from torch.nn import CrossEntropyLoss | |||
from .base_model import BaseModel | |||
from ..embeddings.embedding import TokenEmbedding | |||
from ..embeddings.embedding import TokenEmbedding, Embedding | |||
from ..core.const import Const | |||
from ..core.utils import seq_len_to_mask | |||
@@ -21,18 +21,21 @@ class ESIM(BaseModel): | |||
ESIM model的一个PyTorch实现 | |||
论文参见: https://arxiv.org/pdf/1609.06038.pdf | |||
:param fastNLP.TokenEmbedding init_embedding: 初始化的TokenEmbedding | |||
:param init_embedding: 初始化的Embedding | |||
:param int hidden_size: 隐藏层大小,默认值为Embedding的维度 | |||
:param int num_labels: 目标标签种类数量,默认值为3 | |||
:param float dropout_rate: dropout的比率,默认值为0.3 | |||
:param float dropout_embed: 对Embedding的dropout比率,默认值为0.1 | |||
""" | |||
def __init__(self, init_embedding: TokenEmbedding, hidden_size=None, num_labels=3, dropout_rate=0.3, | |||
def __init__(self, init_embedding, hidden_size=None, num_labels=3, dropout_rate=0.3, | |||
dropout_embed=0.1): | |||
super(ESIM, self).__init__() | |||
self.embedding = init_embedding | |||
if isinstance(init_embedding, TokenEmbedding) or isinstance(init_embedding, Embedding): | |||
self.embedding = init_embedding | |||
else: | |||
self.embedding = Embedding(init_embedding) | |||
self.dropout_embed = EmbedDropout(p=dropout_embed) | |||
if hidden_size is None: | |||
hidden_size = self.embedding.embed_size | |||
@@ -563,6 +563,8 @@ class WordpieceTokenizer(object): | |||
output_tokens.append(self.unk_token) | |||
else: | |||
output_tokens.extend(sub_tokens) | |||
if len(output_tokens)==0: #防止里面全是空格或者回车符号 | |||
return [self.unk_token] | |||
return output_tokens | |||
@@ -848,7 +850,7 @@ class _WordPieceBertModel(nn.Module): | |||
""" | |||
def __init__(self, model_dir: str, layers: str = '-1'): | |||
def __init__(self, model_dir: str, layers: str = '-1', pooled_cls:bool=False): | |||
super().__init__() | |||
self.tokenzier = BertTokenizer.from_pretrained(model_dir) | |||
@@ -867,8 +869,9 @@ class _WordPieceBertModel(nn.Module): | |||
self._cls_index = self.tokenzier.vocab['[CLS]'] | |||
self._sep_index = self.tokenzier.vocab['[SEP]'] | |||
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece | |||
self.pooled_cls = pooled_cls | |||
def index_dataset(self, *datasets, field_name): | |||
def index_dataset(self, *datasets, field_name, add_cls_sep=True): | |||
""" | |||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是 | |||
[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 | |||
@@ -884,10 +887,11 @@ class _WordPieceBertModel(nn.Module): | |||
tokens = self.tokenzier.wordpiece_tokenizer.tokenize(word) | |||
word_piece_ids = self.tokenzier.convert_tokens_to_ids(tokens) | |||
word_pieces.extend(word_piece_ids) | |||
if word_pieces[0] != self._cls_index: | |||
word_pieces.insert(0, self._cls_index) | |||
if word_pieces[-1] != self._sep_index: | |||
word_pieces.insert(-1, self._sep_index) | |||
if add_cls_sep: | |||
if word_pieces[0] != self._cls_index: | |||
word_pieces.insert(0, self._cls_index) | |||
if word_pieces[-1] != self._sep_index: | |||
word_pieces.insert(-1, self._sep_index) | |||
return word_pieces | |||
for index, dataset in enumerate(datasets): | |||
@@ -909,10 +913,13 @@ class _WordPieceBertModel(nn.Module): | |||
batch_size, max_len = word_pieces.size() | |||
attn_masks = word_pieces.ne(self._wordpiece_pad_index) | |||
bert_outputs, _ = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, | |||
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, | |||
output_all_encoded_layers=True) | |||
# output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size | |||
outputs = bert_outputs[0].new_zeros((len(self.layers), batch_size, max_len, bert_outputs[0].size(-1))) | |||
for l_index, l in enumerate(self.layers): | |||
outputs[l_index] = bert_outputs[l] | |||
bert_output = bert_outputs[l] | |||
if l==len(bert_outputs) and self.pooled_cls: | |||
bert_output[:, 0] = pooled_cls | |||
outputs[l_index] = bert_output | |||
return outputs |
@@ -3,7 +3,8 @@ from functools import reduce | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.init as init | |||
import glob | |||
import os | |||
def initial_parameter(net, initial_method=None): | |||
"""A method used to initialize the weights of PyTorch models. | |||
@@ -111,7 +112,7 @@ def get_dropout_mask(drop_p: float, tensor: torch.Tensor): | |||
根据tensor的形状,生成一个mask | |||
:param drop_p: float, 以多大的概率置为0。 | |||
:param tensor:torch.Tensor | |||
:param tensor: torch.Tensor | |||
:return: torch.FloatTensor. 与tensor一样的shape | |||
""" | |||
mask_x = torch.ones_like(tensor) | |||
@@ -119,7 +120,6 @@ def get_dropout_mask(drop_p: float, tensor: torch.Tensor): | |||
training=False, inplace=True) | |||
return mask_x | |||
import glob | |||
def _get_file_name_base_on_postfix(dir_path, postfix): | |||
""" | |||
@@ -128,9 +128,9 @@ def _get_file_name_base_on_postfix(dir_path, postfix): | |||
:param postfix: 形如".bin", ".json"等 | |||
:return: str,文件的路径 | |||
""" | |||
files = glob.glob(os.path.join(dir_path, '*' + postfix)) | |||
files = list(filter(lambda filename:filename.endswith(postfix), os.listdir(os.path.join(dir_path)))) | |||
if len(files) == 0: | |||
raise FileNotFoundError(f"There is no file endswith *.{postfix} file in {dir_path}") | |||
raise FileNotFoundError(f"There is no file endswith *{postfix} file in {dir_path}") | |||
elif len(files) > 1: | |||
raise FileExistsError(f"There are multiple *.{postfix} files in {dir_path}") | |||
raise FileExistsError(f"There are multiple *{postfix} files in {dir_path}") | |||
return os.path.join(dir_path, files[0]) |
@@ -224,11 +224,11 @@ class CharBiaffineParser(BiaffineParser): | |||
batch_size, seq_len, _ = arc_pred.shape | |||
flip_mask = (mask == 0) | |||
_arc_pred = arc_pred.clone() | |||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) | |||
# _arc_pred = arc_pred.clone() | |||
_arc_pred = arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf')) | |||
arc_true[:, 0].fill_(-1) | |||
label_true[:, 0].fill_(-1) | |||
arc_true.data[:, 0].fill_(-1) | |||
label_true.data[:, 0].fill_(-1) | |||
arc_nll = F.cross_entropy(_arc_pred.view(-1, seq_len), arc_true.view(-1), ignore_index=-1) | |||
label_nll = F.cross_entropy(label_pred.view(-1, label_pred.size(-1)), label_true.view(-1), ignore_index=-1) | |||
@@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import StepLR | |||
from fastNLP import Tester | |||
from fastNLP import GradientClipCallback, LRScheduler | |||
import os | |||
from fastNLP import cache_results | |||
def set_random_seed(random_seed=666): | |||
import random, numpy, torch | |||
@@ -39,43 +40,42 @@ label_mlp_size = 100 | |||
batch_size = 32 | |||
update_every = 4 | |||
n_epochs = 100 | |||
data_folder = '' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件 | |||
vector_folder = '' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt | |||
data_name = 'new_ctb7' | |||
#################################################### | |||
data_folder = f'/remote-home/hyan01/exps/JointCwsPosParser/data/{data_name}/output' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件 | |||
vector_folder = '/remote-home/hyan01/exps/CWS/pretrain/vectors' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt | |||
set_random_seed(1234) | |||
device = 0 | |||
# @cache_results('caches/{}.pkl'.format(data_name)) | |||
# def get_data(): | |||
data = CTBxJointLoader().process(data_folder) | |||
char_labels_vocab = data.vocabs['char_labels'] | |||
pre_chars_vocab = data.vocabs['pre_chars'] | |||
pre_bigrams_vocab = data.vocabs['pre_bigrams'] | |||
pre_trigrams_vocab = data.vocabs['pre_trigrams'] | |||
chars_vocab = data.vocabs['chars'] | |||
bigrams_vocab = data.vocabs['bigrams'] | |||
trigrams_vocab = data.vocabs['trigrams'] | |||
pre_chars_embed = StaticEmbedding(pre_chars_vocab, | |||
model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'), | |||
init_method=uniform_init, normalize=False) | |||
pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data/pre_chars_embed.embedding.weight.data.std() | |||
pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab, | |||
model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'), | |||
init_method=uniform_init, normalize=False) | |||
pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data/pre_bigrams_embed.embedding.weight.data.std() | |||
pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab, | |||
model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'), | |||
init_method=uniform_init, normalize=False) | |||
pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data/pre_trigrams_embed.embedding.weight.data.std() | |||
# return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data | |||
# chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data() | |||
@cache_results('caches/{}.pkl'.format(data_name)) | |||
def get_data(): | |||
data = CTBxJointLoader().process(data_folder) | |||
char_labels_vocab = data.vocabs['char_labels'] | |||
pre_chars_vocab = data.vocabs['pre_chars'] | |||
pre_bigrams_vocab = data.vocabs['pre_bigrams'] | |||
pre_trigrams_vocab = data.vocabs['pre_trigrams'] | |||
chars_vocab = data.vocabs['chars'] | |||
bigrams_vocab = data.vocabs['bigrams'] | |||
trigrams_vocab = data.vocabs['trigrams'] | |||
pre_chars_embed = StaticEmbedding(pre_chars_vocab, | |||
model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'), | |||
init_method=uniform_init, normalize=False) | |||
pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data / pre_chars_embed.embedding.weight.data.std() | |||
pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab, | |||
model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'), | |||
init_method=uniform_init, normalize=False) | |||
pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data / pre_bigrams_embed.embedding.weight.data.std() | |||
pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab, | |||
model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'), | |||
init_method=uniform_init, normalize=False) | |||
pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data / pre_trigrams_embed.embedding.weight.data.std() | |||
return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data | |||
chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data() | |||
print(data) | |||
model = CharParser(char_vocab_size=len(chars_vocab), | |||
@@ -104,11 +104,24 @@ optimizer = optim.Adam([param for param in model.parameters() if param.requires_ | |||
sampler = BucketSampler(seq_len_field_name='seq_lens') | |||
callbacks = [] | |||
from fastNLP.core.callback import Callback | |||
from torch.optim.lr_scheduler import LambdaLR | |||
class SchedulerCallback(Callback): | |||
def __init__(self, scheduler): | |||
super().__init__() | |||
self.scheduler = scheduler | |||
def on_backward_end(self): | |||
if self.step % self.update_every==0: | |||
self.scheduler.step() | |||
scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000)) | |||
# scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000)) | |||
scheduler = StepLR(optimizer, step_size=18, gamma=0.75) | |||
# optim_callback = OptimizerCallback(optimizer, scheduler, update_every) | |||
# scheduler = StepLR(optimizer, step_size=18, gamma=0.75) | |||
scheduler_callback = SchedulerCallback(scheduler) | |||
# callbacks.append(optim_callback) | |||
scheduler_callback = LRScheduler(scheduler) | |||
# scheduler_callback = LRScheduler(scheduler) | |||
callbacks.append(scheduler_callback) | |||
callbacks.append(GradientClipCallback(clip_type='value', clip_value=5)) | |||
@@ -119,6 +132,6 @@ callbacks.append(dev_callback) | |||
trainer = Trainer(data.datasets['train'], model, loss=None, metrics=metrics, n_epochs=n_epochs, batch_size=batch_size, print_every=3, | |||
validate_every=-1, dev_data=data.datasets['dev'], save_path=None, optimizer=optimizer, | |||
check_code_level=0, metric_key='u_f1', sampler=sampler, prefetch=True, use_tqdm=True, | |||
check_code_level=0, metric_key='u_f1', sampler=sampler, num_workers=2, use_tqdm=True, | |||
device=device, callbacks=callbacks, update_every=update_every) | |||
trainer.train() |
@@ -51,7 +51,7 @@ class ChineseNERLoader(DataSetLoader): | |||
:param paths: | |||
:param bool, bigrams: 是否包含生成bigram feature, [a, b, c, d] -> [ab, bc, cd, d<eos>] | |||
:param bool, trigrams: 是否包含trigram feature,[a, b, c, d] -> [abc, bcd, cd<eos>, d<eos><eos>] | |||
:return: DataBundle | |||
:return: ~fastNLP.io.DataBundle | |||
包含以下的fields | |||
raw_chars: List[str] | |||
chars: List[int] | |||
@@ -11,6 +11,13 @@ LSTM+self_attention:论文链接[A Structured Self-attentive Sentence Embedding] | |||
AWD-LSTM:论文链接[Regularizing and Optimizing LSTM Language Models](https://arxiv.org/pdf/1708.02182.pdf) | |||
#数据集来源 | |||
IMDB:http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz | |||
SST-2:https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8 | |||
SST:https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip | |||
yelp_full:https://drive.google.com/drive/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M | |||
yelp_polarity:https://drive.google.com/drive/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M | |||
# 数据集及复现结果汇总 | |||
使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道,-表示论文没有在该数据集上列出结果) | |||
@@ -203,7 +203,7 @@ callbacks.append( | |||
def train(model,datainfo,loss,metrics,optimizer,num_epochs=100): | |||
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'),batch_size=ops.batch_size, | |||
metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=[0,1,2], check_code_level=-1, | |||
n_epochs=num_epochs) | |||
n_epochs=num_epochs,callbacks=callbacks) | |||
print(trainer.train()) | |||
@@ -1,4 +1,5 @@ | |||
import os | |||
import sys | |||
import unittest | |||
from fastNLP import DataSet | |||
@@ -79,6 +80,16 @@ class TestDataSetMethods(unittest.TestCase): | |||
self.assertFalse("x" in dd.field_arrays) | |||
self.assertTrue("y" in dd.field_arrays) | |||
def test_delete_instance(self): | |||
dd = DataSet() | |||
old_length = 2 | |||
dd.add_field("x", [[1, 2, 3]] * old_length) | |||
dd.add_field("y", [[1, 2, 3, 4]] * old_length) | |||
dd.delete_instance(0) | |||
self.assertEqual(len(dd), old_length-1) | |||
dd.delete_instance(0) | |||
self.assertEqual(len(dd), old_length-2) | |||
def test_getitem(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
ins_1, ins_0 = ds[0], ds[1] | |||
@@ -0,0 +1,167 @@ | |||
import unittest | |||
import numpy as np | |||
import torch.cuda | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
from fastNLP import CrossEntropyLoss, BCELoss | |||
from fastNLP import SGD | |||
from fastNLP.core.dist_trainer import DistTrainer, get_local_rank | |||
from fastNLP.models.base_model import NaiveClassifier | |||
import shutil | |||
import os | |||
import subprocess | |||
from argparse import ArgumentParser | |||
from fastNLP.core.callback import EchoCallback | |||
from fastNLP import AccuracyMetric | |||
def prepare_fake_dataset(): | |||
mean = np.array([-3, -3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
mean = np.array([3, 3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=0) for item in class_A] + | |||
[Instance(x=[float(item[0]), float(item[1])], y=1) for item in class_B]) | |||
return data_set | |||
def prepare_fake_dataset2(*args, size=100): | |||
ys = np.random.randint(4, size=100, dtype=np.int64) | |||
data = {'y': ys} | |||
for arg in args: | |||
data[arg] = np.random.randn(size, 5) | |||
return DataSet(data=data) | |||
def set_rng_seed(seed): | |||
np.random.seed(seed) | |||
def prepare_env(): | |||
def prepare_fake_dataset(): | |||
mean = np.array([-3, -3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
mean = np.array([3, 3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||
return data_set | |||
data_set = prepare_fake_dataset() | |||
data_set.set_input("x") | |||
data_set.set_target("y") | |||
model = NaiveClassifier(2, 1) | |||
return data_set, model | |||
class TestDistTrainer(unittest.TestCase): | |||
save_path = './save_cp' | |||
def run1(self): | |||
# test distributed training | |||
print('local rank', get_local_rank()) | |||
set_rng_seed(100) | |||
data_set = prepare_fake_dataset() | |||
data_set.set_input("x", flag=True) | |||
data_set.set_target("y", flag=True) | |||
model = NaiveClassifier(2, 2) | |||
trainer = DistTrainer( | |||
model=model, train_data=data_set, optimizer=SGD(lr=0.1), | |||
loss=CrossEntropyLoss(pred="predict", target="y"), | |||
batch_size_per_gpu=8, n_epochs=3, print_every=50, save_path=self.save_path, | |||
) | |||
trainer.train() | |||
""" | |||
# 应该正确运行 | |||
""" | |||
if trainer.is_master and os.path.exists(self.save_path): | |||
shutil.rmtree(self.save_path) | |||
def run2(self): | |||
# test fp16 with distributed training | |||
print('local rank', get_local_rank()) | |||
set_rng_seed(100) | |||
data_set = prepare_fake_dataset() | |||
data_set.set_input("x", flag=True) | |||
data_set.set_target("y", flag=True) | |||
model = NaiveClassifier(2, 2) | |||
trainer = DistTrainer( | |||
model=model, train_data=data_set, optimizer=SGD(lr=0.1), | |||
loss=CrossEntropyLoss(pred="predict", target="y"), | |||
batch_size_per_gpu=8, n_epochs=3, print_every=50, save_path=self.save_path, | |||
fp16='O1' | |||
) | |||
trainer.train() | |||
""" | |||
# 应该正确运行 | |||
""" | |||
if trainer.is_master and os.path.exists(self.save_path): | |||
shutil.rmtree(self.save_path) | |||
def run3(self): | |||
set_rng_seed(100) | |||
data_set, model = prepare_env() | |||
trainer = DistTrainer( | |||
data_set, model, optimizer=None, | |||
loss=BCELoss(pred="predict", target="y"), | |||
n_epochs=3, print_every=50, | |||
callbacks_all=[EchoCallback('callbacks_all')], | |||
callbacks_master=[EchoCallback('callbacks_master')] | |||
) | |||
trainer.train() | |||
def run4(self): | |||
set_rng_seed(100) | |||
data_set, model = prepare_env() | |||
train_set, dev_set = data_set.split(0.3) | |||
model = NaiveClassifier(2, 1) | |||
trainer = DistTrainer( | |||
train_set, model, optimizer=SGD(lr=0.1), | |||
loss=BCELoss(pred="predict", target="y"), | |||
batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||
) | |||
trainer.train() | |||
""" | |||
# 应该正确运行 | |||
""" | |||
def run_dist(self, run_id): | |||
if torch.cuda.is_available(): | |||
ngpu = min(2, torch.cuda.device_count()) | |||
path = __file__ | |||
cmd = ['python', '-m', 'torch.distributed.launch', | |||
'--nproc_per_node', str(ngpu), path, '--test', str(run_id)] | |||
print(' '.join(cmd)) | |||
subprocess.check_call(cmd) | |||
def test_normal_run(self): | |||
self.run_dist(1) | |||
def no_test_fp16(self): | |||
self.run_dist(2) | |||
def test_callback(self): | |||
self.run_dist(3) | |||
def test_dev_data(self): | |||
self.run_dist(4) | |||
if __name__ == '__main__': | |||
runner = TestDistTrainer() | |||
parser = ArgumentParser() | |||
parser.add_argument('--test', type=int) | |||
args, _ = parser.parse_known_args() | |||
if args.test and hasattr(runner, 'run%s'%args.test): | |||
getattr(runner, 'run%s'%args.test)() |
@@ -170,22 +170,22 @@ class TestFieldArray(unittest.TestCase): | |||
def test_append(self): | |||
with self.assertRaises(Exception): | |||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True, use_1st_ins_infer_dim_type=False) | |||
fa.append(0) | |||
with self.assertRaises(Exception): | |||
fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True) | |||
fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True, use_1st_ins_infer_dim_type=False) | |||
fa.append([1, 2, 3, 4, 5]) | |||
with self.assertRaises(Exception): | |||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True, use_1st_ins_infer_dim_type=False) | |||
fa.append([]) | |||
with self.assertRaises(Exception): | |||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True, use_1st_ins_infer_dim_type=False) | |||
fa.append(["str", 0, 0, 0, 1.89]) | |||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]], is_input=True) | |||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]], is_input=True, use_1st_ins_infer_dim_type=False) | |||
fa.append([1.2, 2.3, 3.4, 4.5, 5.6]) | |||
self.assertEqual(len(fa), 3) | |||
self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) | |||
@@ -7,7 +7,7 @@ from fastNLP import AccuracyMetric | |||
from fastNLP.core.metrics import _pred_topk, _accuracy_topk | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from collections import Counter | |||
from fastNLP.core.metrics import SpanFPreRecMetric | |||
from fastNLP.core.metrics import SpanFPreRecMetric, ExtractiveQAMetric | |||
def _generate_tags(encoding_type, number_labels=4): | |||
@@ -347,3 +347,46 @@ class TestUsefulFunctions(unittest.TestCase): | |||
_ = _pred_topk(np.random.randint(0, 3, size=(10, 1))) | |||
# 跑通即可 | |||
class TestExtractiveQAMetric(unittest.TestCase): | |||
def test_cast_1(self): | |||
qa_prediction = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, | |||
-0.3782, 0.8240], | |||
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, -1.1563, | |||
-0.3562, -1.4116], | |||
[-1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, | |||
-2.0023, 0.0075], | |||
[-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, | |||
0.3832, -0.1540], | |||
[-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, | |||
-1.3508, -0.9513], | |||
[1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, | |||
-0.0842, -0.4294]], | |||
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, | |||
-1.4138, -0.8853], | |||
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, | |||
-1.0726, 0.0364], | |||
[0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, | |||
-0.8836, -0.9320], | |||
[0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, | |||
-1.6857, 1.1571], | |||
[1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, | |||
3.5837, 1.0184], | |||
[1.6495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, | |||
-0.9025, 0.0864]]]) | |||
qa_prediction = qa_prediction.permute(1, 2, 0) | |||
pred1, pred2 = qa_prediction.split(1, dim=-1) | |||
pred1 = pred1.squeeze(-1) | |||
pred2 = pred2.squeeze(-1) | |||
target1 = torch.LongTensor([3, 0, 2, 4, 4, 0]) | |||
target2 = torch.LongTensor([4, 1, 6, 8, 7, 1]) | |||
metric = ExtractiveQAMetric() | |||
metric.evaluate(pred1, pred2, target1, target2) | |||
result = metric.get_metric() | |||
truth = {'EM': 62.5, 'f_1': 72.5, 'noAns-f_1': 50.0, 'noAns-EM': 50.0, 'hasAns-f_1': 95.0, 'hasAns-EM': 75.0} | |||
for k, v in truth.items(): | |||
self.assertTrue(k in result) | |||
self.assertEqual(v, result[k]) |
@@ -0,0 +1,12 @@ | |||
index promptID pairID genre sentence1_binary_parse sentence2_binary_parse sentence1_parse sentence2_parse sentence1 sentence2 label1 label2 label3 label4 label5 gold_label | |||
0 63735 63735n slate ( ( The ( new rights ) ) ( are ( nice enough ) ) ) ( Everyone ( really ( likes ( the ( newest benefits ) ) ) ) ) (ROOT (S (NP (DT The) (JJ new) (NNS rights)) (VP (VBP are) (ADJP (JJ nice) (RB enough))))) (ROOT (S (NP (NN Everyone)) (VP (ADVP (RB really)) (VBZ likes) (NP (DT the) (JJS newest) (NNS benefits))))) The new rights are nice enough Everyone really likes the newest benefits neutral entailment neutral neutral neutral neutral | |||
1 91383 91383c government ( ( This site ) ( ( includes ( ( ( ( a list ) ( of ( all ( award winners ) ) ) ) and ) ( ( a ( searchable database ) ) ( of ( Government ( Executive articles ) ) ) ) ) ) . ) ) ( ( ( The ( Government ( Executive articles ) ) ) ( housed ( on ( the website ) ) ) ) ( ( ( are not ) ( able ( to ( be searched ) ) ) ) . ) ) (ROOT (S (NP (DT This) (NN site)) (VP (VBZ includes) (NP (NP (NP (DT a) (NN list)) (PP (IN of) (NP (DT all) (NN award) (NNS winners)))) (CC and) (NP (NP (DT a) (JJ searchable) (NN database)) (PP (IN of) (NP (NNP Government) (NNP Executive) (NNS articles)))))) (. .))) (ROOT (S (NP (NP (DT The) (NNP Government) (NNP Executive) (NNS articles)) (VP (VBN housed) (PP (IN on) (NP (DT the) (NN website))))) (VP (VBP are) (RB not) (ADJP (JJ able) (S (VP (TO to) (VP (VB be) (ADJP (JJ searched))))))) (. .))) This site includes a list of all award winners and a searchable database of Government Executive articles. The Government Executive articles housed on the website are not able to be searched. contradiction contradiction contradiction contradiction contradiction contradiction | |||
2 755 755e telephone ( ( ( ( uh ( i ( ( do n't ) ( know ( ( i i ) ( have ( ( mixed emotions ) ( about ( him ( ( uh sometimes ) ( i ( like him ) ) ) ) ) ) ) ) ) ) ) ) but ) ( ( at ( the ( same times ) ) ) ( i ( love ( to ( see somebody ) ) ) ) ) ) ( beat him ) ) ( I ( ( ( ( ( ( like him ) ( for ( the ( most part ) ) ) ) , ) but ) ( ( would still ) ( enjoy ( seeing ( someone ( beat him ) ) ) ) ) ) . ) ) (ROOT (SINV (S (S (INTJ (UH uh)) (NP (FW i)) (VP (VBP do) (RB n't) (VP (VB know) (NP (NP (FW i) (FW i)) (SBAR (S (VP (VBP have) (VP (VBN mixed) (NP (NNS emotions)) (PP (IN about) (S (NP (PRP him)) (VP (VBG uh) (ADVP (RB sometimes)) (NP (NP (FW i)) (PP (IN like) (NP (PRP him))))))))))))))) (CC but) (S (PP (IN at) (NP (DT the) (JJ same) (NNS times))) (NP (FW i)) (VP (VBP love) (S (VP (TO to) (VP (VB see) (NP (NN somebody)))))))) (VP (VBD beat)) (NP (PRP him)))) (ROOT (S (NP (PRP I)) (VP (VP (VBP like) (NP (PRP him)) (PP (IN for) (NP (DT the) (JJS most) (NN part)))) (, ,) (CC but) (VP (MD would) (ADVP (RB still)) (VP (VB enjoy) (S (VP (VBG seeing) (S (NP (NN someone)) (VP (VB beat) (NP (PRP him))))))))) (. .))) uh i don't know i i have mixed emotions about him uh sometimes i like him but at the same times i love to see somebody beat him I like him for the most part, but would still enjoy seeing someone beat him. entailment entailment entailment entailment entailment entailment | |||
3 78013 78013c telephone ( yeah ( ( i i ) ( think ( ( my ( favorite restaurant ) ) ( ( is always ) ( been ( ( the ( one closest ) ) ( you ( ( know ( the closest ) ) ( ( as long ) ( as ( it ( 's ( it ( meets ( ( the ( minimum criteria ) ) ( you ( know ( of ( good food ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ( ( My ( favorite restaurants ) ) ( ( ( ( are always ) ( ( ( ( ( at least ) a ) hundred ) miles ) away ) ) ( from ( my house ) ) ) . ) ) (ROOT (S (VP (VB yeah) (NP (NP (FW i) (FW i)) (SBAR (S (VP (VBP think) (SBAR (S (NP (PRP$ my) (JJ favorite) (NN restaurant)) (VP (VBZ is) (ADVP (RB always)) (VP (VBN been) (NP (NP (DT the) (CD one) (JJS closest)) (SBAR (S (NP (PRP you)) (VP (VBP know) (NP (DT the) (JJS closest)) (ADVP (ADVP (RB as) (RB long)) (SBAR (IN as) (S (NP (PRP it)) (VP (VBZ 's) (SBAR (S (NP (PRP it)) (VP (VBZ meets) (NP (NP (DT the) (JJ minimum) (NNS criteria)) (SBAR (S (NP (PRP you)) (VP (VBP know) (PP (IN of) (NP (JJ good) (NN food))))))))))))))))))))))))))))) (ROOT (S (NP (PRP$ My) (JJ favorite) (NNS restaurants)) (VP (VBP are) (ADVP (RB always)) (ADVP (NP (QP (IN at) (JJS least) (DT a) (CD hundred)) (NNS miles)) (RB away)) (PP (IN from) (NP (PRP$ my) (NN house)))) (. .))) yeah i i think my favorite restaurant is always been the one closest you know the closest as long as it's it meets the minimum criteria you know of good food My favorite restaurants are always at least a hundred miles away from my house. contradiction contradiction contradiction contradiction contradiction contradiction | |||
4 96377 96377c telephone ( i ( ( do n't ) ( know ( um ( do ( you ( do ( ( a lot ) ( of camping ) ) ) ) ) ) ) ) ) ( I ( ( know exactly ) . ) ) (ROOT (S (NP (FW i)) (VP (VBP do) (RB n't) (VP (VB know) (SBAR (S (NP (NN um)) (VP (VBP do) (SBAR (S (NP (PRP you)) (VP (VBP do) (NP (NP (DT a) (NN lot)) (PP (IN of) (NP (NN camping)))))))))))))) (ROOT (S (NP (PRP I)) (VP (VBP know) (ADVP (RB exactly))) (. .))) i don't know um do you do a lot of camping I know exactly. contradiction contradiction contradiction contradiction contradiction contradiction | |||
5 139749 139749c telephone ( well ( that ( would ( be ( ( a help ) ( i ( wish ( they ( would ( do ( that ( ( ( here ( we ( have ( got ( so ( ( little ( landfill space ) ) ( left ( that ( we ( 're ( going ( to ( ( run out ) ( before ( ( the end ) ( of ( this decade ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) and ) ( it ( ( 's really ) ( going ( to be ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ( We ( ( have ( plenty ( of ( space ( in ( the landfill ) ) ) ) ) ) . ) ) (ROOT (FRAG (ADVP (RB well)) (SBAR (WHNP (WDT that)) (S (VP (MD would) (VP (VB be) (NP (NP (DT a) (NN help)) (SBAR (S (NP (FW i)) (VP (VBP wish) (SBAR (S (NP (PRP they)) (VP (MD would) (VP (VB do) (SBAR (IN that) (S (S (ADVP (RB here)) (NP (PRP we)) (VP (VBP have) (VP (VBN got) (SBAR (IN so) (S (NP (JJ little) (NN landfill) (NN space)) (VP (VBD left) (SBAR (IN that) (S (NP (PRP we)) (VP (VBP 're) (VP (VBG going) (S (VP (TO to) (VP (VB run) (PRT (RP out)) (PP (IN before) (NP (NP (DT the) (NN end)) (PP (IN of) (NP (DT this) (NN decade)))))))))))))))))) (CC and) (S (NP (PRP it)) (VP (VBZ 's) (ADVP (RB really)) (VP (VBG going) (S (VP (TO to) (VP (VB be))))))))))))))))))))))) (ROOT (S (NP (PRP We)) (VP (VBP have) (NP (NP (RB plenty)) (PP (IN of) (NP (NP (NN space)) (PP (IN in) (NP (DT the) (NN landfill))))))) (. .))) well that would be a help i wish they would do that here we have got so little landfill space left that we're going to run out before the end of this decade and it's really going to be We have plenty of space in the landfill. contradiction contradiction contradiction contradiction contradiction contradiction | |||
6 101415 101415c telephone ( yeah ( ( ( i know ) and ) ( i ( did ( that ( ( ( all ( through college ) ) and ) ( it ( worked too ) ) ) ) ) ) ) ) ( I ( ( ( did ( that all ) ) ( through college ) ) ( but ( it ( never worked ) ) ) ) ) (ROOT (S (VP (VB yeah) (S (S (NP (FW i)) (VP (VBP know))) (CC and) (S (NP (FW i)) (VP (VBD did) (SBAR (IN that) (S (S (NP (DT all)) (PP (IN through) (NP (NN college)))) (CC and) (S (NP (PRP it)) (VP (VBD worked) (ADVP (RB too)))))))))))) (ROOT (S (NP (PRP I)) (VP (VBD did) (ADVP (IN that) (DT all)) (PP (IN through) (NP (NN college))) (SBAR (CC but) (S (NP (PRP it)) (ADVP (RB never)) (VP (VBD worked))))))) yeah i know and i did that all through college and it worked too I did that all through college but it never worked contradiction contradiction contradiction contradiction contradiction contradiction | |||
7 93958 93958n travel ( ( ( ( ( Calcutta ( seems ( to ( be ( ( the ( only ( other ( production center ) ) ) ) ( ( having ( any pretensions ) ) ( to ( ( artistic creativity ) ( at all ) ) ) ) ) ) ) ) ) , ) but ) ( ironically ( you ( ( 're actually ) ( ( more ( likely ( to ( see ( ( the works ) ( of ( ( ( Satyajit Ray ) or ) ( ( Mrinal Sen ) ( shown ( in ( Europe ( or ( North America ) ) ) ) ) ) ) ) ) ) ) ) ) ( than ( in ( India itself ) ) ) ) ) ) ) ) . ) ( ( Most ( of ( ( Mrinal ( Sen 's ) ) work ) ) ) ( ( can ( be ( found ( in ( European collections ) ) ) ) ) . ) ) (ROOT (S (S (NP (NNP Calcutta)) (VP (VBZ seems) (S (VP (TO to) (VP (VB be) (NP (NP (DT the) (JJ only) (JJ other) (NN production) (NN center)) (VP (VBG having) (NP (DT any) (NNS pretensions)) (PP (TO to) (NP (NP (JJ artistic) (NN creativity)) (ADVP (IN at) (DT all))))))))))) (, ,) (CC but) (S (ADVP (RB ironically)) (NP (PRP you)) (VP (VBP 're) (ADVP (RB actually)) (ADJP (ADJP (RBR more) (JJ likely) (S (VP (TO to) (VP (VB see) (NP (NP (DT the) (NNS works)) (PP (IN of) (NP (NP (NNP Satyajit) (NNP Ray)) (CC or) (NP (NP (NNP Mrinal) (NNP Sen)) (VP (VBN shown) (PP (IN in) (NP (NNP Europe) (CC or) (NNP North) (NNP America)))))))))))) (ADVP (IN than) (PP (IN in) (S (VP (VBG India) (NP (PRP itself))))))))) (. .))) (ROOT (S (NP (NP (JJS Most)) (PP (IN of) (NP (NP (NNP Mrinal) (NNP Sen) (POS 's)) (NN work)))) (VP (MD can) (VP (VB be) (VP (VBN found) (PP (IN in) (NP (JJ European) (NNS collections)))))) (. .))) Calcutta seems to be the only other production center having any pretensions to artistic creativity at all, but ironically you're actually more likely to see the works of Satyajit Ray or Mrinal Sen shown in Europe or North America than in India itself. Most of Mrinal Sen's work can be found in European collections. neutral neutral entailment neutral neutral neutral | |||
8 12567 12567c slate ( ( If ( ( that investor ) ( were ( willing ( to ( pay ( extra ( for ( ( the security ) ( of ( limited downside ) ) ) ) ) ) ) ) ) ) ) ( , ( she ( ( could ( ( buy ( put options ) ) ( with ( ( a ( strike price ) ) ( of ( ( ( $ 98 ) , ) ( which ( would ( ( ( lock ( in ( ( her profit ) ( on ( ( the shares ) ( at ( $ 18 ) ) ) ) ) ) ) , ) ( less ( whatever ( ( the options ) cost ) ) ) ) ) ) ) ) ) ) ) ) . ) ) ) ) ( ( THe ( strike price ) ) ( ( could ( be ( $ 8 ) ) ) . ) ) (ROOT (S (SBAR (IN If) (S (NP (DT that) (NN investor)) (VP (VBD were) (ADJP (JJ willing) (S (VP (TO to) (VP (VB pay) (NP (NP (JJ extra)) (PP (IN for) (NP (NP (DT the) (NN security)) (PP (IN of) (NP (JJ limited) (NN downside))))))))))))) (, ,) (NP (PRP she)) (VP (MD could) (VP (VB buy) (NP (NN put) (NNS options)) (PP (IN with) (NP (NP (DT a) (NN strike) (NN price)) (PP (IN of) (NP (NP ($ $) (CD 98)) (, ,) (SBAR (WHNP (WDT which)) (S (VP (MD would) (VP (VB lock) (PP (IN in) (NP (NP (PRP$ her) (NN profit)) (PP (IN on) (NP (NP (DT the) (NNS shares)) (PP (IN at) (NP ($ $) (CD 18))))))) (, ,) (ADVP (ADVP (RBR less)) (SBAR (WHNP (WDT whatever)) (S (NP (DT the) (NNS options)) (VP (VBD cost))))))))))))))) (. .))) (ROOT (S (NP (NNP THe) (NN strike) (NN price)) (VP (MD could) (VP (VB be) (NP ($ $) (CD 8)))) (. .))) If that investor were willing to pay extra for the security of limited downside, she could buy put options with a strike price of $98, which would lock in her profit on the shares at $18, less whatever the options cost. THe strike price could be $8. contradiction contradiction contradiction contradiction contradiction contradiction | |||
9 117487 117487n slate ( ( 3 -RRB- ) ( ( Dare ( you ( ( ( rise ( to ( ( ( ( the occasion ) , ) ( like Raskolnikov ) ) , ) ) ) and ) ( reject ( ( the ( petty rules ) ) ( that ( govern ( lesser men ) ) ) ) ) ) ) ) ? ) ) ( ( ( Would you ) ( ( ( rise up ) and ) ( defeaat ( ( all ( evil lords ) ) ( in ( the town ) ) ) ) ) ) ? ) (ROOT (S (LST (LS 3) (-RRB- -RRB-)) (VP (VB Dare) (S (NP (PRP you)) (VP (VP (VB rise) (PP (TO to) (NP (NP (DT the) (NN occasion)) (, ,) (PP (IN like) (NP (NNP Raskolnikov))) (, ,)))) (CC and) (VP (VB reject) (NP (NP (DT the) (JJ petty) (NNS rules)) (SBAR (WHNP (WDT that)) (S (VP (VBP govern) (NP (JJR lesser) (NNS men)))))))))) (. ?))) (ROOT (SQ (MD Would) (NP (PRP you)) (VP (VP (VB rise) (PRT (RP up))) (CC and) (VP (VB defeaat) (NP (NP (DT all) (JJ evil) (NNS lords)) (PP (IN in) (NP (DT the) (NN town)))))) (. ?))) 3) Dare you rise to the occasion, like Raskolnikov, and reject the petty rules that govern lesser men? Would you rise up and defeaat all evil lords in the town? neutral neutral neutral neutral neutral neutral | |||
10 9616 9616c travel ( ( The ( ( most important ) directions ) ) ( ( ( are ( simply ( ( up and ) up ) ) ) ( ( ( ( ( ( ( ( leads eventually ) ( to ( the cathedral ) ) ) and ) ( fortress ( commanding ( the hilltop ) ) ) ) , ) and ) down ) ( inevitably ( ( leads ( to ( one ( of ( three gates ) ) ) ) ) ( through ( ( the wall ) ( to ( the ( new town ) ) ) ) ) ) ) ) ) . ) ) ( Go ( ( downwards ( to ( one ( of ( ( ( the gates ) , ) ( ( all ( of which ) ) ( will ( ( lead you ) ( into ( the cathedral ) ) ) ) ) ) ) ) ) ) . ) ) (ROOT (S (NP (DT The) (ADJP (RBS most) (JJ important)) (NNS directions)) (VP (VBP are) (PRN (ADVP (RB simply)) (ADVP (RB up) (CC and) (RB up))) (VP (VP (VBZ leads) (ADVP (RB eventually)) (PP (TO to) (NP (DT the) (NN cathedral)))) (CC and) (VP (VBZ fortress) (NP (JJ commanding) (DT the) (NN hilltop))) (, ,) (CC and) (ADVP (RB down)) (VP (ADVP (RB inevitably)) (VBZ leads) (PP (TO to) (NP (NP (CD one)) (PP (IN of) (NP (CD three) (NNS gates))))) (PP (IN through) (NP (NP (DT the) (NN wall)) (PP (TO to) (NP (DT the) (JJ new) (NN town)))))))) (. .))) (ROOT (S (NP (NNP Go)) (VP (VBZ downwards) (PP (TO to) (NP (NP (CD one)) (PP (IN of) (NP (NP (DT the) (NNS gates)) (, ,) (SBAR (WHNP (DT all) (WHPP (IN of) (WHNP (WDT which)))) (S (VP (MD will) (VP (VB lead) (NP (PRP you)) (PP (IN into) (NP (DT the) (NN cathedral)))))))))))) (. .))) The most important directions are simply up and up leads eventually to the cathedral and fortress commanding the hilltop, and down inevitably leads to one of three gates through the wall to the new town. Go downwards to one of the gates, all of which will lead you into the cathedral. contradiction contradiction entailment contradiction contradiction contradiction |
@@ -0,0 +1,21 @@ | |||
import unittest | |||
from fastNLP import Vocabulary | |||
from fastNLP.embeddings import ElmoEmbedding | |||
import torch | |||
import os | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
class TestDownload(unittest.TestCase): | |||
def test_download_small(self): | |||
# import os | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='en-small') | |||
words = torch.LongTensor([[0, 1, 2]]) | |||
print(elmo_embed(words).size()) | |||
# 首先保证所有权重可以加载;上传权重;验证可以下载 | |||
@@ -3,13 +3,110 @@ import unittest | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP import Vocabulary | |||
import torch | |||
import os | |||
class TestRandomSameEntry(unittest.TestCase): | |||
def test_same_vector(self): | |||
vocab = Vocabulary().add_word_lst(["The", "the", "THE"]) | |||
vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True) | |||
words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE"]]]) | |||
words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'a', 'A']]]) | |||
words = embed(words) | |||
embed_0 = words[0, 0] | |||
for i in range(1, words.size(1)): | |||
for i in range(1, 3): | |||
assert torch.sum(embed_0==words[0, i]).eq(len(embed_0)) | |||
embed_0 = words[0, 3] | |||
for i in range(3, 5): | |||
assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0)) | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
def test_same_vector2(self): | |||
vocab = Vocabulary().add_word_lst(["The", 'a', 'b', "the", "THE", "B", 'a', "A"]) | |||
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.6B.100d.txt', | |||
lower=True) | |||
words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'b', "B", 'a', 'A']]]) | |||
words = embed(words) | |||
embed_0 = words[0, 0] | |||
for i in range(1, 3): | |||
assert torch.sum(embed_0==words[0, i]).eq(len(embed_0)) | |||
embed_0 = words[0, 3] | |||
for i in range(3, 5): | |||
assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0)) | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
def test_same_vector3(self): | |||
# 验证lower | |||
word_lst = ["The", "the"] | |||
no_create_word_lst = ['of', 'Of', 'With', 'with'] | |||
vocab = Vocabulary().add_word_lst(word_lst) | |||
vocab.add_word_lst(no_create_word_lst, no_create_entry=True) | |||
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', | |||
lower=True) | |||
words = torch.LongTensor([[vocab.to_index(word) for word in word_lst+no_create_word_lst]]) | |||
words = embed(words) | |||
lowered_word_lst = [word.lower() for word in word_lst] | |||
lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst] | |||
lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst) | |||
lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True) | |||
lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', | |||
lower=False) | |||
lowered_words = torch.LongTensor([[lowered_vocab.to_index(word) for word in lowered_word_lst+lowered_no_create_word_lst]]) | |||
lowered_words = lowered_embed(lowered_words) | |||
all_words = word_lst + no_create_word_lst | |||
for idx, (word_i, word_j) in enumerate(zip(words[0], lowered_words[0])): | |||
with self.subTest(idx=idx, word=all_words[idx]): | |||
assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size) | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
def test_same_vector4(self): | |||
# 验证在有min_freq下的lower | |||
word_lst = ["The", "the", "the", "The", "a", "A"] | |||
no_create_word_lst = ['of', 'Of', "Of", "of", 'With', 'with'] | |||
all_words = word_lst[:-2] + no_create_word_lst[:-2] | |||
vocab = Vocabulary(min_freq=2).add_word_lst(word_lst) | |||
vocab.add_word_lst(no_create_word_lst, no_create_entry=True) | |||
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', | |||
lower=True) | |||
words = torch.LongTensor([[vocab.to_index(word) for word in all_words]]) | |||
words = embed(words) | |||
lowered_word_lst = [word.lower() for word in word_lst] | |||
lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst] | |||
lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst) | |||
lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True) | |||
lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', | |||
lower=False) | |||
lowered_words = torch.LongTensor([[lowered_vocab.to_index(word.lower()) for word in all_words]]) | |||
lowered_words = lowered_embed(lowered_words) | |||
for idx in range(len(all_words)): | |||
word_i, word_j = words[0, idx], lowered_words[0, idx] | |||
with self.subTest(idx=idx, word=all_words[idx]): | |||
assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size) | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
def test_same_vector5(self): | |||
# 检查通过使用min_freq后的word是否内容一致 | |||
word_lst = ["they", "the", "they", "the", 'he', 'he', "a", "A"] | |||
no_create_word_lst = ['of', "of", "she", "she", 'With', 'with'] | |||
all_words = word_lst[:-2] + no_create_word_lst[:-2] | |||
vocab = Vocabulary().add_word_lst(word_lst) | |||
vocab.add_word_lst(no_create_word_lst, no_create_entry=True) | |||
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', | |||
lower=False, min_freq=2) | |||
words = torch.LongTensor([[vocab.to_index(word) for word in all_words]]) | |||
words = embed(words) | |||
min_freq_vocab = Vocabulary(min_freq=2).add_word_lst(word_lst) | |||
min_freq_vocab.add_word_lst(no_create_word_lst, no_create_entry=True) | |||
min_freq_embed = StaticEmbedding(min_freq_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', | |||
lower=False) | |||
min_freq_words = torch.LongTensor([[min_freq_vocab.to_index(word.lower()) for word in all_words]]) | |||
min_freq_words = min_freq_embed(min_freq_words) | |||
for idx in range(len(all_words)): | |||
word_i, word_j = words[0, idx], min_freq_words[0, idx] | |||
with self.subTest(idx=idx, word=all_words[idx]): | |||
assert torch.sum(word_i == word_j).eq(min_freq_embed.embed_size) |
@@ -0,0 +1,19 @@ | |||
import unittest | |||
from fastNLP.io.loader.classification import YelpFullLoader | |||
from fastNLP.io.loader.classification import YelpPolarityLoader | |||
from fastNLP.io.loader.classification import IMDBLoader | |||
from fastNLP.io.loader.classification import SST2Loader | |||
from fastNLP.io.loader.classification import SSTLoader | |||
import os | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
class TestDownload(unittest.TestCase): | |||
def test_download(self): | |||
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]: | |||
loader().download() | |||
def test_load(self): | |||
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]: | |||
data_bundle = loader().load() | |||
print(data_bundle) |
@@ -0,0 +1,22 @@ | |||
import unittest | |||
from fastNLP.io.loader.matching import RTELoader | |||
from fastNLP.io.loader.matching import QNLILoader | |||
from fastNLP.io.loader.matching import SNLILoader | |||
from fastNLP.io.loader.matching import QuoraLoader | |||
from fastNLP.io.loader.matching import MNLILoader | |||
import os | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
class TestDownload(unittest.TestCase): | |||
def test_download(self): | |||
for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]: | |||
loader().download() | |||
with self.assertRaises(Exception): | |||
QuoraLoader().load() | |||
def test_load(self): | |||
for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]: | |||
data_bundle = loader().load() | |||
print(data_bundle) | |||
@@ -0,0 +1,13 @@ | |||
import unittest | |||
import os | |||
from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
class TestPipe(unittest.TestCase): | |||
def test_process_from_file(self): | |||
for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: | |||
with self.subTest(pipe=pipe): | |||
print(pipe) | |||
data_bundle = pipe(tokenizer='raw').process_from_file() | |||
print(data_bundle) |
@@ -0,0 +1,26 @@ | |||
import unittest | |||
import os | |||
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe | |||
from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
class TestPipe(unittest.TestCase): | |||
def test_process_from_file(self): | |||
for pipe in [SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe]: | |||
with self.subTest(pipe=pipe): | |||
print(pipe) | |||
data_bundle = pipe(tokenizer='raw').process_from_file() | |||
print(data_bundle) | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
class TestBertPipe(unittest.TestCase): | |||
def test_process_from_file(self): | |||
for pipe in [SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe]: | |||
with self.subTest(pipe=pipe): | |||
print(pipe) | |||
data_bundle = pipe(tokenizer='raw').process_from_file() | |||
print(data_bundle) |
@@ -0,0 +1,15 @@ | |||
import unittest | |||
from fastNLP.core.const import Const | |||
from fastNLP.io.data_loader import MNLILoader | |||
class TestDataLoader(unittest.TestCase): | |||
def test_mnli_loader(self): | |||
ds = MNLILoader().process('test/data_for_tests/sample_mnli.tsv', | |||
to_lower=True, get_index=True, seq_len_type='mask') | |||
self.assertTrue('train' in ds.datasets) | |||
self.assertTrue(len(ds.datasets) == 1) | |||
self.assertTrue(len(ds.datasets['train']) == 11) | |||
self.assertTrue(isinstance(ds.datasets['train'][0][Const.INPUT_LENS(0)], list)) |
@@ -0,0 +1,9 @@ | |||
import unittest | |||
from .model_runner import * | |||
from fastNLP.models.snli import ESIM | |||
class TestSNLIModel(unittest.TestCase): | |||
def test_snli(self): | |||
model = ESIM((VOCAB_SIZE, 10), num_labels=NUM_CLS, dropout_rate=0) | |||
RUNNER.run_model_with_task(NLI, model) |