diff --git a/.travis.yml b/.travis.yml index 210d158a..0d63417a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,9 @@ language: python python: - "3.6" + +env: + - TRAVIS=1 # command to install dependencies install: - pip install --quiet -r requirements.txt diff --git a/docs/Makefile b/docs/Makefile index 2b4de2d8..b9f1cf95 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -14,7 +14,7 @@ help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) apidoc: - $(SPHINXAPIDOC) -efM -o source ../$(SPHINXPROJ) + $(SPHINXAPIDOC) -efM -o source ../$(SPHINXPROJ) && python3 format.py server: cd build/html && python -m http.server diff --git a/docs/format.py b/docs/format.py new file mode 100644 index 00000000..7cc341c2 --- /dev/null +++ b/docs/format.py @@ -0,0 +1,65 @@ +import os + + +def shorten(file, to_delete, cut=False): + if file.endswith("index.rst") or file.endswith("conf.py"): + return + res = [] + with open(file, "r") as fin: + lines = fin.readlines() + for line in lines: + if cut and line.rstrip() == "Submodules": + break + else: + res.append(line.rstrip()) + for i, line in enumerate(res): + if line.endswith(" package"): + res[i] = res[i][:-len(" package")] + res[i + 1] = res[i + 1][:-len(" package")] + elif line.endswith(" module"): + res[i] = res[i][:-len(" module")] + res[i + 1] = res[i + 1][:-len(" module")] + else: + for name in to_delete: + if line.endswith(name): + res[i] = "del" + + with open(file, "w") as fout: + for line in res: + if line != "del": + print(line, file=fout) + + +def clear(path='./source/'): + files = os.listdir(path) + to_delete = [ + "fastNLP.core.dist_trainer", + "fastNLP.core.predictor", + + "fastNLP.io.file_reader", + "fastNLP.io.config_io", + + "fastNLP.embeddings.contextual_embedding", + + "fastNLP.modules.dropout", + "fastNLP.models.base_model", + "fastNLP.models.bert", + "fastNLP.models.enas_utils", + "fastNLP.models.enas_controller", + "fastNLP.models.enas_model", + "fastNLP.models.enas_trainer", + ] + for file in files: + if not os.path.isdir(path + file): + res = file.split('.') + if len(res) > 4: + to_delete.append(file[:-4]) + elif len(res) == 4: + shorten(path + file, to_delete, True) + else: + shorten(path + file, to_delete) + for file in to_delete: + os.remove(path + file + ".rst") + + +clear() diff --git a/docs/source/fastNLP.core.rst b/docs/source/fastNLP.core.rst index cacc6622..08d161b7 100644 --- a/docs/source/fastNLP.core.rst +++ b/docs/source/fastNLP.core.rst @@ -6,11 +6,10 @@ fastNLP.core :undoc-members: :show-inheritance: -子模块 +Submodules ---------- .. toctree:: - :maxdepth: 1 fastNLP.core.batch fastNLP.core.callback diff --git a/docs/source/fastNLP.embeddings.rst b/docs/source/fastNLP.embeddings.rst index 6b168906..6872e91d 100644 --- a/docs/source/fastNLP.embeddings.rst +++ b/docs/source/fastNLP.embeddings.rst @@ -6,11 +6,10 @@ fastNLP.embeddings :undoc-members: :show-inheritance: -子模块 +Submodules ---------- .. toctree:: - :maxdepth: 1 fastNLP.embeddings.bert_embedding fastNLP.embeddings.char_embedding diff --git a/docs/source/fastNLP.io.data_loader.rst b/docs/source/fastNLP.io.data_loader.rst index 8f990102..0b4f5d0b 100644 --- a/docs/source/fastNLP.io.data_loader.rst +++ b/docs/source/fastNLP.io.data_loader.rst @@ -1,7 +1,8 @@ fastNLP.io.data\_loader -========================== +======================= .. automodule:: fastNLP.io.data_loader :members: :undoc-members: - :show-inheritance: \ No newline at end of file + :show-inheritance: + diff --git a/docs/source/fastNLP.io.file_utils.rst b/docs/source/fastNLP.io.file_utils.rst new file mode 100644 index 00000000..944550d7 --- /dev/null +++ b/docs/source/fastNLP.io.file_utils.rst @@ -0,0 +1,7 @@ +fastNLP.io.file\_utils +====================== + +.. automodule:: fastNLP.io.file_utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.io.loader.rst b/docs/source/fastNLP.io.loader.rst new file mode 100644 index 00000000..bbdc1d7a --- /dev/null +++ b/docs/source/fastNLP.io.loader.rst @@ -0,0 +1,8 @@ +fastNLP.io.loader +================= + +.. automodule:: fastNLP.io.loader + :members: + :undoc-members: + :show-inheritance: + diff --git a/docs/source/fastNLP.io.pipe.rst b/docs/source/fastNLP.io.pipe.rst new file mode 100644 index 00000000..bf126585 --- /dev/null +++ b/docs/source/fastNLP.io.pipe.rst @@ -0,0 +1,8 @@ +fastNLP.io.pipe +=============== + +.. automodule:: fastNLP.io.pipe + :members: + :undoc-members: + :show-inheritance: + diff --git a/docs/source/fastNLP.io.rst b/docs/source/fastNLP.io.rst index a97ed67d..0a006709 100644 --- a/docs/source/fastNLP.io.rst +++ b/docs/source/fastNLP.io.rst @@ -6,14 +6,23 @@ fastNLP.io :undoc-members: :show-inheritance: -子模块 +Subpackages +----------- + +.. toctree:: + + fastNLP.io.data_loader + fastNLP.io.loader + fastNLP.io.pipe + +Submodules ---------- .. toctree:: - :maxdepth: 1 fastNLP.io.base_loader - fastNLP.io.embed_loader fastNLP.io.dataset_loader - fastNLP.io.data_loader + fastNLP.io.embed_loader + fastNLP.io.file_utils fastNLP.io.model_io + fastNLP.io.utils diff --git a/docs/source/fastNLP.io.utils.rst b/docs/source/fastNLP.io.utils.rst new file mode 100644 index 00000000..0b3f3938 --- /dev/null +++ b/docs/source/fastNLP.io.utils.rst @@ -0,0 +1,7 @@ +fastNLP.io.utils +================ + +.. automodule:: fastNLP.io.utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.models.rst b/docs/source/fastNLP.models.rst index 2ea546e2..36875b85 100644 --- a/docs/source/fastNLP.models.rst +++ b/docs/source/fastNLP.models.rst @@ -6,11 +6,10 @@ fastNLP.models :undoc-members: :show-inheritance: -子模块 +Submodules ---------- .. toctree:: - :maxdepth: 1 fastNLP.models.biaffine_parser fastNLP.models.cnn_text_classification diff --git a/docs/source/fastNLP.modules.encoder.rst b/docs/source/fastNLP.modules.encoder.rst index 0562f12d..e60f9fa4 100644 --- a/docs/source/fastNLP.modules.encoder.rst +++ b/docs/source/fastNLP.modules.encoder.rst @@ -5,3 +5,4 @@ fastNLP.modules.encoder :members: :undoc-members: :show-inheritance: + diff --git a/docs/source/fastNLP.modules.rst b/docs/source/fastNLP.modules.rst index 646ef2d3..06494b53 100644 --- a/docs/source/fastNLP.modules.rst +++ b/docs/source/fastNLP.modules.rst @@ -6,12 +6,17 @@ fastNLP.modules :undoc-members: :show-inheritance: -子模块 +Subpackages ----------- .. toctree:: - :titlesonly: - :maxdepth: 1 fastNLP.modules.decoder - fastNLP.modules.encoder \ No newline at end of file + fastNLP.modules.encoder + +Submodules +---------- + +.. toctree:: + + fastNLP.modules.utils diff --git a/docs/source/fastNLP.modules.utils.rst b/docs/source/fastNLP.modules.utils.rst new file mode 100644 index 00000000..c0219435 --- /dev/null +++ b/docs/source/fastNLP.modules.utils.rst @@ -0,0 +1,7 @@ +fastNLP.modules.utils +===================== + +.. automodule:: fastNLP.modules.utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.rst b/docs/source/fastNLP.rst index 0057a184..e3ba429d 100644 --- a/docs/source/fastNLP.rst +++ b/docs/source/fastNLP.rst @@ -1,16 +1,15 @@ -API 文档 -=============== +fastNLP +======= .. automodule:: fastNLP :members: :undoc-members: :show-inheritance: -内部模块 +Subpackages ----------- .. toctree:: - :maxdepth: 1 fastNLP.core fastNLP.embeddings diff --git a/docs/source/modules.rst b/docs/source/modules.rst index 9ca3c7f3..e9a92cb7 100644 --- a/docs/source/modules.rst +++ b/docs/source/modules.rst @@ -2,7 +2,6 @@ fastNLP ======= .. toctree:: - :titlesonly: :maxdepth: 4 fastNLP diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index b246c6a0..eeabda35 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -14,6 +14,7 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa """ from .batch import DataSetIter, BatchIter, TorchLoaderIter from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC +from .callback import EvaluateCallback, FitlogCallback, SaveModelCallback from .const import Const from .dataset import DataSet from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder @@ -24,5 +25,5 @@ from .optimizer import Optimizer, SGD, Adam from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler from .tester import Tester from .trainer import Trainer -from .utils import cache_results, seq_len_to_mask +from .utils import cache_results, seq_len_to_mask, get_seq_len from .vocabulary import Vocabulary diff --git a/fastNLP/core/_parallel_utils.py b/fastNLP/core/_parallel_utils.py index 4a7757d3..6b24d9f9 100644 --- a/fastNLP/core/_parallel_utils.py +++ b/fastNLP/core/_parallel_utils.py @@ -1,6 +1,7 @@ import threading import torch +from torch import nn from torch.nn.parallel.parallel_apply import get_a_var from torch.nn.parallel.scatter_gather import scatter_kwargs, gather @@ -86,3 +87,16 @@ def _data_parallel_wrapper(func_name, device_ids, output_device): outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) return gather(outputs, output_device) return wrapper + + +def _model_contains_inner_module(model): + """ + + :param nn.Module model: 模型文件,判断是否内部包含model.module, 多用于check模型是否是nn.DataParallel, + nn.parallel.DistributedDataParallel。主要是在做形参匹配的时候需要使用最内部的model的function。 + :return: bool + """ + if isinstance(model, nn.Module): + if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): + return True + return False \ No newline at end of file diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 64c5f48e..8d97783e 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -48,6 +48,11 @@ class DataSetGetter: return len(self.dataset) def collate_fn(self, batch: list): + """ + + :param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]] + :return: + """ # TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景 batch_x = {n:[] for n in self.inputs.keys()} batch_y = {n:[] for n in self.targets.keys()} @@ -93,9 +98,13 @@ class DataSetGetter: class SamplerAdapter(torch.utils.data.Sampler): def __init__(self, sampler, dataset): + super().__init__(dataset) self.sampler = sampler self.dataset = dataset + def __len__(self): + return len(self.dataset) + def __iter__(self): return iter(self.sampler(self.dataset)) @@ -165,15 +174,19 @@ class DataSetIter(BatchIter): timeout=0, worker_init_fn=None): super().__init__() assert isinstance(dataset, DataSet) - sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) + if not isinstance(sampler, torch.utils.data.Sampler): + self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) + else: + self.sampler = sampler dataset = DataSetGetter(dataset, as_numpy) collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None self.dataiter = torch.utils.data.DataLoader( - dataset=dataset, batch_size=batch_size, sampler=sampler, + dataset=dataset, batch_size=batch_size, sampler=self.sampler, collate_fn=collate_fn, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn) - self.num_batches = self.get_num_batches(len(dataset), batch_size, drop_last) + # 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了 + self.num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last) self.batch_size = batch_size @@ -182,7 +195,7 @@ class TorchLoaderIter(BatchIter): super().__init__() assert isinstance(dataset, torch.utils.data.DataLoader) self.dataiter = dataset - self.num_batches = self.get_num_batches(len(dataset), dataset.batch_size, dataset.drop_last) + self.num_batches = self.get_num_batches(len(dataset.sampler), dataset.batch_size, dataset.drop_last) self.batch_size = dataset.batch_size @@ -200,6 +213,13 @@ class OnlineDataIter(BatchIter): def _to_tensor(batch, field_dtype): + """ + + :param batch: np.array() + :param field_dtype: 数据类型 + :return: batch, flag. 如果传入的数据支持转为tensor,返回的batch就是tensor,且flag为True;如果传入的数据不支持转为tensor, + 返回的batch就是原来的数据,且flag为False + """ try: if field_dtype is not None and isinstance(field_dtype, type)\ and issubclass(field_dtype, Number) \ diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 6f855397..633c6f45 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -57,6 +57,7 @@ __all__ = [ "FitlogCallback", "LRScheduler", "ControlC", + "EvaluateCallback", "CallbackException", "EarlyStopError" @@ -79,6 +80,7 @@ except: from ..io.model_io import ModelSaver, ModelLoader from .dataset import DataSet from .tester import Tester +import logging try: import fitlog @@ -100,7 +102,8 @@ class Callback(object): def __init__(self): super(Callback, self).__init__() self._trainer = None # 在Trainer内部被重新赋值 - + self._disabled = False + @property def trainer(self): """ @@ -158,7 +161,19 @@ class Callback(object): def batch_per_epoch(self): """每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。""" return self._trainer.batch_per_epoch - + + @property + def is_master(self): + return self._trainer.is_master() + + @property + def disabled(self): + return self._disabled + + @property + def logger(self): + return getattr(self._trainer, 'logger', logging) + def on_train_begin(self): """ 在Train过程开始之前调用。 @@ -250,6 +265,14 @@ class Callback(object): :return: """ pass + + def on_validation(self): + """ + 如果Trainer中设置了验证,则会在每次需要验证时调用该函数 + + :return: + """ + pass def on_epoch_end(self): """ @@ -281,6 +304,8 @@ def _transfer(func): def wrapper(manager, *arg): returns = [] for callback in manager.callbacks: + if callback.disabled: + continue returns.append(getattr(callback, func.__name__)(*arg)) return returns @@ -297,22 +322,28 @@ class CallbackManager(Callback): """ super(CallbackManager, self).__init__() # set attribute of trainer environment - + self._env = env self.callbacks = [] - if callbacks is not None: - if isinstance(callbacks, list): - if all([isinstance(cb, Callback) for cb in callbacks]) is True: - self.callbacks.extend(callbacks) - else: - obj = [not isinstance(cb, Callback) for cb in callbacks][0] - raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") + if callbacks: + self.callbacks = self.prepare_callbacks(callbacks) + + def prepare_callbacks(self, callbacks): + if not callbacks: + return [] + if isinstance(callbacks, list): + if all([isinstance(cb, Callback) for cb in callbacks]) is True: + pass else: - raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") - - for env_name, env_val in env.items(): - for callback in self.callbacks: + obj = [not isinstance(cb, Callback) for cb in callbacks][0] + raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") + else: + raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") + + for env_name, env_val in self._env.items(): + for callback in callbacks: setattr(callback, '_' + env_name, env_val) # Callback.trainer - + return callbacks + @_transfer def on_train_begin(self): pass @@ -352,6 +383,10 @@ class CallbackManager(Callback): @_transfer def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): pass + + @_transfer + def on_validation(self): + pass @_transfer def on_epoch_end(self): @@ -366,6 +401,25 @@ class CallbackManager(Callback): pass +class DistCallbackManager(CallbackManager): + def __init__(self, env, callbacks_all=None, callbacks_master=None): + super(DistCallbackManager, self).__init__(env) + assert 'trainer' in env + is_master = env['trainer'].is_master + self.patch_callback(callbacks_master, disabled=not is_master) + self.callbacks_all = self.prepare_callbacks(callbacks_all) + self.callbacks_master = self.prepare_callbacks(callbacks_master) + self.callbacks = self.callbacks_all + self.callbacks_master + + def patch_callback(self, callbacks, disabled): + if not callbacks: + return + if not isinstance(callbacks, (list, tuple)): + callbacks = [callbacks] + for cb in callbacks: + cb._disabled = disabled + + class GradientClipCallback(Callback): """ 别名::class:`fastNLP.GradientClipCallback` :class:`fastNLP.core.callback.GradientClipCallback` @@ -403,6 +457,9 @@ class GradientClipCallback(Callback): def on_backward_end(self): if self.step%self.update_every==0: if self.parameters is None: + if getattr(self.trainer, 'fp16', ''): + from apex import amp + self.clip_fun(amp.master_params(self.optimizer), self.clip_value) self.clip_fun(self.model.parameters(), self.clip_value) else: self.clip_fun(self.parameters, self.clip_value) @@ -448,10 +505,9 @@ class FitlogCallback(Callback): 并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 - :param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 - DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过 - dict的方式传入。如果仅传入DataSet, 则被命名为test - :param ~fastNLP.Tester tester: Tester对象,将在on_valid_end时调用。tester中的DataSet会被称为为`test` + :param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要 + 传入多个DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。data的结果的名称以'data'开头。 + :param ~fastNLP.Tester,Dict[~fastNLP.Tester] tester: Tester对象,将在on_valid_end时调用。tester的结果的名称以'tester'开头 :param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值),如果数据集较大建议将该值设置得 大一些,不然会导致log文件巨大。默认为0, 即不要记录loss。 :param int verbose: 是否在终端打印evaluation的结果,0不打印。 @@ -465,21 +521,24 @@ class FitlogCallback(Callback): self._log_exception = log_exception assert isinstance(log_loss_every, int) and log_loss_every>=0 if tester is not None: - assert isinstance(tester, Tester), "Only fastNLP.Tester allowed." - assert isinstance(data, dict) or data is None, "If tester is not None, only dict[DataSet] allowed for data." - if data is not None: - assert 'test' not in data, "Cannot use `test` as DataSet key, when tester is passed." - setattr(tester, 'verbose', 0) - self.testers['test'] = tester - + if isinstance(tester, dict): + for name, test in tester.items(): + if not isinstance(test, Tester): + raise TypeError(f"{name} in tester is not a valid fastNLP.Tester.") + self.testers['tester-' + name] = test + if isinstance(tester, Tester): + self.testers['tester-test'] = tester + for tester in self.testers.values(): + setattr(tester, 'verbose', 0) + if isinstance(data, dict): for key, value in data.items(): assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}." for key, value in data.items(): - self.datasets[key] = value + self.datasets['data-' + key] = value elif isinstance(data, DataSet): - self.datasets['test'] = data - else: + self.datasets['data-test'] = data + elif data is not None: raise TypeError("data receives dict[DataSet] or DataSet object.") self.verbose = verbose @@ -492,8 +551,11 @@ class FitlogCallback(Callback): if len(self.datasets) > 0: for key, data in self.datasets.items(): - tester = Tester(data=data, model=self.model, batch_size=self.batch_size, metrics=self.trainer.metrics, - verbose=0) + tester = Tester(data=data, model=self.model, + batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), + metrics=self.trainer.metrics, + verbose=0, + use_tqdm=self.trainer.use_tqdm) self.testers[key] = tester fitlog.add_progress(total_steps=self.n_steps) @@ -533,6 +595,65 @@ class FitlogCallback(Callback): fitlog.add_other(repr(exception), name='except_info') +class EvaluateCallback(Callback): + """ + 别名: :class:`fastNLP.EvaluateCallback` :class:`fastNLP.core.callback.EvaluateCallback` + + 该callback用于扩展Trainer训练过程中只能对dev数据进行验证的问题。 + + :param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 + DataSet请通过dict的方式传入。 + :param ~fastNLP.Tester,Dict[~fastNLP.DataSet] tester: Tester对象,将在on_valid_end时调用。 + """ + + def __init__(self, data=None, tester=None): + super().__init__() + self.datasets = {} + self.testers = {} + if tester is not None: + if isinstance(tester, dict): + for name, test in tester.items(): + if not isinstance(test, Tester): + raise TypeError(f"{name} in tester is not a valid fastNLP.Tester.") + self.testers['tester-' + name] = test + if isinstance(tester, Tester): + self.testers['tester-test'] = tester + for tester in self.testers.values(): + setattr(tester, 'verbose', 0) + + if isinstance(data, dict): + for key, value in data.items(): + assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}." + for key, value in data.items(): + self.datasets['data-' + key] = value + elif isinstance(data, DataSet): + self.datasets['data-test'] = data + elif data is not None: + raise TypeError("data receives dict[DataSet] or DataSet object.") + + def on_train_begin(self): + if len(self.datasets) > 0and self.trainer.dev_data is None: + raise RuntimeError("Trainer has no dev data, you cannot pass extra DataSet to do evaluation.") + + if len(self.datasets) > 0: + for key, data in self.datasets.items(): + tester = Tester(data=data, model=self.model, + batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), + metrics=self.trainer.metrics, verbose=0, + use_tqdm=self.trainer.use_tqdm) + self.testers[key] = tester + + def on_valid_end(self, eval_result, metric_key, optimizer, better_result): + if len(self.testers) > 0: + for key, tester in self.testers.items(): + try: + eval_result = tester.test() + self.pbar.write("Evaluation on {}:".format(key)) + self.pbar.write(tester._format_eval_results(eval_result)) + except Exception: + self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) + + class LRScheduler(Callback): """ 别名::class:`fastNLP.LRScheduler` :class:`fastNLP.core.callback.LRScheduler` @@ -884,3 +1005,59 @@ class EarlyStopError(CallbackException): def __init__(self, msg): super(EarlyStopError, self).__init__(msg) + + +class EchoCallback(Callback): + def __init__(self, name, out=sys.stdout): + super(EchoCallback, self).__init__() + self.name = name + self.out = out + + def __getattribute__(self, item): + if item.startswith('on_'): + print('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid()), + file=self.out) + return super(EchoCallback, self).__getattribute__(item) + + +class TesterCallback(Callback): + def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None): + super(TesterCallback, self).__init__() + self.tester = Tester(data, model, + metrics=metrics, batch_size=batch_size, + num_workers=num_workers, verbose=0) + # parse metric_key + # increase_better is True. It means the exp result gets better if the indicator increases. + # It is true by default. + self.increase_better = True + if metric_key is not None: + self.increase_better = False if metric_key[0] == "-" else True + self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key + else: + self.metric_key = None + self.score = None + + def on_validation(self): + cur_score = self.tester.test() + eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format( + self.epoch, self.n_epochs, self.step, self.n_steps, + self.tester._format_eval_results(cur_score)) + self.logger.info(eval_str) + is_better = self.compare_better(cur_score) + if is_better: + self.score = cur_score + return cur_score, is_better + + def compare_better(self, a): + if self.score is None: + return True + k = self.metric_key + is_increase = self.score[k] <= a[k] # if equal, prefer more recent results + if self.increase_better: + return is_increase + else: + return not is_increase + + def on_train_end(self): + self.logger.info('Evaluate on training ends.') + self.on_validation() diff --git a/fastNLP/core/const.py b/fastNLP/core/const.py index 89ff51a2..27e8d1cb 100644 --- a/fastNLP/core/const.py +++ b/fastNLP/core/const.py @@ -7,12 +7,14 @@ class Const: 具体列表:: - INPUT 模型的序列输入 words(复数words1, words2) - CHAR_INPUT 模型character输入 chars(复数chars1, chars2) - INPUT_LEN 序列长度 seq_len(复数seq_len1,seq_len2) - OUTPUT 模型输出 pred(复数pred1, pred2) - TARGET 真实目标 target(复数target1,target2) - LOSS 损失函数 loss (复数loss1,loss2) + INPUT 模型的序列输入 words(具有多列words时,依次使用words1, words2, ) + CHAR_INPUT 模型character输入 chars(具有多列chars时,依次使用chars1, chars2) + INPUT_LEN 序列长度 seq_len(具有多列seq_len时,依次使用seq_len1,seq_len2) + OUTPUT 模型输出 pred(具有多列pred时,依次使用pred1, pred2) + TARGET 真实目标 target(具有多列target时,依次使用target1,target2) + LOSS 损失函数 loss (具有多列loss时,依次使用loss1,loss2) + RAW_WORD 原文的词 raw_words (具有多列raw_words时,依次使用raw_words1, raw_words2) + RAW_CHAR 原文的字 raw_chars (具有多列raw_chars时,依次使用raw_chars1, raw_chars2) """ INPUT = 'words' @@ -21,6 +23,8 @@ class Const: OUTPUT = 'pred' TARGET = 'target' LOSS = 'loss' + RAW_WORD = 'raw_words' + RAW_CHAR = 'raw_chars' @staticmethod def INPUTS(i): @@ -34,6 +38,16 @@ class Const: i = int(i) + 1 return Const.CHAR_INPUT + str(i) + @staticmethod + def RAW_WORDS(i): + i = int(i) + 1 + return Const.RAW_WORD + str(i) + + @staticmethod + def RAW_CHARS(i): + i = int(i) + 1 + return Const.RAW_CHAR + str(i) + @staticmethod def INPUT_LENS(i): """得到第 i 个 ``INPUT_LEN`` 的命名""" diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 7b7fa87a..0f98ed1f 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -291,6 +291,7 @@ import _pickle as pickle import warnings import numpy as np +from copy import deepcopy from .field import AutoPadder from .field import FieldArray @@ -298,6 +299,7 @@ from .instance import Instance from .utils import _get_func_signature from .field import AppendToTargetOrInputException from .field import SetInputOrTargetException +from .const import Const class DataSet(object): """ @@ -349,7 +351,11 @@ class DataSet(object): self.idx]) assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) return self.dataset.field_arrays[item][self.idx] - + + def items(self): + ins = self.dataset[self.idx] + return ins.items() + def __repr__(self): return self.dataset[self.idx].__repr__() @@ -487,7 +493,7 @@ class DataSet(object): """ 删除第index个instance - :param int index: 需要删除的instance的index,从0开始 + :param int index: 需要删除的instance的index,序号从0开始。 """ assert isinstance(index, int), "Only integer supported." if len(self) <= index: @@ -497,6 +503,7 @@ class DataSet(object): else: for field in self.field_arrays.values(): field.pop(index) + return self def delete_field(self, field_name): """ @@ -505,7 +512,22 @@ class DataSet(object): :param str field_name: 需要删除的field的名称. """ self.field_arrays.pop(field_name) - + return self + + def copy_field(self, field_name, new_field_name): + """ + 深度copy名为field_name的field到new_field_name + + :param str field_name: 需要copy的field。 + :param str new_field_name: copy生成的field名称 + :return: self + """ + if not self.has_field(field_name): + raise KeyError(f"Field:{field_name} not found in DataSet.") + fieldarray = deepcopy(self.get_field(field_name)) + self.add_fieldarray(field_name=new_field_name, fieldarray=fieldarray) + return self + def has_field(self, field_name): """ 判断DataSet中是否有名为field_name这个field @@ -566,7 +588,7 @@ class DataSet(object): raise KeyError("DataSet has no field named {}.".format(old_name)) return self - def set_target(self, *field_names, flag=True): + def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): """ 将field_names的field设置为target @@ -577,11 +599,14 @@ class DataSet(object): :param str field_names: field的名称 :param bool flag: 将field_name的target状态设置为flag + :param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一 + 行的数据进行类型和维度推断本列的数据的类型和维度。 """ assert isinstance(flag, bool), "Only bool type supported." for name in field_names: if name in self.field_arrays: try: + self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) self.field_arrays[name].is_target = flag except SetInputOrTargetException as e: print(f"Cannot set field:{name} as target.") @@ -589,7 +614,7 @@ class DataSet(object): else: raise KeyError("{} is not a valid field name.".format(name)) - def set_input(self, *field_names, flag=True): + def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): """ 将field_names的field设置为input:: @@ -598,10 +623,13 @@ class DataSet(object): :param str field_names: field的名称 :param bool flag: 将field_name的input状态设置为flag + :param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一 + 行的数据进行类型和维度推断本列的数据的类型和维度。 """ for name in field_names: if name in self.field_arrays: try: + self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) self.field_arrays[name].is_input = flag except SetInputOrTargetException as e: print(f"Cannot set field:{name} as input, exception happens at the {e.index} value.") @@ -695,7 +723,7 @@ class DataSet(object): results.append(func(ins[field_name])) except Exception as e: if idx != -1: - print("Exception happens at the `{}`th instance.".format(idx)) + print("Exception happens at the `{}`th(from 1) instance.".format(idx+1)) raise e if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None raise ValueError("{} always return None.".format(_get_func_signature(func=func))) @@ -760,10 +788,11 @@ class DataSet(object): results = [] for idx, ins in enumerate(self._inner_iter()): results.append(func(ins)) - except Exception as e: + except BaseException as e: if idx != -1: print("Exception happens at the `{}`th instance.".format(idx)) raise e + # results = [func(ins) for ins in self._inner_iter()] if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None raise ValueError("{} always return None.".format(_get_func_signature(func=func))) @@ -773,7 +802,7 @@ class DataSet(object): return results - def add_seq_len(self, field_name:str, new_field_name='seq_len'): + def add_seq_len(self, field_name:str, new_field_name=Const.INPUT_LEN): """ 将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。 diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py new file mode 100644 index 00000000..00db6361 --- /dev/null +++ b/fastNLP/core/dist_trainer.py @@ -0,0 +1,355 @@ +""" +正在开发中的分布式训练代码 +""" +import torch +import torch.cuda +import torch.optim +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP +import os +from tqdm import tqdm +import logging +import time +from datetime import datetime, timedelta +from functools import partial + +from .batch import DataSetIter, BatchIter +from .callback import DistCallbackManager, CallbackException, TesterCallback +from .dataset import DataSet +from .losses import _prepare_losser +from .optimizer import Optimizer +from .utils import _build_args +from .utils import _move_dict_value_to_device +from .utils import _get_func_signature +from pkg_resources import parse_version + +__all__ = [ + 'get_local_rank', + 'DistTrainer', +] + + +def get_local_rank(): + if 'LOCAL_RANK' in os.environ: + return int(os.environ['LOCAL_RANK']) + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument('--local_rank', type=int) + args, _ = parser.parse_known_args() + if 'local_rank' in args and args.local_rank: + os.environ['LOCAL_RANK'] = str(args.local_rank) # for multiple calls for this function + return args.local_rank + raise RuntimeError('Please use "python -m torch.distributed.launch train_script.py') + + +class DistTrainer(): + """ + Distributed Trainer that support distributed and mixed precision training + """ + def __init__(self, train_data, model, optimizer=None, loss=None, + callbacks_all=None, callbacks_master=None, + batch_size_per_gpu=8, n_epochs=1, + num_data_workers=1, drop_last=False, + dev_data=None, metrics=None, metric_key=None, + update_every=1, print_every=10, validate_every=-1, + log_path=None, + save_every=-1, save_path=None, device='auto', + fp16='', backend=None, init_method=None): + + assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" + if device == 'auto': + device = 'cuda' if torch.cuda.is_available() else 'cpu' + if backend is None: + backend = 'nccl' if device == 'cuda' else 'gloo' + + # init distributed + if device == 'cuda': + torch.cuda.set_device(get_local_rank()) + self.device = torch.device("cuda", get_local_rank()) + else: + self.device = torch.device(device) + + dist.init_process_group(backend=backend, init_method=init_method) + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() # unique id for each process + + self.model = model + self.train_data = train_data + self.batch_size_per_gpu = int(batch_size_per_gpu) + self.n_epochs = int(n_epochs) + self.num_data_workers = int(num_data_workers) + self.drop_last = drop_last + self.update_every = int(update_every) + self.print_every = int(print_every) + self.validate_every = int(validate_every) + self.save_every = int(save_every) + self.save_path = save_path + self.losser = _prepare_losser(loss) + self.fp16 = fp16 + self.init_method = init_method + self.backend = backend + self.local_rank = get_local_rank() + self._forward_func = model.forward + self.callback_manager = DistCallbackManager( + env={"trainer": self}, callbacks_all=callbacks_all, + callbacks_master=callbacks_master) + self.metric_key = metric_key + + model.to(self.device) + optimizer = self._get_optimizer(optimizer) + + # init fp16, must before DataParallel init + if len(self.fp16): + assert isinstance(self.fp16, str), "Please set Apex AMP optimization level selected in ['O0', 'O1', 'O2', 'O3']" + try: + from apex import amp + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." + assert device == 'cuda', "Amp requires cuda device" + model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) + + # init DataParallel + if parse_version(torch.__version__)>=parse_version('1.1'): + self.model = DDP(model, device_ids=[self.local_rank], + output_device=self.local_rank, find_unused_parameters=True) + else: + self.model = DDP(model, device_ids=[self.local_rank], + output_device=self.local_rank) + + self.optimizer = optimizer + self.sampler = DistributedSampler(self.train_data) + self.data_iterator = self._get_data_iter(self.train_data) + self.n_steps = self._get_n_steps() + + # for evaluation, only run eval on master proc + if dev_data and metrics: + cb = TesterCallback( + dev_data, model, metrics, + batch_size=batch_size_per_gpu, num_workers=num_data_workers) + self.callback_manager.callbacks_master += \ + self.callback_manager.prepare_callbacks([cb]) + + # Setup logging + dist.barrier() + self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') + if self.save_path: + self.cp_save_path = os.path.join(self.save_path, 'checkpoints', self.start_time) + else: + self.cp_save_path = None + + # use INFO in the master, WARN for others + logging.basicConfig(filename=log_path, + format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO if self.is_master else logging.WARN) + self.logger = logging.getLogger(__name__) + self.logger.info("Setup Distributed Trainer") + self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( + os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False)) + self.logger.info("Num of processes: {}".format(self.world_size)) + self.logger.info("Use device: {}".format(device)) + self.logger.info("Training with fp16: {}, optimization level: {}".format( + len(self.fp16) > 0, self.fp16 if self.fp16 else None)) + + def _get_n_steps(self): + batch_size = self.world_size * self.batch_size_per_gpu + return (len(self.train_data) // batch_size + int( + len(self.train_data) % batch_size != 0)) * int(self.drop_last == 0) * self.n_epochs + + def _get_data_iter(self, dataset): + if isinstance(dataset, DataSet): + return DataSetIter( + dataset=dataset, batch_size=self.batch_size_per_gpu, + num_workers=self.num_data_workers, sampler=self.sampler, + drop_last=self.drop_last + ) + elif isinstance(dataset, BatchIter): + return dataset + else: + raise TypeError("train_data type {} not support".format(type(dataset))) + + def _get_optimizer(self, optimizer): + if isinstance(optimizer, torch.optim.Optimizer): + return optimizer + elif isinstance(optimizer, Optimizer): + return optimizer.construct_from_pytorch(self.model.parameters()) + elif optimizer is None: + return torch.optim.Adam(self.model.parameters(), lr=4e-3) + else: + raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) + + @property + def is_master(self): + return self.rank == 0 + + def train(self, on_exception='auto'): + try: + self.logger.info("###### Training epochs started ######") + self.logger.info('Total epochs: %d'% self.n_epochs) + self.logger.info('Total steps: %d'% self.n_steps) + self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu) + self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size()) + self.logger.info('Total num of samples: %d'% len(self.train_data)) + self.logger.info("Num of callbacks for all workers: {}".format( + len(self.callback_manager.callbacks_all))) + self.logger.info("Num of callbacks for master workers: {}".format( + len(self.callback_manager.callbacks_master))) + self.logger.info("Callbacks for all workers: {}".format( + [repr(cb) for cb in self.callback_manager.callbacks_all])) + self.logger.info("Callbacks for master workers: {}".format( + [repr(cb) for cb in self.callback_manager.callbacks_master])) + + start_time = time.time() + results = {} + if self.n_epochs <= 0: + self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs)) + results['seconds'] = 0. + return results + + try: + self.callback_manager.on_train_begin() + self._train() + self.callback_manager.on_train_end() + + except BaseException as e: + self.callback_manager.on_exception(e) + if on_exception == 'auto': + if not isinstance(e, (CallbackException, KeyboardInterrupt)): + raise e + else: + self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__)) + elif on_exception == 'raise': + raise e + + results['seconds'] = round(time.time() - start_time, 2) + self.logger.info("###### Train finished ######") + self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) + return results + finally: + self.close() + + def _train(self): + if self.fp16: + # skip check, done in __init__() + from apex import amp + self.step = 0 + self.epoch = 0 + self.pbar = tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', + leave=False, dynamic_ncols=True, disable=not self.is_master) + pbar = self.pbar + avg_loss = 0 + data_iterator = self.data_iterator + self.model.zero_grad() + for epoch in range(1, self.n_epochs + 1): + self.epoch = epoch + pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) + # early stopping + self.callback_manager.on_epoch_begin() + for batch_x, batch_y in data_iterator: + self.model.train() + self.step += 1 + _move_dict_value_to_device(batch_x, batch_y, device=self.device) + indices = data_iterator.get_batch_indices() + # negative sampling; replace unknown; re-weight batch_y + self.callback_manager.on_batch_begin(batch_x, batch_y, indices) + prediction = self._data_forward(self.model, batch_x) + + # edit prediction + self.callback_manager.on_loss_begin(batch_y, prediction) + loss = self._compute_loss(prediction, batch_y) + avg_loss += loss.item() + + # Is loss NaN or inf? requires_grad = False + self.callback_manager.on_backward_begin(loss) + + if self.fp16: + with amp.scale_loss(loss, self.optimizer) as scale_loss: + scale_loss.backward() + else: + loss.backward() + + self.callback_manager.on_backward_end() + + self._update() + self.callback_manager.on_step_end() + + if self.step % self.print_every == 0: + avg_loss = float(avg_loss) / self.print_every + print_output = "loss:{:<6.5f}".format(avg_loss) + pbar.update(self.print_every) + pbar.set_postfix_str(print_output) + avg_loss = 0 + + self.callback_manager.on_batch_end() + + if ((self.validate_every > 0 and self.step % self.validate_every == 0) or + (self.validate_every < 0 and self.step % len(data_iterator) == 0)): + self.callback_manager.on_valid_begin() + eval_res = self.callback_manager.on_validation() + eval_res = list(filter(lambda x: x is not None, eval_res)) + if len(eval_res): + eval_res, is_better = list(zip(*eval_res)) + else: + eval_res, is_better = None, None + self.callback_manager.on_valid_end( + eval_res, self.metric_key, self.optimizer, is_better) + dist.barrier() + + if self.cp_save_path and \ + self.save_every > 0 and \ + self.step % self.save_every == 0: + self.save_check_point() + + # ================= mini-batch end ==================== # + if self.save_every < 0 and self.cp_save_path: + self.save_check_point() + # lr decay; early stopping + self.callback_manager.on_epoch_end() + # =============== epochs end =================== # + pbar.close() + self.pbar = None + # ============ tqdm end ============== # + + def _update(self): + """Perform weight update on a model. + + """ + if self.step % self.update_every == 0: + self.optimizer.step() + self.model.zero_grad() + + def _data_forward(self, network, x): + x = _build_args(self._forward_func, **x) + y = network(**x) + if not isinstance(y, dict): + raise TypeError( + f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") + return y + + def _compute_loss(self, predict, truth): + """Compute loss given prediction and ground truth. + + :param predict: prediction dict, produced by model.forward + :param truth: ground truth dict, produced by batch_y + :return: a scalar + """ + loss = self.losser(predict, truth) + if self.update_every > 1: + loss = loss / self.update_every + return loss.mean() + + def save_check_point(self, only_params=False): + # only master save models + if self.is_master: + os.makedirs(self.cp_save_path, exist_ok=True) + path = os.path.join(self.cp_save_path, 'checkpoint-{}.bin'.format(self.step)) + self.logger.info("Save checkpoint to {}".format(path)) + model_to_save = self.model.module + if only_params: + model_to_save = model_to_save.state_dict() + torch.save(model_to_save, path) + + def close(self): + dist.destroy_process_group() diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index bba854f5..65bd9be4 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -7,6 +7,7 @@ from typing import Any from abc import abstractmethod from copy import deepcopy from collections import Counter +from .utils import _is_iterable class SetInputOrTargetException(Exception): def __init__(self, msg, index=None, field_name=None): @@ -23,7 +24,8 @@ class AppendToTargetOrInputException(Exception): self.field_name = field_name # 标示当前field的名称 class FieldArray: - def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False): + def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False, + use_1st_ins_infer_dim_type=True): if len(content)==0: raise RuntimeError("Empty fieldarray is not allowed.") _content = content @@ -38,6 +40,7 @@ class FieldArray: # 根据input的情况设置input,target等 self._cell_ndim = None # 多少维度 self.dtype = None # 最内层的element都是什么类型的 + self._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) self._is_input = False self._is_target = False @@ -77,7 +80,7 @@ class FieldArray: if value is True and \ self._is_target is False and \ self._ignore_type is False: - self._check_dtype_and_ndim() + self._check_dtype_and_ndim(only_check_1st_ins_dim_type=self._use_1st_ins_infer_dim_type) if value is False and self._is_target is False: self.dtype = None self._cell_ndim = None @@ -95,32 +98,34 @@ class FieldArray: if value is True and \ self._is_input is False and \ self._ignore_type is False: - self._check_dtype_and_ndim() + self._check_dtype_and_ndim(only_check_1st_ins_dim_type=self._use_1st_ins_infer_dim_type) if value is False and self._is_input is False: self.dtype = None self._cell_ndim = None self._is_target = value - def _check_dtype_and_ndim(self): + def _check_dtype_and_ndim(self, only_check_1st_ins_dim_type=True): """ 检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有 通过将直接报错. + :param bool only_check_1st_ins_dim_type: 是否只检查第一个元素的type和dim :return: """ cell_0 = self.content[0] index = 0 try: type_0, dim_0 = _get_ele_type_and_dim(cell_0) - for cell in self.content[1:]: - index += 1 - type_i, dim_i = _get_ele_type_and_dim(cell) - if type_i!=type_0: - raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}." - ".".format(type_i, index, type_0)) - if dim_0!=dim_i: - raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with " - "dimension:{}.".format(dim_i, index, dim_0)) + if not only_check_1st_ins_dim_type: + for cell in self.content[1:]: + index += 1 + type_i, dim_i = _get_ele_type_and_dim(cell) + if type_i!=type_0: + raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}." + ".".format(type_i, index, type_0)) + if dim_0!=dim_i: + raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with " + "dimension:{}.".format(dim_i, index, dim_0)) self._cell_ndim = dim_0 self.dtype = type_0 except SetInputOrTargetException as e: @@ -132,7 +137,7 @@ class FieldArray: :param val: 把该val append到fieldarray。 :return: """ - if (self._is_target or self._is_input) and self._ignore_type is False: + if (self._is_target or self._is_input) and self._ignore_type is False and not self._use_1st_ins_infer_dim_type: type_, dim_ = _get_ele_type_and_dim(val) if self.dtype!=type_: raise AppendToTargetOrInputException(f"Value(type:{type_}) are of different types with " @@ -144,6 +149,14 @@ class FieldArray: else: self.content.append(val) + def pop(self, index): + """ + 删除该field中index处的元素 + :param int index: 从0开始的数据下标。 + :return: + """ + self.content.pop(index) + def __getitem__(self, indices): return self.get(indices, pad=False) @@ -431,15 +444,6 @@ def _get_ele_type_and_dim(cell:Any, dim=0): raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") -def _is_iterable(value): - # 检查是否是iterable的, duck typing - try: - iter(value) - return True - except BaseException as e: - return False - - class Padder: """ 别名::class:`fastNLP.Padder` :class:`fastNLP.core.field.Padder` diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py index 5408522e..9a5d9edf 100644 --- a/fastNLP/core/instance.py +++ b/fastNLP/core/instance.py @@ -35,6 +35,13 @@ class Instance(object): :param Any field: 新增field的内容 """ self.fields[field_name] = field + + def items(self): + """ + 返回一个迭代器,迭代器返回两个内容,第一个内容是field_name, 第二个内容是field_value + :return: + """ + return self.fields.items() def __getitem__(self, name): if name in self.fields: diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 1f8923eb..05e5b440 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -28,6 +28,7 @@ from .utils import _check_arg_dict_list from .utils import _check_function_or_method from .utils import _get_func_signature from .utils import seq_len_to_mask +import warnings class LossBase(object): @@ -225,8 +226,9 @@ class CrossEntropyLoss(LossBase): def get_loss(self, pred, target, seq_len=None): if pred.dim() > 2: - if pred.size(1) != target.size(1): - pred = pred.transpose(1, 2) + if pred.size(1) != target.size(1): # 有可能顺序替换了 + raise RuntimeError("It seems like that your prediction's shape is (batch_size, num_labels, max_len)." + " It should be (batch_size, max_len, num_labels).") pred = pred.reshape(-1, pred.size(-1)) target = target.reshape(-1) if seq_len is not None: diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index f23eab91..8dd51eb6 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -624,7 +624,7 @@ class SpanFPreRecMetric(MetricBase): f, pre, rec = self._compute_f_pre_rec(tp, fn, fp) f_sum += f pre_sum += pre - rec_sum + rec + rec_sum += rec if not self.only_gross and tag != '': # tag!=''防止无tag的情况 f_key = 'f-{}'.format(tag) pre_key = 'pre-{}'.format(tag) @@ -814,8 +814,8 @@ class ExtractiveQAMetric(MetricBase): if not self.right_open: e += 1 te += 1 - if ts == 0 and te == int(not self.right_open): - if s == 0 and e == int(not self.right_open): + if ts == 0 and te == 1: + if s == 0 and e == 1: self.no_ans_correct += 1 self.no2no += 1 else: diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py index 3036257c..e95047b4 100644 --- a/fastNLP/core/optimizer.py +++ b/fastNLP/core/optimizer.py @@ -49,7 +49,7 @@ class NullOptimizer(Optimizer): super().__init__(None) def construct_from_pytorch(self, model_params): - pass + return self def __getattr__(self, item): def pass_func(*args, **kwargs): diff --git a/fastNLP/core/sampler.py b/fastNLP/core/sampler.py index d8ba1ad1..9ca04fa0 100644 --- a/fastNLP/core/sampler.py +++ b/fastNLP/core/sampler.py @@ -25,9 +25,9 @@ class Sampler(object): def __call__(self, data_set): """ - :param DataSet data_set: `DataSet` 对象, 需要Sample的数据 - :return result: list(int) 其中元素的下标序列, ``data_set`` 中元素会按 ``result`` 中顺序取出 - """ + :param DataSet data_set: `DataSet` 对象, 需要Sample的数据 + :return result: list(int) 其中元素的下标序列, ``data_set`` 中元素会按 ``result`` 中顺序取出 + """ raise NotImplementedError diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index c1d270d1..691bf2ae 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -32,9 +32,16 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation """ +import time + import torch import torch.nn as nn +try: + from tqdm.auto import tqdm +except: + from .utils import _pseudo_tqdm as tqdm + from .batch import BatchIter, DataSetIter from .dataset import DataSet from .metrics import _prepare_metrics @@ -47,6 +54,7 @@ from .utils import _get_func_signature from .utils import _get_model_device from .utils import _move_model_to_device from ._parallel_utils import _data_parallel_wrapper +from ._parallel_utils import _model_contains_inner_module from functools import partial __all__ = [ @@ -79,13 +87,12 @@ class Tester(object): 如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 + :param bool use_tqdm: 是否使用tqdm来显示测试进度; 如果为False,则不会显示任何内容。 """ - def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1): + def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True): super(Tester, self).__init__() - - if not isinstance(data, DataSet): - raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.") + if not isinstance(model, nn.Module): raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.") @@ -95,6 +102,7 @@ class Tester(object): self._model = _move_model_to_device(model, device=device) self.batch_size = batch_size self.verbose = verbose + self.use_tqdm = use_tqdm if isinstance(data, DataSet): self.data_iterator = DataSetIter( @@ -106,19 +114,22 @@ class Tester(object): # check predict if (hasattr(self._model, 'predict') and callable(self._model.predict)) or \ - (isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and - callable(self._model.module.predict)): + (_model_contains_inner_module(self._model) and hasattr(self._model.module, 'predict') and + callable(self._model.module.predict)): if isinstance(self._model, nn.DataParallel): self._predict_func_wrapper = partial(_data_parallel_wrapper('predict', self._model.device_ids, self._model.output_device), network=self._model.module) + self._predict_func = self._model.module.predict # 用于匹配参数 + elif isinstance(self._model, nn.parallel.DistributedDataParallel): self._predict_func = self._model.module.predict + self._predict_func_wrapper = self._model.module.predict # 用于调用 else: self._predict_func = self._model.predict self._predict_func_wrapper = self._model.predict else: - if isinstance(self._model, nn.DataParallel): + if _model_contains_inner_module(model): self._predict_func_wrapper = self._model.forward self._predict_func = self._model.module.forward else: @@ -139,21 +150,39 @@ class Tester(object): eval_results = {} try: with torch.no_grad(): - for batch_x, batch_y in data_iterator: - _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) - pred_dict = self._data_forward(self._predict_func, batch_x) - if not isinstance(pred_dict, dict): - raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} " - f"must be `dict`, got {type(pred_dict)}.") + if not self.use_tqdm: + from .utils import _pseudo_tqdm as inner_tqdm + else: + inner_tqdm = tqdm + with inner_tqdm(total=len(data_iterator), leave=False, dynamic_ncols=True) as pbar: + pbar.set_description_str(desc="Test") + + start_time = time.time() + + for batch_x, batch_y in data_iterator: + _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) + pred_dict = self._data_forward(self._predict_func, batch_x) + if not isinstance(pred_dict, dict): + raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} " + f"must be `dict`, got {type(pred_dict)}.") + for metric in self.metrics: + metric(pred_dict, batch_y) + + if self.use_tqdm: + pbar.update() + for metric in self.metrics: - metric(pred_dict, batch_y) - for metric in self.metrics: - eval_result = metric.get_metric() - if not isinstance(eval_result, dict): - raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be " - f"`dict`, got {type(eval_result)}") - metric_name = metric.__class__.__name__ - eval_results[metric_name] = eval_result + eval_result = metric.get_metric() + if not isinstance(eval_result, dict): + raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be " + f"`dict`, got {type(eval_result)}") + metric_name = metric.__class__.__name__ + eval_results[metric_name] = eval_result + + end_time = time.time() + test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!' + pbar.write(test_str) + pbar.close() except _CheckError as e: prev_func_signature = _get_func_signature(self._predict_func) _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 671e2736..0d239048 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -352,6 +352,7 @@ from .utils import _move_dict_value_to_device from .utils import _get_func_signature from .utils import _get_model_device from .utils import _move_model_to_device +from ._parallel_utils import _model_contains_inner_module class Trainer(object): @@ -389,8 +390,8 @@ class Trainer(object): 要指定以哪个指标为准。另外有些指标是越小效果越好,比如语言模型的困惑度,这种情况下,在key前面增加一个'-'来表 明验证时,值越小越好(比如: "-ppl")。仅在传入dev_data时有效。 :param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 - :param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模型。 - 保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 + :param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存 + 最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 :param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 的计算位置进行管理。支持以下的输入: @@ -421,7 +422,7 @@ class Trainer(object): num_workers=0, n_epochs=10, print_every=5, dev_data=None, metrics=None, metric_key=None, validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False, - callbacks=None, check_code_level=0): + callbacks=None, check_code_level=0, **kwargs): if prefetch and num_workers==0: num_workers = 1 if prefetch: @@ -430,23 +431,23 @@ class Trainer(object): super(Trainer, self).__init__() if not isinstance(model, nn.Module): raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") - + # check metrics and dev_data if (not metrics) and dev_data is not None: raise ValueError("No metric for dev_data evaluation.") if metrics and (dev_data is None): raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") - + # check update every assert update_every >= 1, "update_every must be no less than 1." self.update_every = int(update_every) - + # check save_path if not (save_path is None or isinstance(save_path, str)): raise ValueError("save_path can only be None or `str`.") # prepare evaluate metrics = _prepare_metrics(metrics) - + # parse metric_key # increase_better is True. It means the exp result gets better if the indicator increases. # It is true by default. @@ -458,30 +459,69 @@ class Trainer(object): self.metric_key = None # prepare loss losser = _prepare_losser(loss) - - # sampler check - if sampler is not None and not isinstance(sampler, Sampler): - raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) - if sampler is None: - sampler = RandomSampler() - elif hasattr(sampler, 'set_batch_size'): - sampler.set_batch_size(batch_size) + if isinstance(train_data, BatchIter): + if sampler is not None: + warnings.warn("sampler is ignored when train_data is a BatchIter.") + if num_workers>0: + warnings.warn("num_workers is ignored when train_data is BatchIter.") + if drop_last: + warnings.warn("drop_last is ignored when train_data is BatchIter.") + + if isinstance(model, nn.parallel.DistributedDataParallel): # 如果是分布式的 + # device为None + if device is not None: + warnings.warn("device is ignored when model is nn.parallel.DistributedDataParallel.") + device = None + # Sampler要是分布式的 + if sampler is None: + sampler = torch.utils.data.DistributedSampler(train_data) + elif not isinstance(sampler, torch.utils.data.DistributedSampler): + raise TypeError("When using nn.parallel.DistributedDataParallel, " + "sampler must be None or torch.utils.data.DistributedSampler.") + # 不能保存模型 + if save_path: + raise RuntimeError("Saving model in Distributed situation is not allowed right now.") + else: + # sampler check + if sampler is not None and not isinstance(sampler, (Sampler, torch.utils.data.Sampler)): + raise ValueError(f"The type of sampler should be fastNLP.BaseSampler or pytorch's Sampler, got {type(sampler)}") + if sampler is None: + sampler = RandomSampler() + elif hasattr(sampler, 'set_batch_size'): + sampler.set_batch_size(batch_size) if isinstance(train_data, DataSet): self.data_iterator = DataSetIter( dataset=train_data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, drop_last=drop_last) elif isinstance(train_data, BatchIter): self.data_iterator = train_data + train_data = train_data.dataset else: raise TypeError("train_data type {} not support".format(type(train_data))) - if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): - _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, - metric_key=self.metric_key, check_level=check_code_level, - batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) - # _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 self.model = _move_model_to_device(model, device=device) + if _model_contains_inner_module(self.model): + self._forward_func = self.model.module.forward + else: + self._forward_func = self.model.forward + if check_code_level > -1: + # _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的field名与模型的输入 + # 名是否匹配 + dev_dataset = dev_data + if isinstance(dev_data, BatchIter): + dev_dataset = None + warnings.warn("dev_data is of BatchIter type, ignore validation checking.") + check_batch_size = min(batch_size, DEFAULT_CHECK_BATCH_SIZE) + if isinstance(self.model, nn.DataParallel): + _num_devices = len(self.model.device_ids) + if batch_size//_num_devices>1: # 如果多卡是每个卡可以分多个数据的,则用每个卡给两个sample + check_batch_size = max(len(self.model.device_ids)*2, check_batch_size) + else: + check_batch_size = max(len(self.model.device_ids), check_batch_size) + _check_code(dataset=train_data, model=self.model, losser=losser, forward_func=self._forward_func, metrics=metrics, + dev_data=dev_dataset, metric_key=self.metric_key, check_level=check_code_level, + batch_size=check_batch_size) self.train_data = train_data self.dev_data = dev_data # If None, No validation. @@ -496,8 +536,7 @@ class Trainer(object): self.best_dev_epoch = None self.best_dev_step = None self.best_dev_perf = None - self.n_steps = (len(self.train_data) // self.batch_size + int( - len(self.train_data) % self.batch_size != 0)) * int(drop_last==0) * self.n_epochs + self.n_steps = len(self.data_iterator) * self.n_epochs if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer @@ -507,22 +546,23 @@ class Trainer(object): self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) else: raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) - + self.use_tqdm = use_tqdm self.pbar = None self.print_every = abs(self.print_every) - + self.kwargs = kwargs if self.dev_data is not None: self.tester = Tester(model=self.model, data=self.dev_data, metrics=self.metrics, - batch_size=self.batch_size, + batch_size=kwargs.get("dev_batch_size", self.batch_size), device=None, # 由上面的部分处理device - verbose=0) - + verbose=0, + use_tqdm=self.use_tqdm) + self.step = 0 self.start_time = None # start timestamp - + self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) @@ -558,7 +598,7 @@ class Trainer(object): self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) start_time = time.time() print("training epochs started " + self.start_time, flush=True) - + try: self.callback_manager.on_train_begin() self._train() @@ -571,7 +611,7 @@ class Trainer(object): raise e elif on_exception == 'raise': raise e - + if self.dev_data is not None and self.best_dev_perf is not None: print( "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + @@ -589,21 +629,17 @@ class Trainer(object): finally: pass results['seconds'] = round(time.time() - start_time, 2) - + return results - + def _train(self): if not self.use_tqdm: - from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm + from .utils import _pseudo_tqdm as inner_tqdm else: inner_tqdm = tqdm self.step = 0 self.epoch = 0 start = time.time() - if isinstance(self.model, nn.DataParallel): - self._forward_func = self.model.module.forward - else: - self._forward_func = self.model.forward with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: self.pbar = pbar avg_loss = 0 @@ -621,21 +657,21 @@ class Trainer(object): # negative sampling; replace unknown; re-weight batch_y self.callback_manager.on_batch_begin(batch_x, batch_y, indices) prediction = self._data_forward(self.model, batch_x) - + # edit prediction self.callback_manager.on_loss_begin(batch_y, prediction) loss = self._compute_loss(prediction, batch_y).mean() avg_loss += loss.item() loss = loss / self.update_every - + # Is loss NaN or inf? requires_grad = False self.callback_manager.on_backward_begin(loss) self._grad_backward(loss) self.callback_manager.on_backward_end() - + self._update() self.callback_manager.on_step_end() - + if self.step % self.print_every == 0: avg_loss = float(avg_loss) / self.print_every if self.use_tqdm: @@ -649,29 +685,29 @@ class Trainer(object): pbar.set_postfix_str(print_output) avg_loss = 0 self.callback_manager.on_batch_end() - + if ((self.validate_every > 0 and self.step % self.validate_every == 0) or (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ and self.dev_data is not None: eval_res = self._do_validation(epoch=epoch, step=self.step) - eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, + eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, self.n_steps) + \ self.tester._format_eval_results(eval_res) pbar.write(eval_str + '\n') - + # ================= mini-batch end ==================== # - + # lr decay; early stopping self.callback_manager.on_epoch_end() # =============== epochs end =================== # pbar.close() self.pbar = None # ============ tqdm end ============== # - + def _do_validation(self, epoch, step): self.callback_manager.on_valid_begin() res = self.tester.test() - + is_better_eval = False if self._better_eval_result(res): if self.save_path is not None: @@ -686,7 +722,7 @@ class Trainer(object): # get validation results; adjust optimizer self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) return res - + def _mode(self, model, is_test=False): """Train mode or Test mode. This is for PyTorch currently. @@ -698,14 +734,14 @@ class Trainer(object): model.eval() else: model.train() - + def _update(self): """Perform weight update on a model. """ if self.step % self.update_every == 0: self.optimizer.step() - + def _data_forward(self, network, x): x = _build_args(self._forward_func, **x) y = network(**x) @@ -713,7 +749,7 @@ class Trainer(object): raise TypeError( f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") return y - + def _grad_backward(self, loss): """Compute gradient with link rules. @@ -724,7 +760,7 @@ class Trainer(object): if (self.step-1) % self.update_every == 0: self.model.zero_grad() loss.backward() - + def _compute_loss(self, predict, truth): """Compute loss given prediction and ground truth. @@ -733,7 +769,7 @@ class Trainer(object): :return: a scalar """ return self.losser(predict, truth) - + def _save_model(self, model, model_name, only_param=False): """ 存储不含有显卡信息的state_dict或model :param model: @@ -745,7 +781,7 @@ class Trainer(object): model_path = os.path.join(self.save_path, model_name) if not os.path.exists(self.save_path): os.makedirs(self.save_path, exist_ok=True) - if isinstance(model, nn.DataParallel): + if _model_contains_inner_module(model): model = model.module if only_param: state_dict = model.state_dict() @@ -756,7 +792,7 @@ class Trainer(object): model.cpu() torch.save(model, model_path) model.to(self._model_device) - + def _load_model(self, model, model_name, only_param=False): # 返回bool值指示是否成功reload模型 if self.save_path is not None: @@ -765,7 +801,7 @@ class Trainer(object): states = torch.load(model_path) else: states = torch.load(model_path).state_dict() - if isinstance(model, nn.DataParallel): + if _model_contains_inner_module(model): model.module.load_state_dict(states) else: model.load_state_dict(states) @@ -774,7 +810,7 @@ class Trainer(object): else: return False return True - + def _better_eval_result(self, metrics): """Check if the current epoch yields better validation results. @@ -800,6 +836,9 @@ class Trainer(object): is_better = False return is_better + @property + def is_master(self): + return True DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_NUM_BATCH = 2 @@ -821,14 +860,15 @@ def _get_value_info(_dict): strs.append(_str) return strs + from numbers import Number from .batch import _to_tensor -def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, - dev_data=None, metric_key=None, - check_level=0): + + +def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAULT_CHECK_BATCH_SIZE, + dev_data=None, metric_key=None, check_level=0): # check get_loss 方法 - model_devcie = _get_model_device(model=model) - + model_device = _get_model_device(model=model) def _iter(): start_idx = 0 while start_idx List[str]: else: raise TypeError("Invalid IOB format.") return new_tags + + +def _is_iterable(value): + # 检查是否是iterable的, duck typing + try: + iter(value) + return True + except BaseException as e: + return False + + +def get_seq_len(words, pad_value=0): + """ + 给定batch_size x max_len的words矩阵,返回句子长度 + + :param words: batch_size x max_len + :return: (batch_size,) + """ + mask = words.ne(pad_value) + return mask.sum(dim=-1) diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 9ce59a8c..330d73dd 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -4,12 +4,12 @@ __all__ = [ ] from functools import wraps -from collections import Counter, defaultdict +from collections import Counter from .dataset import DataSet from .utils import Option from functools import partial import numpy as np - +from .utils import _is_iterable class VocabularyOption(Option): def __init__(self, @@ -131,11 +131,11 @@ class Vocabulary(object): """ 在新加入word时,检查_no_create_word的设置。 - :param str, List[str] word: + :param str List[str] word: :param bool no_create_entry: :return: """ - if isinstance(word, str): + if isinstance(word, str) or not _is_iterable(word): word = [word] for w in word: if no_create_entry and self.word_count.get(w, 0) == self._no_create_word.get(w, 0): @@ -257,35 +257,45 @@ class Vocabulary(object): vocab.index_dataset(train_data, dev_data, test_data, field_name='words') :param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集 - :param str field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field. - 目前仅支持 ``str`` , ``List[str]`` , ``List[List[str]]`` - :param str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. - Default: ``None`` + :param list,str field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field. + 目前支持 ``str`` , ``List[str]`` + :param list,str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. + Default: ``None``. """ - def index_instance(ins): + def index_instance(field): """ 有几种情况, str, 1d-list, 2d-list :param ins: :return: """ - field = ins[field_name] - if isinstance(field, str): + if isinstance(field, str) or not _is_iterable(field): return self.to_index(field) - elif isinstance(field, list): - if not isinstance(field[0], list): + else: + if isinstance(field[0], str) or not _is_iterable(field[0]): return [self.to_index(w) for w in field] else: - if isinstance(field[0][0], list): + if not isinstance(field[0][0], str) and _is_iterable(field[0][0]): raise RuntimeError("Only support field with 2 dimensions.") return [[self.to_index(c) for c in w] for w in field] - - if new_field_name is None: - new_field_name = field_name + + new_field_name = new_field_name or field_name + + if type(new_field_name) == type(field_name): + if isinstance(new_field_name, list): + assert len(new_field_name) == len(field_name), "new_field_name should have same number elements with " \ + "field_name." + elif isinstance(new_field_name, str): + field_name = [field_name] + new_field_name = [new_field_name] + else: + raise TypeError("field_name and new_field_name can only be str or List[str].") + for idx, dataset in enumerate(datasets): if isinstance(dataset, DataSet): try: - dataset.apply(index_instance, new_field_name=new_field_name) + for f_n, n_f_n in zip(field_name, new_field_name): + dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n) except Exception as e: print("When processing the `{}` dataset, the following error occurred.".format(idx)) raise e @@ -306,9 +316,8 @@ class Vocabulary(object): :param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集 :param str,List[str] field_name: 可为 ``str`` 或 ``List[str]`` . - 构建词典所使用的 field(s), 支持一个或多个field - 若有多个 DataSet, 每个DataSet都必须有这些field. - 目前仅支持的field结构: ``str`` , ``List[str]`` , ``list[List[str]]`` + 构建词典所使用的 field(s), 支持一个或多个field,若有多个 DataSet, 每个DataSet都必须有这些field. 目前支持的field结构 + : ``str`` , ``List[str]`` :param no_create_entry_dataset: 可以传入DataSet, List[DataSet]或者None(默认),该选项用在接下来的模型会使用pretrain 的embedding(包括glove, word2vec, elmo与bert)且会finetune的情况。如果仅使用来自于train的数据建立vocabulary,会导致test与dev 中的数据无法充分利用到来自于预训练embedding的信息,所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。 @@ -326,14 +335,14 @@ class Vocabulary(object): def construct_vocab(ins, no_create_entry=False): for fn in field_name: field = ins[fn] - if isinstance(field, str): + if isinstance(field, str) or not _is_iterable(field): self.add_word(field, no_create_entry=no_create_entry) - elif isinstance(field, (list, np.ndarray)): - if not isinstance(field[0], (list, np.ndarray)): + else: + if isinstance(field[0], str) or not _is_iterable(field[0]): for word in field: self.add_word(word, no_create_entry=no_create_entry) else: - if isinstance(field[0][0], (list, np.ndarray)): + if not isinstance(field[0][0], str) and _is_iterable(field[0][0]): raise RuntimeError("Only support field with 2 dimensions.") for words in field: for word in words: @@ -343,8 +352,8 @@ class Vocabulary(object): if isinstance(dataset, DataSet): try: dataset.apply(construct_vocab) - except Exception as e: - print("When processing the `{}` dataset, the following error occurred.".format(idx)) + except BaseException as e: + print("When processing the `{}` dataset, the following error occurred:".format(idx)) raise e else: raise TypeError("Only DataSet type is allowed.") @@ -367,7 +376,7 @@ class Vocabulary(object): :return: bool """ return word in self._no_create_word - + def to_index(self, w): """ 将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``:: diff --git a/fastNLP/embeddings/__init__.py b/fastNLP/embeddings/__init__.py index 2bfb2960..4f90ac63 100644 --- a/fastNLP/embeddings/__init__.py +++ b/fastNLP/embeddings/__init__.py @@ -10,6 +10,7 @@ __all__ = [ "StaticEmbedding", "ElmoEmbedding", "BertEmbedding", + "BertWordPieceEncoder", "StackEmbedding", "LSTMCharEmbedding", "CNNCharEmbedding", @@ -20,7 +21,7 @@ __all__ = [ from .embedding import Embedding from .static_embedding import StaticEmbedding from .elmo_embedding import ElmoEmbedding -from .bert_embedding import BertEmbedding +from .bert_embedding import BertEmbedding, BertWordPieceEncoder from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding from .stack_embedding import StackEmbedding from .utils import get_embeddings \ No newline at end of file diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 5d46d98c..7a9738fe 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -8,10 +8,10 @@ import numpy as np from itertools import chain from ..core.vocabulary import Vocabulary -from ..io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR +from ..io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer from .contextual_embedding import ContextualEmbedding - +import warnings class BertEmbedding(ContextualEmbedding): """ @@ -38,8 +38,8 @@ class BertEmbedding(ContextualEmbedding): :param ~fastNLP.Vocabulary vocab: 词表 :param str model_dir_or_name: 模型所在目录或者模型的名称。当传入模型所在目录时,目录中应该包含一个词表文件(以.txt作为后缀名), 权重文件(以.bin作为文件后缀名), 配置文件(以.json作为后缀名)。 - :param str layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,可以以负数 - 去索引倒数几层。 + :param str layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,层的序号是 + 从0开始,可以以负数去索引倒数几层。 :param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces 中计算得到它对应的表示。支持 ``last`` , ``first`` , ``avg`` , ``max``。 :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 @@ -47,27 +47,35 @@ class BertEmbedding(ContextualEmbedding): :param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样 会使得word embedding的结果比输入的结果长两个token。如果该值为True,则在使用 :class::StackEmbedding 可能会与其它类型的 embedding长度不匹配。 + :param bool pooled_cls: 返回的[CLS]是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取[CLS]做预测, + 一般该值为True。 :param bool requires_grad: 是否需要gradient以更新Bert的权重。 + :param bool auto_truncate: 当句子words拆分为word pieces长度超过bert最大允许长度(一般为512), 自动截掉拆分后的超过510个 + word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS] + 来进行分类的任务将auto_truncate置为True。 """ def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1', - pool_method: str='first', word_dropout=0, dropout=0, requires_grad: bool=False, - include_cls_sep: bool=False): + pool_method: str='first', word_dropout=0, dropout=0, include_cls_sep: bool=False, + pooled_cls=True, requires_grad: bool=False, auto_truncate:bool=False): super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) # 根据model_dir_or_name检查是否存在并下载 if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: - PRETRAIN_URL = _get_base_url('bert') - model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] - model_url = PRETRAIN_URL + model_name - model_dir = cached_path(model_url) + model_url = _get_embedding_url('bert', model_dir_or_name.lower()) + model_dir = cached_path(model_url, name='embedding') # 检查是否存在 elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): - model_dir = model_dir_or_name + model_dir = os.path.expanduser(os.path.abspath(model_dir_or_name)) else: raise ValueError(f"Cannot recognize {model_dir_or_name}.") + self._word_sep_index = None + if '[SEP]' in vocab: + self._word_sep_index = vocab['[SEP]'] + self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers, - pool_method=pool_method, include_cls_sep=include_cls_sep) + pool_method=pool_method, include_cls_sep=include_cls_sep, + pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2) self.requires_grad = requires_grad self._embed_size = len(self.model.layers)*self.model.encoder.hidden_size @@ -83,7 +91,11 @@ class BertEmbedding(ContextualEmbedding): :param torch.LongTensor words: [batch_size, max_len] :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) """ + if self._word_sep_index: # 不能drop sep + sep_mask = words.eq(self._word_sep_index) words = self.drop_word(words) + if self._word_sep_index: + words.masked_fill_(sep_mask, self._word_sep_index) outputs = self._get_sent_reprs(words) if outputs is not None: return self.dropout(words) @@ -120,24 +132,24 @@ class BertWordPieceEncoder(nn.Module): :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` :param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 + :param bool pooled_cls: 返回的句子开头的[CLS]是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取 + [CLS]做预测,一般该值为True。 :param bool requires_grad: 是否需要gradient。 """ def __init__(self, model_dir_or_name: str='en-base-uncased', layers: str='-1', - requires_grad: bool=False): + pooled_cls: bool = False, requires_grad: bool=False): super().__init__() - PRETRAIN_URL = _get_base_url('bert') - if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: - model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] - model_url = PRETRAIN_URL + model_name - model_dir = cached_path(model_url) + if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: + model_url = _get_embedding_url('bert', model_dir_or_name.lower()) + model_dir = cached_path(model_url, name='embedding') # 检查是否存在 - elif os.path.isdir(model_dir_or_name): + elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): model_dir = model_dir_or_name else: raise ValueError(f"Cannot recognize {model_dir_or_name}.") - self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers) + self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers, pooled_cls=pooled_cls) self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size self.requires_grad = requires_grad @@ -162,16 +174,25 @@ class BertWordPieceEncoder(nn.Module): def embed_size(self): return self._embed_size - def index_datasets(self, *datasets, field_name): + @property + def embedding_dim(self): + return self._embed_size + + @property + def num_embedding(self): + return self.model.encoder.config.vocab_size + + def index_datasets(self, *datasets, field_name, add_cls_sep=True): """ - 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是 - [CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 + 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了 + bert的pad value。 - :param datasets: DataSet对象 - :param field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 + :param ~fastNLP.DataSet datasets: DataSet对象 + :param str field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 + :param bool add_cls_sep: 如果首尾不是[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP]。 :return: """ - self.model.index_dataset(*datasets, field_name=field_name) + self.model.index_dataset(*datasets, field_name=field_name, add_cls_sep=add_cls_sep) def forward(self, word_pieces, token_type_ids=None): """ @@ -188,11 +209,13 @@ class BertWordPieceEncoder(nn.Module): class _WordBertModel(nn.Module): - def __init__(self, model_dir:str, vocab:Vocabulary, layers:str='-1', pool_method:str='first', include_cls_sep:bool=False): + def __init__(self, model_dir:str, vocab:Vocabulary, layers:str='-1', pool_method:str='first', + include_cls_sep:bool=False, pooled_cls:bool=False, auto_truncate:bool=False, min_freq=2): super().__init__() self.tokenzier = BertTokenizer.from_pretrained(model_dir) self.encoder = BertModel.from_pretrained(model_dir) + self._max_position_embeddings = self.encoder.config.max_position_embeddings # 检查encoder_layer_number是否合理 encoder_layer_number = len(self.encoder.encoder.layer) self.layers = list(map(int, layers.split(','))) @@ -207,12 +230,21 @@ class _WordBertModel(nn.Module): assert pool_method in ('avg', 'max', 'first', 'last') self.pool_method = pool_method self.include_cls_sep = include_cls_sep + self.pooled_cls = pooled_cls + self.auto_truncate = auto_truncate # 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] print("Start to generating word pieces for word.") # 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 word_piece_dict = {'[CLS]':1, '[SEP]':1} # 用到的word_piece以及新增的 found_count = 0 + self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids + if '[sep]' in vocab: + warnings.warn("Lower cased [sep] detected, it cannot be correctly recognized as [SEP] by BertEmbedding.") + if "[CLS]" in vocab: + warnings.warn("[CLS] detected in your vocabulary. BertEmbedding will add [CSL] and [SEP] to the begin " + "and end of the input automatically, make sure you don't add [CLS] and [SEP] at the begin" + " and end.") for word, index in vocab: if index == vocab.padding_idx: # pad是个特殊的符号 word = '[PAD]' @@ -222,7 +254,8 @@ class _WordBertModel(nn.Module): if len(word_pieces)==1: if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 if index!=vocab.unknown_idx and word_pieces[0]=='[UNK]': # 说明这个词不在原始的word里面 - word_piece_dict[word] = 1 # 新增一个值 + if vocab.word_count[word]>=min_freq and not vocab._is_word_no_create_entry(word): #出现次数大于这个次数才新增 + word_piece_dict[word] = 1 # 新增一个值 continue for word_piece in word_pieces: word_piece_dict[word_piece] = 1 @@ -258,7 +291,7 @@ class _WordBertModel(nn.Module): print("Found(Or seg into word pieces) {} words out of {}.".format(found_count, len(vocab))) self._cls_index = self.tokenzier.vocab['[CLS]'] self._sep_index = self.tokenzier.vocab['[SEP]'] - self._pad_index = vocab.padding_idx + self._word_pad_index = vocab.padding_idx self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece self.word_to_wordpieces = np.array(word_to_wordpieces) self.word_pieces_lengths = nn.Parameter(torch.LongTensor(word_pieces_lengths), requires_grad=False) @@ -270,29 +303,50 @@ class _WordBertModel(nn.Module): :param words: torch.LongTensor, batch_size x max_len :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size """ - batch_size, max_word_len = words.size() - seq_len = words.ne(self._pad_index).sum(dim=-1) - batch_word_pieces_length = self.word_pieces_lengths[words] # batch_size x max_len - word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) - max_word_piece_length = word_pieces_lengths.max().item() - # +2是由于需要加入[CLS]与[SEP] - word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index) - word_pieces[:, 0].fill_(self._cls_index) - batch_indexes = torch.arange(batch_size).to(words) - word_pieces[batch_indexes, word_pieces_lengths+1] = self._sep_index - attn_masks = torch.zeros_like(word_pieces) - # 1. 获取words的word_pieces的id,以及对应的span范围 - word_indexes = words.tolist() - for i in range(batch_size): - word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i]])) - word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i) - attn_masks[i, :len(word_pieces_i)+2].fill_(1) - # TODO 截掉长度超过的部分。 + with torch.no_grad(): + batch_size, max_word_len = words.size() + word_mask = words.ne(self._word_pad_index) # 为1的地方有word + seq_len = word_mask.sum(dim=-1) + batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(0), 0) # batch_size x max_len + word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size + word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding) + if word_piece_length+2>self._max_position_embeddings: + if self.auto_truncate: + word_pieces_lengths = word_pieces_lengths.masked_fill(word_pieces_lengths+2>self._max_position_embeddings, + self._max_position_embeddings-2) + else: + raise RuntimeError("After split words into word pieces, the lengths of word pieces are longer than the " + f"maximum allowed sequence length:{self._max_position_embeddings} of bert.") + + # +2是由于需要加入[CLS]与[SEP] + word_pieces = words.new_full((batch_size, min(word_piece_length+2, self._max_position_embeddings)), + fill_value=self._wordpiece_pad_index) + attn_masks = torch.zeros_like(word_pieces) + # 1. 获取words的word_pieces的id,以及对应的span范围 + word_indexes = words.cpu().numpy() + for i in range(batch_size): + word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i, :seq_len[i]]])) + if self.auto_truncate and len(word_pieces_i)>self._max_position_embeddings-2: + word_pieces_i = word_pieces_i[:self._max_position_embeddings-2] + word_pieces[i, 1:word_pieces_lengths[i]+1] = torch.LongTensor(word_pieces_i) + attn_masks[i, :word_pieces_lengths[i]+2].fill_(1) + # 添加[cls]和[sep] + word_pieces[:, 0].fill_(self._cls_index) + batch_indexes = torch.arange(batch_size).to(words) + word_pieces[batch_indexes, word_pieces_lengths+1] = self._sep_index + if self._has_sep_in_vocab: #但[SEP]在vocab中出现应该才会需要token_ids + sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len + sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + token_type_ids = sep_mask_cumsum.fmod(2) + if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 + token_type_ids = token_type_ids.eq(0).float() + else: + token_type_ids = torch.zeros_like(word_pieces) # 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 # all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] - bert_outputs, _ = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks, + bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, output_all_encoded_layers=True) - # output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size + # output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size if self.include_cls_sep: outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, @@ -306,6 +360,12 @@ class _WordBertModel(nn.Module): batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len for l_index, l in enumerate(self.layers): output_layer = bert_outputs[l] + real_word_piece_length = output_layer.size(1) - 2 + if word_piece_length > real_word_piece_length: # 如果实际上是截取出来的 + paddings = output_layer.new_zeros(batch_size, + word_piece_length-real_word_piece_length, + output_layer.size(2)) + output_layer = torch.cat((output_layer, paddings), dim=1).contiguous() # 从word_piece collapse到word的表示 truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size outputs_seq_len = seq_len + s_shift @@ -328,7 +388,10 @@ class _WordBertModel(nn.Module): start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j+1] outputs[l_index, i, j+s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2) if self.include_cls_sep: - outputs[l_index, :, 0] = output_layer[:, 0] + if l in (len(bert_outputs)-1, -1) and self.pooled_cls: + outputs[l_index, :, 0] = pooled_cls + else: + outputs[l_index, :, 0] = output_layer[:, 0] outputs[l_index, batch_indexes, seq_len+s_shift] = output_layer[batch_indexes, seq_len+s_shift] # 3. 最终的embedding结果 return outputs diff --git a/fastNLP/embeddings/char_embedding.py b/fastNLP/embeddings/char_embedding.py index b0bd6796..783dec41 100644 --- a/fastNLP/embeddings/char_embedding.py +++ b/fastNLP/embeddings/char_embedding.py @@ -95,7 +95,7 @@ class CNNCharEmbedding(TokenEmbedding): for i in range(len(kernel_sizes))]) self._embed_size = embed_size self.fc = nn.Linear(sum(filter_nums), embed_size) - self.init_param() + self.reset_parameters() def forward(self, words): """ @@ -152,7 +152,7 @@ class CNNCharEmbedding(TokenEmbedding): continue param.requires_grad = value - def init_param(self): + def reset_parameters(self): for name, param in self.named_parameters(): if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset continue diff --git a/fastNLP/embeddings/contextual_embedding.py b/fastNLP/embeddings/contextual_embedding.py index 1831af4e..152b0ab9 100644 --- a/fastNLP/embeddings/contextual_embedding.py +++ b/fastNLP/embeddings/contextual_embedding.py @@ -1,4 +1,3 @@ - from abc import abstractmethod import torch @@ -9,6 +8,10 @@ from ..core.sampler import SequentialSampler from ..core.utils import _move_model_to_device, _get_model_device from .embedding import TokenEmbedding +__all__ = [ + "ContextualEmbedding" +] + class ContextualEmbedding(TokenEmbedding): def __init__(self, vocab: Vocabulary, word_dropout:float=0.0, dropout:float=0.0): diff --git a/fastNLP/embeddings/elmo_embedding.py b/fastNLP/embeddings/elmo_embedding.py index 73def086..435e0b98 100644 --- a/fastNLP/embeddings/elmo_embedding.py +++ b/fastNLP/embeddings/elmo_embedding.py @@ -8,7 +8,7 @@ import json import codecs from ..core.vocabulary import Vocabulary -from ..io.file_utils import cached_path, _get_base_url, PRETRAINED_ELMO_MODEL_DIR +from ..io.file_utils import cached_path, _get_embedding_url, PRETRAINED_ELMO_MODEL_DIR from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder from .contextual_embedding import ContextualEmbedding @@ -40,7 +40,7 @@ class ElmoEmbedding(ContextualEmbedding): :param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo所在文件夹,该文件夹下面应该有两个文件, 其中一个是以json为后缀的配置文件,另一个是以pkl为后缀的权重文件;第二种是传入ELMo版本的名称,将自动查看缓存中是否存在该模型, 没有的话将自动下载并缓存。 - :param layers: str, 指定返回的层数, 以,隔开不同的层。如果要返回第二层的结果'2', 返回后两层的结果'1,2'。不同的层的结果 + :param layers: str, 指定返回的层数(从0开始), 以,隔开不同的层。如果要返回第二层的结果'2', 返回后两层的结果'1,2'。不同的层的结果 按照这个顺序concat起来,默认为'2'。'mix'会使用可学习的权重结合不同层的表示(权重是否可训练与requires_grad保持一致, 初始化权重对三层结果进行mean-pooling, 可以通过ElmoEmbedding.set_mix_weights_requires_grad()方法只将mix weights设置为可学习。) :param requires_grad: bool, 该层是否需要gradient, 默认为False. @@ -56,10 +56,8 @@ class ElmoEmbedding(ContextualEmbedding): # 根据model_dir_or_name检查是否存在并下载 if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: - PRETRAIN_URL = _get_base_url('elmo') - model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name] - model_url = PRETRAIN_URL + model_name - model_dir = cached_path(model_url) + model_url = _get_embedding_url('elmo', model_dir_or_name.lower()) + model_dir = cached_path(model_url, name='embedding') # 检查是否存在 elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): model_dir = model_dir_or_name @@ -185,8 +183,8 @@ class _ElmoModel(nn.Module): raise Exception(f"Multiple config files(*.json) or weight files(*.hdf5) detected in {model_dir}.") elif config_count == 0 or weight_count == 0: raise Exception(f"No config file or weight file found in {model_dir}") - - config = json.load(open(os.path.join(model_dir, config_file), 'r')) + with open(os.path.join(model_dir, config_file), 'r') as config_f: + config = json.load(config_f) self.weight_file = os.path.join(model_dir, weight_file) self.config = config diff --git a/fastNLP/embeddings/embedding.py b/fastNLP/embeddings/embedding.py index a02e7a20..8c5396b7 100644 --- a/fastNLP/embeddings/embedding.py +++ b/fastNLP/embeddings/embedding.py @@ -42,7 +42,12 @@ class Embedding(nn.Module): self.dropout = nn.Dropout(dropout) if not isinstance(self.embed, TokenEmbedding): - self._embed_size = self.embed.weight.size(1) + if hasattr(self.embed, 'embed_size'): + self._embed_size = self.embed.embed_size + elif hasattr(self.embed, 'embedding_dim'): + self._embed_size = self.embed.embedding_dim + else: + self._embed_size = self.embed.weight.size(1) if word_dropout>0 and not isinstance(unk_index, int): raise ValueError("When drop word is set, you need to pass in the unk_index.") else: diff --git a/fastNLP/embeddings/static_embedding.py b/fastNLP/embeddings/static_embedding.py index c2aa1c49..050a7fe1 100644 --- a/fastNLP/embeddings/static_embedding.py +++ b/fastNLP/embeddings/static_embedding.py @@ -7,9 +7,11 @@ import numpy as np import warnings from ..core.vocabulary import Vocabulary -from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_base_url, cached_path +from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path from .embedding import TokenEmbedding from ..modules.utils import _get_file_name_base_on_postfix +from copy import deepcopy +from collections import defaultdict class StaticEmbedding(TokenEmbedding): """ @@ -45,15 +47,16 @@ class StaticEmbedding(TokenEmbedding): 如果输入为None则使用embedding_dim的维度随机初始化一个embedding。 :param int embedding_dim: 随机初始化的embedding的维度,仅在model_dir_or_name为None时有效。 :param bool requires_grad: 是否需要gradient. 默认为True - :param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对象。 + :param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对 :param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独 为大写的词语开辟一个vector表示,则将lower设置为False。 - :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 + :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 :param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。 + :param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 """ def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', embedding_dim=100, requires_grad: bool=True, - init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False): + init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs): super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) # 得到cache_path @@ -62,10 +65,8 @@ class StaticEmbedding(TokenEmbedding): embedding_dim = int(embedding_dim) model_path = None elif model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: - PRETRAIN_URL = _get_base_url('static') - model_name = PRETRAIN_STATIC_FILES[model_dir_or_name] - model_url = PRETRAIN_URL + model_name - model_path = cached_path(model_url) + model_url = _get_embedding_url('static', model_dir_or_name.lower()) + model_path = cached_path(model_url, name='embedding') # 检查是否存在 elif os.path.isfile(os.path.expanduser(os.path.abspath(model_dir_or_name))): model_path = model_dir_or_name @@ -74,27 +75,49 @@ class StaticEmbedding(TokenEmbedding): else: raise ValueError(f"Cannot recognize {model_dir_or_name}.") + # 根据min_freq缩小vocab + truncate_vocab = (vocab.min_freq is None and min_freq>1) or (vocab.min_freq and vocab.min_freq=min_freq and word_count0: - if vocab.unknown is None: # 创建一个专门的unknown - unknown_idx = len(matrix) - vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() - else: - unknown_idx = vocab.unknown_idx - words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), - requires_grad=False) - for order, (index, vec) in enumerate(matrix.items()): - if vec is not None: - vectors[order] = vec - words_to_words[index] = order - self.words_to_words = words_to_words + if vocab.unknown is None: # 创建一个专门的unknown + unknown_idx = len(matrix) + vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() else: - for index, vec in matrix.items(): - if vec is not None: - vectors[index] = vec + unknown_idx = vocab.unknown_idx + self.words_to_words = nn.Parameter(torch.full((len(vocab), ), fill_value=unknown_idx).long(), + requires_grad=False) + + for index, (index_in_vocab, vec) in enumerate(matrix.items()): + if vec is not None: + vectors[index] = vec + self.words_to_words[index_in_vocab] = index return vectors diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py index cd0d3527..5234b209 100644 --- a/fastNLP/io/__init__.py +++ b/fastNLP/io/__init__.py @@ -3,9 +3,9 @@ 1. 用于读入 embedding 的 :doc:`EmbedLoader ` 类, -2. 用于读入不同格式数据的 :doc:`DataSetLoader ` 类 +2. 用于读入不同格式数据的 :doc:`Loader ` 类 -3. 用于读入不同数据集并进行预处理的 :doc:`DataLoader ` 类 +3. 用于处理读入数据的 :doc:`Pipe ` 类 4. 用于保存和载入模型的类, 参考 :doc:`model_io文档` @@ -14,27 +14,56 @@ __all__ = [ 'EmbedLoader', - 'CSVLoader', - 'JsonLoader', - 'DataBundle', 'DataSetLoader', - 'ConllLoader', - 'Conll2003Loader', + 'YelpLoader', + 'YelpFullLoader', + 'YelpPolarityLoader', 'IMDBLoader', - 'MatchingLoader', - 'SNLILoader', - 'MNLILoader', - 'MTL16Loader', - 'PeopleDailyCorpusLoader', - 'QNLILoader', - 'QuoraLoader', - 'RTELoader', 'SSTLoader', 'SST2Loader', - 'YelpLoader', - + + 'ConllLoader', + 'Conll2003Loader', + 'Conll2003NERLoader', + 'OntoNotesNERLoader', + 'CTBLoader', + + 'Loader', + 'CSVLoader', + 'JsonLoader', + + 'CWSLoader', + + 'MNLILoader', + "QuoraLoader", + "SNLILoader", + "QNLILoader", + "RTELoader", + + "YelpFullPipe", + "YelpPolarityPipe", + "SSTPipe", + "SST2Pipe", + "IMDBPipe", + + "Conll2003NERPipe", + "OntoNotesNERPipe", + + "MatchingBertPipe", + "RTEBertPipe", + "SNLIBertPipe", + "QuoraBertPipe", + "QNLIBertPipe", + "MNLIBertPipe", + "MatchingPipe", + "RTEPipe", + "SNLIPipe", + "QuoraPipe", + "QNLIPipe", + "MNLIPipe", + 'ModelLoader', 'ModelSaver', ] @@ -44,4 +73,5 @@ from .base_loader import DataBundle, DataSetLoader from .dataset_loader import CSVLoader, JsonLoader from .model_io import ModelLoader, ModelSaver -from .data_loader import * +from .loader import * +from .pipe import * diff --git a/fastNLP/io/base_loader.py b/fastNLP/io/base_loader.py index 5d61c16a..5cbd5bb1 100644 --- a/fastNLP/io/base_loader.py +++ b/fastNLP/io/base_loader.py @@ -5,10 +5,10 @@ __all__ = [ ] import _pickle as pickle -import os from typing import Union, Dict import os from ..core.dataset import DataSet +from ..core.vocabulary import Vocabulary class BaseLoader(object): @@ -111,7 +111,10 @@ def _uncompress(src, dst): class DataBundle: """ - 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。 + 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。该对象一般由fastNLP中各种 + DataSetLoader的load函数生成,可以通过以下的方法获取里面的内容 + + Example:: :param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict @@ -121,6 +124,88 @@ class DataBundle: self.vocabs = vocabs or {} self.datasets = datasets or {} + def set_vocab(self, vocab, field_name): + """ + 向DataBunlde中增加vocab + + :param ~fastNLP.Vocabulary vocab: 词表 + :param str field_name: 这个vocab对应的field名称 + :return: + """ + assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary supports." + self.vocabs[field_name] = vocab + + def set_dataset(self, dataset, name): + """ + + :param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet + :param str name: dataset的名称 + :return: + """ + self.datasets[name] = dataset + + def get_dataset(self, name:str)->DataSet: + """ + 获取名为name的dataset + + :param str name: dataset的名称,一般为'train', 'dev', 'test' + :return: DataSet + """ + return self.datasets[name] + + def get_vocab(self, field_name:str)->Vocabulary: + """ + 获取field名为field_name对应的vocab + + :param str field_name: 名称 + :return: Vocabulary + """ + return self.vocabs[field_name] + + def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_field=True): + """ + 将field_names中的field设置为input, 对data_bundle中所有的dataset执行该操作:: + + data_bundle.set_input('words', 'seq_len') # 将words和seq_len这两个field的input属性设置为True + data_bundle.set_input('words', flag=False) # 将words这个field的input属性设置为False + + :param str field_names: field的名称 + :param bool flag: 将field_name的input状态设置为flag + :param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一 + 行的数据进行类型和维度推断本列的数据的类型和维度。 + :param bool ignore_miss_field: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略; 如果为False,则报错 + """ + for field_name in field_names: + for name, dataset in self.datasets.items(): + if not ignore_miss_field and not dataset.has_field(field_name): + raise KeyError(f"Field:{field_name} was not found in DataSet:{name}") + if not dataset.has_field(field_name): + continue + else: + dataset.set_input(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type) + + def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_field=True): + """ + 将field_names中的field设置为target, 对data_bundle中所有的dataset执行该操作:: + + data_bundle.set_target('target', 'seq_len') # 将words和target这两个field的input属性设置为True + data_bundle.set_target('target', flag=False) # 将target这个field的input属性设置为False + + :param str field_names: field的名称 + :param bool flag: 将field_name的target状态设置为flag + :param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一 + 行的数据进行类型和维度推断本列的数据的类型和维度。 + :param bool ignore_miss_field: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略; 如果为False,则报错 + """ + for field_name in field_names: + for name, dataset in self.datasets.items(): + if not ignore_miss_field and not dataset.has_field(field_name): + raise KeyError(f"Field:{field_name} was not found in DataSet:{name}") + if not dataset.has_field(field_name): + continue + else: + dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type) + def __repr__(self): _str = 'In total {} datasets:\n'.format(len(self.datasets)) for name, dataset in self.datasets.items(): diff --git a/fastNLP/io/config_io.py b/fastNLP/io/config_io.py index 4acdbb96..ac349080 100644 --- a/fastNLP/io/config_io.py +++ b/fastNLP/io/config_io.py @@ -1,7 +1,9 @@ """ 用于读入和处理和保存 config 文件 - .. todo:: + +.. todo:: 这个模块中的类可能被抛弃? + """ __all__ = [ "ConfigLoader", diff --git a/fastNLP/io/data_loader/__init__.py b/fastNLP/io/data_loader/__init__.py index 5d6b08b0..b3ca9021 100644 --- a/fastNLP/io/data_loader/__init__.py +++ b/fastNLP/io/data_loader/__init__.py @@ -1,4 +1,8 @@ """ +.. warning:: + + 本模块在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。 + 用于读数据集的模块, 可以读取文本分类、序列标注、Matching任务的数据集 这些模块的具体介绍如下,您可以通过阅读 :doc:`教程` 来进行了解。 diff --git a/fastNLP/io/data_loader/conll.py b/fastNLP/io/data_loader/conll.py index 9b2402a2..0285173c 100644 --- a/fastNLP/io/data_loader/conll.py +++ b/fastNLP/io/data_loader/conll.py @@ -3,38 +3,47 @@ from ...core.dataset import DataSet from ...core.instance import Instance from ..base_loader import DataSetLoader from ..file_reader import _read_conll - +from typing import Union, Dict +from ..utils import check_loader_paths +from ..base_loader import DataBundle class ConllLoader(DataSetLoader): """ 别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.data_loader.ConllLoader` - 读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html. 数据中以"-DOCSTART-"开头的行将被忽略,因为 - 该符号在conll 2003中被用为文档分割符。 - - 列号从0开始, 每列对应内容为:: - - Column Type - 0 Document ID - 1 Part number - 2 Word number - 3 Word itself - 4 Part-of-Speech - 5 Parse bit - 6 Predicate lemma - 7 Predicate Frameset ID - 8 Word sense - 9 Speaker/Author - 10 Named Entities - 11:N Predicate Arguments - N Coreference - - :param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 - :param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` - :param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` + 该ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示: + + Example:: + + # 文件中的内容 + Nadim NNP B-NP B-PER + Ladki NNP I-NP I-PER + + AL-AIN NNP B-NP B-LOC + United NNP B-NP B-LOC + Arab NNP I-NP I-LOC + Emirates NNPS I-NP I-LOC + 1996-12-06 CD I-NP O + ... + + # 如果用以下的参数读取,返回的DataSet将包含raw_words和pos两个field, 这两个field的值分别取自于第0列与第1列 + dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll') + # 如果用以下的参数读取,返回的DataSet将包含raw_words和ner两个field, 这两个field的值分别取自于第0列与第2列 + dataset = ConllLoader(headers=['raw_words', 'ner'], indexes=[0, 3])._load('/path/to/train.conll') + # 如果用以下的参数读取,返回的DataSet将包含raw_words, pos和ner三个field + dataset = ConllLoader(headers=['raw_words', 'pos', 'ner'], indexes=[0, 1, 3])._load('/path/to/train.conll') + + dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll')中DataSet的raw_words + 列与pos列的内容都是List[str] + + 数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 + + :param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 + :param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` + :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` """ - def __init__(self, headers, indexes=None, dropna=False): + def __init__(self, headers, indexes=None, dropna=True): super(ConllLoader, self).__init__() if not isinstance(headers, (list, tuple)): raise TypeError( @@ -49,25 +58,74 @@ class ConllLoader(DataSetLoader): self.indexes = indexes def _load(self, path): + """ + 传入的一个文件路径,将该文件读入DataSet中,field由Loader初始化时指定的headers决定。 + + :param str path: 文件的路径 + :return: DataSet + """ ds = DataSet() for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): ins = {h: data[i] for i, h in enumerate(self.headers)} ds.append(Instance(**ins)) return ds + def load(self, paths: Union[str, Dict[str, str]]) -> DataBundle: + """ + 从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 + + 读取的field根据ConllLoader初始化时传入的headers决定。 + + :param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式 + (1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件 + 名包含'train'、 'dev'、 'test'则会报错 + + Example:: + data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train, dev, test等有所变化 + # 可以通过以下的方式取出DataSet + tr_data = data_bundle.datasets['train'] + te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段 + + (2) 传入文件path + + Example:: + data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' + tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet + + (3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test + + Example:: + paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} + data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" + dev_data = data_bundle.datasets['dev'] + + :return: :class:`~fastNLP.DataSet` 类的对象或 :class:`~fastNLP.io.DataBundle` 的字典 + """ + paths = check_loader_paths(paths) + datasets = {name: self._load(path) for name, path in paths.items()} + data_bundle = DataBundle(datasets=datasets) + return data_bundle + class Conll2003Loader(ConllLoader): """ 别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.data_loader.Conll2003Loader` - 读取Conll2003数据 + 该Loader用以读取Conll2003数据,conll2003的数据可以在https://github.com/davidsbatista/NER-datasets/tree/master/CONLL2003 + 找到。数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 + + 返回的DataSet将具有以下['raw_words', 'pos', 'chunks', 'ner']四个field, 每个field中的内容都是List[str]。 + + .. csv-table:: Conll2003Loader处理之 :header: "raw_words", "words", "target", "seq_len" + + "[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2 + "[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 5 + "[...]", "[...]", "[...]", . - 关于数据集的更多信息,参考: - https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data """ def __init__(self): headers = [ - 'tokens', 'pos', 'chunks', 'ner', + 'raw_words', 'pos', 'chunks', 'ner', ] super(Conll2003Loader, self).__init__(headers=headers) diff --git a/fastNLP/io/data_loader/matching.py b/fastNLP/io/data_loader/matching.py index 481b5056..1242b432 100644 --- a/fastNLP/io/data_loader/matching.py +++ b/fastNLP/io/data_loader/matching.py @@ -121,7 +121,7 @@ class MatchingLoader(DataSetLoader): PRETRAIN_URL = _get_base_url('bert') model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] model_url = PRETRAIN_URL + model_name - model_dir = cached_path(model_url) + model_dir = cached_path(model_url, name='embedding') # 检查是否存在 elif os.path.isdir(bert_tokenizer): model_dir = bert_tokenizer diff --git a/fastNLP/io/data_loader/mtl.py b/fastNLP/io/data_loader/mtl.py index cbca413d..20824958 100644 --- a/fastNLP/io/data_loader/mtl.py +++ b/fastNLP/io/data_loader/mtl.py @@ -5,7 +5,7 @@ from ..base_loader import DataBundle from ..dataset_loader import CSVLoader from ...core.vocabulary import Vocabulary, VocabularyOption from ...core.const import Const -from ..utils import check_dataloader_paths +from ..utils import check_loader_paths class MTL16Loader(CSVLoader): @@ -38,7 +38,7 @@ class MTL16Loader(CSVLoader): src_vocab_opt: VocabularyOption = None, tgt_vocab_opt: VocabularyOption = None,): - paths = check_dataloader_paths(paths) + paths = check_loader_paths(paths) datasets = {} info = DataBundle() for name, path in paths.items(): diff --git a/fastNLP/io/data_loader/sst.py b/fastNLP/io/data_loader/sst.py index 0d881e65..c2e0eca1 100644 --- a/fastNLP/io/data_loader/sst.py +++ b/fastNLP/io/data_loader/sst.py @@ -8,7 +8,7 @@ from ...core.vocabulary import VocabularyOption, Vocabulary from ...core.dataset import DataSet from ...core.const import Const from ...core.instance import Instance -from ..utils import check_dataloader_paths, get_tokenizer +from ..utils import check_loader_paths, get_tokenizer class SSTLoader(DataSetLoader): @@ -67,7 +67,7 @@ class SSTLoader(DataSetLoader): paths, train_subtree=True, src_vocab_op: VocabularyOption = None, tgt_vocab_op: VocabularyOption = None,): - paths = check_dataloader_paths(paths) + paths = check_loader_paths(paths) input_name, target_name = 'words', 'target' src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) tgt_vocab = Vocabulary(unknown=None, padding=None) \ @@ -129,11 +129,12 @@ class SST2Loader(CSVLoader): tgt_vocab_opt: VocabularyOption = None, char_level_op=False): - paths = check_dataloader_paths(paths) + paths = check_loader_paths(paths) datasets = {} info = DataBundle() for name, path in paths.items(): dataset = self.load(path) + dataset.apply_field(lambda words:words.copy(), field_name='words', new_field_name='raw_words') datasets[name] = dataset def wordtochar(words): @@ -154,7 +155,9 @@ class SST2Loader(CSVLoader): for dataset in datasets.values(): dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) - src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT) + src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT, no_create_entry_dataset=[ + dataset for name, dataset in datasets.items() if name!='train' + ]) src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) tgt_vocab = Vocabulary(unknown=None, padding=None) \ diff --git a/fastNLP/io/data_loader/yelp.py b/fastNLP/io/data_loader/yelp.py index 333fcab0..15533b04 100644 --- a/fastNLP/io/data_loader/yelp.py +++ b/fastNLP/io/data_loader/yelp.py @@ -8,7 +8,7 @@ from ...core.instance import Instance from ...core.vocabulary import VocabularyOption, Vocabulary from ..base_loader import DataBundle, DataSetLoader from typing import Union, Dict -from ..utils import check_dataloader_paths, get_tokenizer +from ..utils import check_loader_paths, get_tokenizer class YelpLoader(DataSetLoader): @@ -62,7 +62,7 @@ class YelpLoader(DataSetLoader): src_vocab_op: VocabularyOption = None, tgt_vocab_op: VocabularyOption = None, char_level_op=False): - paths = check_dataloader_paths(paths) + paths = check_loader_paths(paths) info = DataBundle(datasets=self.load(paths)) src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) tgt_vocab = Vocabulary(unknown=None, padding=None) \ diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index ad6bbdc1..e1e06ec9 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -1,4 +1,8 @@ """ +.. warning:: + + 本模块将在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。 + dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的数据, 并返回 `DataSet` , 得到的 :class:`~fastNLP.DataSet` 对象可以直接传入 :class:`~fastNLP.Trainer` 和 :class:`~fastNLP.Tester`, 用于模型的训练和测试。 以SNLI数据集为例:: @@ -11,6 +15,7 @@ dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的 # ... do stuff 为 fastNLP 提供 DataSetLoader 的开发者请参考 :class:`~fastNLP.io.DataSetLoader` 的介绍。 + """ __all__ = [ 'CSVLoader', @@ -114,25 +119,3 @@ def _cut_long_sentence(sent, max_sample_length=200): else: cutted_sentence.append(sent) return cutted_sentence - - -def _add_seg_tag(data): - """ - - :param data: list of ([word], [pos], [heads], [head_tags]) - :return: list of ([word], [pos]) - """ - - _processed = [] - for word_list, pos_list, _, _ in data: - new_sample = [] - for word, pos in zip(word_list, pos_list): - if len(word) == 1: - new_sample.append((word, 'S-' + pos)) - else: - new_sample.append((word[0], 'B-' + pos)) - for c in word[1:-1]: - new_sample.append((c, 'M-' + pos)) - new_sample.append((word[-1], 'E-' + pos)) - _processed.append(list(map(list, zip(*new_sample)))) - return _processed diff --git a/fastNLP/io/file_reader.py b/fastNLP/io/file_reader.py index 0ae0a319..6aa89b80 100644 --- a/fastNLP/io/file_reader.py +++ b/fastNLP/io/file_reader.py @@ -2,7 +2,7 @@ 此模块用于给其它模块提供读取文件的函数,没有为用户提供 API """ import json - +import warnings def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): """ @@ -91,7 +91,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): with open(path, 'r', encoding=encoding) as f: sample = [] start = next(f).strip() - if '-DOCSTART-' not in start and start!='': + if start!='': sample.append(start.split()) for line_idx, line in enumerate(f, 1): line = line.strip() @@ -103,13 +103,13 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): yield line_idx, res except Exception as e: if dropna: + warnings.warn('Invalid instance ends at line: {} has been dropped.'.format(line_idx)) continue - raise ValueError('invalid instance ends at line: {}'.format(line_idx)) + raise ValueError('Invalid instance ends at line: {}'.format(line_idx)) elif line.startswith('#'): continue else: - if not line.startswith('-DOCSTART-'): - sample.append(line.split()) + sample.append(line.split()) if len(sample) > 0: try: res = parse_conll(sample) diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py index cb762eb7..9febfe4a 100644 --- a/fastNLP/io/file_utils.py +++ b/fastNLP/io/file_utils.py @@ -1,4 +1,3 @@ - import os from pathlib import Path from urllib.parse import urlparse @@ -7,65 +6,124 @@ import requests import tempfile from tqdm import tqdm import shutil -import hashlib - +from requests import HTTPError PRETRAINED_BERT_MODEL_DIR = { - 'en': 'bert-base-cased-f89bfe08.zip', - 'en-base-uncased': 'bert-base-uncased-3413b23c.zip', - 'en-base-cased': 'bert-base-cased-f89bfe08.zip', - 'en-large-uncased': 'bert-large-uncased-20939f45.zip', - 'en-large-cased': 'bert-large-cased-e0cf90fc.zip', - - 'en-large-cased-wwm': 'bert-large-cased-wwm-a457f118.zip', - 'en-large-uncased-wwm': 'bert-large-uncased-wwm-92a50aeb.zip', - 'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc-c7099855.zip', - - 'cn': 'bert-base-chinese-29d0a84a.zip', - 'cn-base': 'bert-base-chinese-29d0a84a.zip', - - 'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', - 'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', - 'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', + 'en': 'bert-base-cased.zip', + 'en-large-cased-wwm': 'bert-large-cased-wwm.zip', + 'en-large-uncased-wwm': 'bert-large-uncased-wwm.zip', + + 'en-large-uncased': 'bert-large-uncased.zip', + 'en-large-cased': 'bert-large-cased.zip', + + 'en-base-uncased': 'bert-base-uncased.zip', + 'en-base-cased': 'bert-base-cased.zip', + + 'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc.zip', + + 'multi-base-cased': 'bert-base-multilingual-cased.zip', + 'multi-base-uncased': 'bert-base-multilingual-uncased.zip', + + 'cn': 'bert-chinese-wwm.zip', + 'cn-base': 'bert-base-chinese.zip', + 'cn-wwm': 'bert-chinese-wwm.zip', } PRETRAINED_ELMO_MODEL_DIR = { - 'en': 'elmo_en-d39843fe.tar.gz', - 'cn': 'elmo_cn-5e9b34e2.tar.gz' + 'en': 'elmo_en_Medium.zip', + 'en-small': "elmo_en_Small.zip", + 'en-original-5.5b': 'elmo_en_Original_5.5B.zip', + 'en-original': 'elmo_en_Original.zip', + 'en-medium': 'elmo_en_Medium.zip' } PRETRAIN_STATIC_FILES = { - 'en': 'glove.840B.300d-cc1ad5e1.tar.gz', - 'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz', - 'en-glove-6b-50': "glove.6B.50d-a6028c70.tar.gz", - 'en-word2vec-300': "GoogleNews-vectors-negative300-be166d9d.tar.gz", - 'en-fasttext': "cc.en.300.vec-d53187b2.gz", - 'cn': "tencent_cn-dab24577.tar.gz", - 'cn-fasttext': "cc.zh.300.vec-d68a9bcf.gz", + 'en': 'glove.840B.300d.zip', + + 'en-glove-6b-50d': 'glove.6B.50d.zip', + 'en-glove-6b-100d': 'glove.6B.100d.zip', + 'en-glove-6b-200d': 'glove.6B.200d.zip', + 'en-glove-6b-300d': 'glove.6B.300d.zip', + 'en-glove-42b-300d': 'glove.42B.300d.zip', + 'en-glove-840b-300d': 'glove.840B.300d.zip', + 'en-glove-twitter-27b-25d': 'glove.twitter.27B.25d.zip', + 'en-glove-twitter-27b-50d': 'glove.twitter.27B.50d.zip', + 'en-glove-twitter-27b-100d': 'glove.twitter.27B.100d.zip', + 'en-glove-twitter-27b-200d': 'glove.twitter.27B.200d.zip', + + 'en-word2vec-300': "GoogleNews-vectors-negative300.txt.gz", + + 'en-fasttext-wiki': "wiki-news-300d-1M.vec.zip", + 'en-fasttext-crawl': "crawl-300d-2M.vec.zip", + + 'cn': "tencent_cn.txt.zip", + 'cn-tencent': "tencent_cn.txt.zip", + 'cn-fasttext': "cc.zh.300.vec.gz", + 'cn-sgns-literature-word': 'sgns.literature.word.txt.zip', } +DATASET_DIR = { + 'aclImdb': "imdb.zip", + "yelp-review-full": "yelp_review_full.tar.gz", + "yelp-review-polarity": "yelp_review_polarity.tar.gz", + "mnli": "MNLI.zip", + "snli": "SNLI.zip", + "qnli": "QNLI.zip", + "sst-2": "SST-2.zip", + "sst": "SST.zip", + "rte": "RTE.zip" +} + +PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, + "bert": PRETRAINED_BERT_MODEL_DIR, + "static": PRETRAIN_STATIC_FILES} + +# 用于扩展fastNLP的下载 +FASTNLP_EXTEND_DATASET_URL = 'fastnlp_dataset_url.txt' +FASTNLP_EXTEND_EMBEDDING_URL = {'elmo': 'fastnlp_elmo_url.txt', + 'bert':'fastnlp_bert_url.txt', + 'static': 'fastnlp_static_url.txt' +} -def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: + +def cached_path(url_or_filename: str, cache_dir: str = None, name=None) -> Path: """ - 给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 - 将文件放入到cache_dir中 + 给定一个url,尝试通过url中的解析出来的文件名字filename到{cache_dir}/{name}/{filename}下寻找这个文件, + + 1. 如果cache_dir=None, 则cache_dir=~/.fastNLP/; 否则cache_dir=cache_dir + 2. 如果name=None, 则没有中间的{name}这一层结构;否者中间结构就为{name} + + 如果有该文件,就直接返回路径 + + 如果没有该文件,则尝试用传入的url下载 + + 或者文件名(可以是具体的文件名,也可以是文件夹),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 + 将文件放入到cache_dir中. + + :param str url_or_filename: 文件的下载url或者文件名称。 + :param str cache_dir: 文件的缓存文件夹。如果为None,将使用"~/.fastNLP"这个默认路径 + :param str name: 中间一层的名称。如embedding, dataset + :return: """ if cache_dir is None: - dataset_cache = Path(get_defalt_path()) + data_cache = Path(get_cache_path()) else: - dataset_cache = cache_dir + data_cache = cache_dir + + if name: + data_cache = os.path.join(data_cache, name) parsed = urlparse(url_or_filename) if parsed.scheme in ("http", "https"): # URL, so get it from the cache (downloading if necessary) - return get_from_cache(url_or_filename, dataset_cache) - elif parsed.scheme == "" and Path(os.path.join(dataset_cache, url_or_filename)).exists(): + return get_from_cache(url_or_filename, Path(data_cache)) + elif parsed.scheme == "" and Path(os.path.join(data_cache, url_or_filename)).exists(): # File, and it exists. - return Path(url_or_filename) + return Path(os.path.join(data_cache, url_or_filename)) elif parsed.scheme == "": # File, but it doesn't exist. - raise FileNotFoundError("file {} not found".format(url_or_filename)) + raise FileNotFoundError("file {} not found in {}.".format(url_or_filename, data_cache)) else: # Something unknown raise ValueError( @@ -75,48 +133,143 @@ def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: def get_filepath(filepath): """ - 如果filepath中只有一个文件,则直接返回对应的全路径 - :param filepath: + 如果filepath为文件夹, + + 如果内含多个文件, 返回filepath + + 如果只有一个文件, 返回filepath + filename + + 如果filepath为文件 + + 返回filepath + + :param str filepath: 路径 :return: """ if os.path.isdir(filepath): files = os.listdir(filepath) - if len(files)==1: + if len(files) == 1: return os.path.join(filepath, files[0]) else: return filepath - return filepath + elif os.path.isfile(filepath): + return filepath + else: + raise FileNotFoundError(f"{filepath} is not a valid file or directory.") -def get_defalt_path(): +def get_cache_path(): """ - 获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 + 获取fastNLP默认cache的存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 - :return: + :return str: 存放路径 """ if 'FASTNLP_CACHE_DIR' in os.environ: fastnlp_cache_dir = os.environ.get('FASTNLP_CACHE_DIR') - if os.path.exists(fastnlp_cache_dir): + if os.path.isdir(fastnlp_cache_dir): return fastnlp_cache_dir - raise RuntimeError("Some errors happens on cache directory.") - else: - raise RuntimeError("There function is not available right now.") + else: + raise NotADirectoryError(f"{os.environ['FASTNLP_CACHE_DIR']} is not a directory.") fastnlp_cache_dir = os.path.expanduser(os.path.join("~", ".fastNLP")) return fastnlp_cache_dir def _get_base_url(name): + """ + 根据name返回下载的url地址。 + + :param str name: 支持dataset和embedding两种 + :return: + """ # 返回的URL结尾必须是/ - if 'FASTNLP_BASE_URL' in os.environ: - fastnlp_base_url = os.environ['FASTNLP_BASE_URL'] - return fastnlp_base_url - raise RuntimeError("There function is not available right now.") + environ_name = "FASTNLP_{}_URL".format(name.upper()) + + if environ_name in os.environ: + url = os.environ[environ_name] + if url.endswith('/'): + return url + else: + return url + '/' + else: + URLS = { + 'embedding': "http://dbcloud.irocn.cn:8989/api/public/dl/", + "dataset": "http://dbcloud.irocn.cn:8989/api/public/dl/dataset/" + } + if name.lower() not in URLS: + raise KeyError(f"{name} is not recognized.") + return URLS[name.lower()] + + +def _get_embedding_url(embed_type, name): + """ + 给定embedding类似和名称,返回下载url + + :param str embed_type: 支持static, bert, elmo。即embedding的类型 + :param str name: embedding的名称, 例如en, cn, based等 + :return: str, 下载的url地址 + """ + # 从扩展中寻找下载的url + _filename = FASTNLP_EXTEND_EMBEDDING_URL.get(embed_type, None) + if _filename: + url = _read_extend_url_file(_filename, name) + if url: + return url + embed_map = PRETRAIN_MAP.get(embed_type, None) + if embed_map: + filename = embed_map.get(name, None) + if filename: + url = _get_base_url('embedding') + filename + return url + raise KeyError("There is no {}. Only supports {}.".format(name, list(embed_map.keys()))) + else: + raise KeyError(f"There is no {embed_type}. Only supports bert, elmo, static") + +def _read_extend_url_file(filename, name)->str: + """ + filename中的内容使用制表符隔开,第一列是名称,第二列是下载的url地址 + + :param str filename: 在默认的路径下寻找file这个文件 + :param str name: 需要寻找的资源的名称 + :return: str or None + """ + cache_dir = get_cache_path() + filepath = os.path.join(cache_dir, filename) + if os.path.exists(filepath): + with open(filepath, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + if len(parts) == 2: + if name == parts[0]: + return parts[1] + return None + +def _get_dataset_url(name): + """ + 给定dataset的名称,返回下载url + + :param str name: 给定dataset的名称,比如imdb, sst-2等 + :return: str + """ + # 从扩展中寻找下载的url + url = _read_extend_url_file(FASTNLP_EXTEND_DATASET_URL, name) + if url: + return url + + filename = DATASET_DIR.get(name, None) + if filename: + url = _get_base_url('dataset') + filename + return url + else: + raise KeyError(f"There is no {name}.") def split_filename_suffix(filepath): """ - 给定filepath返回对应的name和suffix - :param filepath: + 给定filepath 返回对应的name和suffix. 如果后缀是多个点,仅支持.tar.gz类型 + + :param filepath: 文件路径 :return: filename, suffix """ filename = os.path.basename(filepath) @@ -127,21 +280,19 @@ def split_filename_suffix(filepath): def get_from_cache(url: str, cache_dir: Path = None) -> Path: """ - 尝试在cache_dir中寻找url定义的资源; 如果没有找到。则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。 - 如果从url中下载的资源解压后有多个文件,则返回directory的路径; 如果只有一个资源,则返回具体的路径。 - + 尝试在cache_dir中寻找url定义的资源; 如果没有找到; 则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。会将下载的 + 文件解压,将解压后的文件全部放在cache_dir文件夹中。 + + 如果从url中下载的资源解压后有多个文件,则返回目录的路径; 如果只有一个资源文件,则返回具体的路径。 + + :param url: 资源的 url + :param cache_dir: cache 目录 + :return: 路径 """ cache_dir.mkdir(parents=True, exist_ok=True) filename = re.sub(r".+/", "", url) dir_name, suffix = split_filename_suffix(filename) - sep_index = dir_name[::-1].index('-') - if sep_index<0: - check_sum = None - else: - check_sum = dir_name[-sep_index+1:] - sep_index = len(dir_name) if sep_index==-1 else -sep_index-1 - dir_name = dir_name[:sep_index] # 寻找与它名字匹配的内容, 而不关心后缀 match_dir_name = match_file(dir_name, cache_dir) @@ -154,11 +305,11 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: return get_filepath(cache_path) # make HEAD request to check ETag TODO ETag可以用来判断资源是否已经更新了,之后需要加上 - response = requests.head(url, headers={"User-Agent": "fastNLP"}) - if response.status_code != 200: - raise IOError( - f"HEAD request failed for url {url} with status code {response.status_code}." - ) + # response = requests.head(url, headers={"User-Agent": "fastNLP"}) + # if response.status_code != 200: + # raise IOError( + # f"HEAD request failed for url {url} with status code {response.status_code}." + # ) # add ETag to filename if it exists # etag = response.headers.get("ETag") @@ -166,74 +317,73 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: if not cache_path.exists(): # Download to temporary file, then copy to cache dir once finished. # Otherwise you get corrupt cache entries if the download gets interrupted. - fd, temp_filename = tempfile.mkstemp() - print("%s not found in cache, downloading to %s"%(url, temp_filename)) - # GET file object req = requests.get(url, stream=True, headers={"User-Agent": "fastNLP"}) - content_length = req.headers.get("Content-Length") - total = int(content_length) if content_length is not None else None - progress = tqdm(unit="B", total=total) - sha256 = hashlib.sha256() - with open(temp_filename, "wb") as temp_file: - for chunk in req.iter_content(chunk_size=1024): - if chunk: # filter out keep-alive new chunks - progress.update(len(chunk)) - temp_file.write(chunk) - sha256.update(chunk) - # check sum - digit = sha256.hexdigest()[:8] - if not check_sum: - assert digit == check_sum, "File corrupted when download." - progress.close() - print(f"Finish download from {url}.") - - # 开始解压 - delete_temp_dir = None - if suffix in ('.zip', '.tar.gz'): - uncompress_temp_dir = tempfile.mkdtemp() - delete_temp_dir = uncompress_temp_dir - print(f"Start to uncompress file to {uncompress_temp_dir}.") - if suffix == '.zip': - unzip_file(Path(temp_filename), Path(uncompress_temp_dir)) + if req.status_code == 200: + content_length = req.headers.get("Content-Length") + total = int(content_length) if content_length is not None else None + progress = tqdm(unit="B", total=total, unit_scale=1) + fd, temp_filename = tempfile.mkstemp() + print("%s not found in cache, downloading to %s" % (url, temp_filename)) + + with open(temp_filename, "wb") as temp_file: + for chunk in req.iter_content(chunk_size=1024 * 16): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + print(f"Finish download from {url}.") + + # 开始解压 + delete_temp_dir = None + if suffix in ('.zip', '.tar.gz'): + uncompress_temp_dir = tempfile.mkdtemp() + delete_temp_dir = uncompress_temp_dir + print(f"Start to uncompress file to {uncompress_temp_dir}") + if suffix == '.zip': + unzip_file(Path(temp_filename), Path(uncompress_temp_dir)) + else: + untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir)) + filenames = os.listdir(uncompress_temp_dir) + if len(filenames) == 1: + if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])): + uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0]) + + cache_path.mkdir(parents=True, exist_ok=True) + print("Finish un-compressing file.") else: - untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir)) - filenames = os.listdir(uncompress_temp_dir) - if len(filenames)==1: - if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])): - uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0]) - - cache_path.mkdir(parents=True, exist_ok=True) - print("Finish un-compressing file.") + uncompress_temp_dir = temp_filename + cache_path = str(cache_path) + suffix + success = False + try: + # 复制到指定的位置 + print(f"Copy file to {cache_path}") + if os.path.isdir(uncompress_temp_dir): + for filename in os.listdir(uncompress_temp_dir): + if os.path.isdir(os.path.join(uncompress_temp_dir, filename)): + shutil.copytree(os.path.join(uncompress_temp_dir, filename), cache_path / filename) + else: + shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path / filename) + else: + shutil.copyfile(uncompress_temp_dir, cache_path) + success = True + except Exception as e: + print(e) + raise e + finally: + if not success: + if cache_path.exists(): + if cache_path.is_file(): + os.remove(cache_path) + else: + shutil.rmtree(cache_path) + if delete_temp_dir: + shutil.rmtree(delete_temp_dir) + os.close(fd) + os.remove(temp_filename) + return get_filepath(cache_path) else: - uncompress_temp_dir = temp_filename - cache_path = str(cache_path) + suffix - success = False - try: - # 复制到指定的位置 - print(f"Copy file to {cache_path}.") - if os.path.isdir(uncompress_temp_dir): - for filename in os.listdir(uncompress_temp_dir): - shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path/filename) - else: - shutil.copyfile(uncompress_temp_dir, cache_path) - success = True - except Exception as e: - print(e) - raise e - finally: - if not success: - if cache_path.exists(): - if cache_path.is_file(): - os.remove(cache_path) - else: - shutil.rmtree(cache_path) - if delete_temp_dir: - shutil.rmtree(delete_temp_dir) - os.close(fd) - os.remove(temp_filename) - - return get_filepath(cache_path) + raise HTTPError(f"Status code:{req.status_code}. Fail to download from {url}.") def unzip_file(file: Path, to: Path): @@ -245,55 +395,30 @@ def unzip_file(file: Path, to: Path): zipObj.extractall(to) -def untar_gz_file(file:Path, to:Path): +def untar_gz_file(file: Path, to: Path): import tarfile with tarfile.open(file, 'r:gz') as tar: tar.extractall(to) -def match_file(dir_name: str, cache_dir: str) -> str: +def match_file(dir_name: str, cache_dir: Path) -> str: """ - 匹配的原则是,在cache_dir下的文件: (1) 与dir_name完全一致; (2) 除了后缀以外和dir_name完全一致。 + 匹配的原则是: 在cache_dir下的文件与dir_name完全一致, 或除了后缀以外和dir_name完全一致。 如果找到了两个匹配的结果将报错. 如果找到了则返回匹配的文件的名称; 没有找到返回空字符串 :param dir_name: 需要匹配的名称 :param cache_dir: 在该目录下找匹配dir_name是否存在 - :return: str + :return str: 做为匹配结果的字符串 """ files = os.listdir(cache_dir) matched_filenames = [] for file_name in files: - if re.match(dir_name+'$', file_name) or re.match(dir_name+'\\..*', file_name): + if re.match(dir_name + '$', file_name) or re.match(dir_name + '\\..*', file_name): matched_filenames.append(file_name) - if len(matched_filenames)==0: + if len(matched_filenames) == 0: return '' - elif len(matched_filenames)==1: + elif len(matched_filenames) == 1: return matched_filenames[-1] else: raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.") - - -if __name__ == '__main__': - cache_dir = Path('caches') - cache_dir = None - # 需要对cache_dir进行测试 - base_url = 'http://0.0.0.0:8888/file/download' - # if True: - # for filename in os.listdir(cache_dir): - # if os.path.isdir(os.path.join(cache_dir, filename)): - # shutil.rmtree(os.path.join(cache_dir, filename)) - # else: - # os.remove(os.path.join(cache_dir, filename)) - # 1. 测试.txt文件 - print(cached_path(base_url + '/{}'.format('txt_test-bcb4fe65.txt'), cache_dir)) - # 2. 测试.zip文件(只有一个文件) - print(cached_path(base_url + '/{}'.format('zip_test-40966d39.zip'), cache_dir)) - # 3. 测试.zip文件(有多个文件) - print(cached_path(base_url + '/{}'.format('zip_pack_test-70c0b20d.zip'), cache_dir)) - # 4. 测试.tar.gz文件 - print(cached_path(base_url + '/{}'.format('tar_gz_test-3e2679cf.tar.gz'), cache_dir)) - # 5. 测试.tar.gz多个文件 - print(cached_path(base_url + '/{}'.format('tar_gz_pack_test-08dfdccd.tar.gz'), cache_dir)) - - # 6. 测试.pkl文件 diff --git a/fastNLP/io/loader/__init__.py b/fastNLP/io/loader/__init__.py new file mode 100644 index 00000000..a4e6a6f5 --- /dev/null +++ b/fastNLP/io/loader/__init__.py @@ -0,0 +1,78 @@ +""" +Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle` 中。所有的Loader都支持以下的 +三个方法: ``__init__`` , ``_load`` , ``loads`` . 其中 ``__init__(...)`` 用于申明读取参数,以及说明该Loader支持的数据格式, +读取后 :class:`~fastNLP.DataSet` 中的 `field` ; ``_load(path)`` 方法传入文件路径读取单个文件,并返回 :class:`~fastNLP.DataSet` ; +``load(paths)`` 用于读取文件夹下的文件,并返回 :class:`~fastNLP.io.DataBundle` 类型的对象 , load()方法支持以下几种类型的参数: + +0.传入None + 将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。 + +1.传入一个文件的 path + 返回的 `data_bundle` 包含一个名为 `train` 的 dataset ,可以通过 ``data_bundle.datasets['train']`` 获取 + +2.传入一个文件夹目录 + 将读取的是这个文件夹下文件名中包含 `train` , `test` , `dev` 的文件,其它文件会被忽略。假设某个目录下的文件为:: + + | + +-train.txt + +-dev.txt + +-test.txt + +-other.txt + + 在 Loader().load('/path/to/dir') 返回的 `data_bundle` 中可以用 ``data_bundle.datasets['train']`` , ``data_bundle.datasets['dev']`` , + ``data_bundle.datasets['test']`` 获取对应的 `dataset` ,其中 `other.txt` 的内容会被忽略。假设某个目录下的文件为:: + + | + +-train.txt + +-dev.txt + + 在 Loader().load('/path/to/dir') 返回的 `data_bundle` 中可以用 ``data_bundle.datasets['train']`` , + ``data_bundle.datasets['dev']`` 获取对应的 dataset。 + +3.传入一个字典 + 字典的的 key 为 `dataset` 的名称,value 是该 `dataset` 的文件路径:: + + paths = {'train':'/path/to/train', 'dev': '/path/to/dev', 'test':'/path/to/test'} + + 在 Loader().load(paths) 返回的 `data_bundle` 中可以用 ``data_bundle.datasets['train']`` , ``data_bundle.datasets['dev']`` , + ``data_bundle.datasets['test']`` 来获取对应的 `dataset` + +fastNLP 目前提供了如下的 Loader + + + +""" + +__all__ = [ + 'YelpLoader', + 'YelpFullLoader', + 'YelpPolarityLoader', + 'IMDBLoader', + 'SSTLoader', + 'SST2Loader', + + 'ConllLoader', + 'Conll2003Loader', + 'Conll2003NERLoader', + 'OntoNotesNERLoader', + 'CTBLoader', + + 'Loader', + 'CSVLoader', + 'JsonLoader', + + 'CWSLoader', + + 'MNLILoader', + "QuoraLoader", + "SNLILoader", + "QNLILoader", + "RTELoader" +] +from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader +from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader +from .csv import CSVLoader +from .cws import CWSLoader +from .json import JsonLoader +from .loader import Loader +from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader diff --git a/fastNLP/io/loader/classification.py b/fastNLP/io/loader/classification.py new file mode 100644 index 00000000..dd85b4fe --- /dev/null +++ b/fastNLP/io/loader/classification.py @@ -0,0 +1,369 @@ +from ...core.dataset import DataSet +from ...core.instance import Instance +from .loader import Loader +import warnings +import os +import random +import shutil +import numpy as np + +class YelpLoader(Loader): + """ + 别名::class:`fastNLP.io.YelpLoader` :class:`fastNLP.io.loader.YelpLoader` + + 原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。 + + Example:: + "1","I got 'new' tires from the..." + "1","Don't waste your time..." + + 读取YelpFull, YelpPolarity的数据。可以通过xxx下载并预处理数据。 + 读取的DataSet将具备以下的数据结构 + + .. csv-table:: + :header: "raw_words", "target" + + "I got 'new' tires from them and... ", "1" + "Don't waste your time. We had two...", "1" + "...", "..." + + """ + + def __init__(self): + super(YelpLoader, self).__init__() + + def _load(self, path: str=None): + ds = DataSet() + with open(path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + sep_index = line.index(',') + target = line[:sep_index] + raw_words = line[sep_index + 1:] + if target.startswith("\""): + target = target[1:] + if target.endswith("\""): + target = target[:-1] + if raw_words.endswith("\""): + raw_words = raw_words[:-1] + if raw_words.startswith('"'): + raw_words = raw_words[1:] + raw_words = raw_words.replace('""', '"') # 替换双引号 + if raw_words: + ds.append(Instance(raw_words=raw_words, target=target)) + return ds + + +class YelpFullLoader(YelpLoader): + def download(self, dev_ratio: float = 0.1, seed: int = 0): + """ + 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + + Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances + in Neural Information Processing Systems 28 (NIPS 2015) + + 根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后在output_dir中有train.csv, test.csv, + dev.csv三个文件。 + + :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 + :param int seed: 划分dev时的随机数种子 + :return: str, 数据集的目录地址 + """ + + dataset_name = 'yelp-review-full' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否需要重新下载 + re_download = True + if dev_ratio>0: + dev_line_count = 0 + tr_line_count = 0 + with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ + open(os.path.join(data_dir, 'dev.csv'), 'r', encoding='utf-8') as f2: + for line in f1: + tr_line_count += 1 + for line in f2: + dev_line_count += 1 + if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): + re_download = True + else: + re_download = False + if re_download: + shutil.rmtree(data_dir) + data_dir = self._get_dataset_path(dataset_name=dataset_name) + + if not os.path.exists(os.path.join(data_dir, 'dev.csv')): + if dev_ratio > 0: + assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." + random.seed(int(seed)) + try: + with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f, \ + open(os.path.join(data_dir, 'middle_file.csv'), 'w', encoding='utf-8') as f1, \ + open(os.path.join(data_dir, 'dev.csv'), 'w', encoding='utf-8') as f2: + for line in f: + if random.random() < dev_ratio: + f2.write(line) + else: + f1.write(line) + os.remove(os.path.join(data_dir, 'train.csv')) + os.renames(os.path.join(data_dir, 'middle_file.csv'), os.path.join(data_dir, 'train.csv')) + finally: + if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): + os.remove(os.path.join(data_dir, 'middle_file.csv')) + + return data_dir + + +class YelpPolarityLoader(YelpLoader): + def download(self, dev_ratio: float = 0.1, seed: int = 0): + """ + 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + + Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances + in Neural Information Processing Systems 28 (NIPS 2015) + + 根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后从train中切分0.1作为dev + + :param float dev_ratio: 如果路径中不存在dev.csv, 从train划分多少作为dev的数据. 如果为0,则不划分dev + :param int seed: 划分dev时的随机数种子 + :return: str, 数据集的目录地址 + """ + dataset_name = 'yelp-review-polarity' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否符合比例要求 + re_download = True + if dev_ratio>0: + dev_line_count = 0 + tr_line_count = 0 + with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ + open(os.path.join(data_dir, 'dev.csv'), 'r', encoding='utf-8') as f2: + for line in f1: + tr_line_count += 1 + for line in f2: + dev_line_count += 1 + if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): + re_download = True + else: + re_download = False + if re_download: + shutil.rmtree(data_dir) + data_dir = self._get_dataset_path(dataset_name=dataset_name) + + if not os.path.exists(os.path.join(data_dir, 'dev.csv')): + if dev_ratio > 0: + assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." + random.seed(int(seed)) + try: + with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f, \ + open(os.path.join(data_dir, 'middle_file.csv'), 'w', encoding='utf-8') as f1, \ + open(os.path.join(data_dir, 'dev.csv'), 'w', encoding='utf-8') as f2: + for line in f: + if random.random() < dev_ratio: + f2.write(line) + else: + f1.write(line) + os.remove(os.path.join(data_dir, 'train.csv')) + os.renames(os.path.join(data_dir, 'middle_file.csv'), os.path.join(data_dir, 'train.csv')) + finally: + if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): + os.remove(os.path.join(data_dir, 'middle_file.csv')) + + return data_dir + + +class IMDBLoader(Loader): + """ + 别名::class:`fastNLP.io.IMDBLoader` :class:`fastNLP.io.loader.IMDBLoader` + + IMDBLoader读取后的数据将具有以下两列内容: raw_words: str, 需要分类的文本; target: str, 文本的标签 + DataSet具备以下的结构: + + .. csv-table:: + :header: "raw_words", "target" + + "Bromwell High is a cartoon ... ", "pos" + "Story of a man who has ...", "neg" + "...", "..." + + """ + + def __init__(self): + super(IMDBLoader, self).__init__() + + def _load(self, path: str): + dataset = DataSet() + with open(path, 'r', encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split('\t') + target = parts[0] + words = parts[1] + if words: + dataset.append(Instance(raw_words=words, target=target)) + + if len(dataset) == 0: + raise RuntimeError(f"{path} has no valid data.") + + return dataset + + def download(self, dev_ratio: float = 0.1, seed: int = 0): + """ + 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + + http://www.aclweb.org/anthology/P11-1015 + + 根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后从train中切分0.1作为dev + + :param float dev_ratio: 如果路径中没有dev.txt。从train划分多少作为dev的数据. 如果为0,则不划分dev + :param int seed: 划分dev时的随机数种子 + :return: str, 数据集的目录地址 + """ + dataset_name = 'aclImdb' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + if os.path.exists(os.path.join(data_dir, 'dev.txt')): # 存在dev的话,check是否符合比例要求 + re_download = True + if dev_ratio>0: + dev_line_count = 0 + tr_line_count = 0 + with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f1, \ + open(os.path.join(data_dir, 'dev.txt'), 'r', encoding='utf-8') as f2: + for line in f1: + tr_line_count += 1 + for line in f2: + dev_line_count += 1 + if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): + re_download = True + else: + re_download = False + if re_download: + shutil.rmtree(data_dir) + data_dir = self._get_dataset_path(dataset_name=dataset_name) + + if not os.path.exists(os.path.join(data_dir, 'dev.csv')): + if dev_ratio > 0: + assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." + random.seed(int(seed)) + try: + with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \ + open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \ + open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2: + for line in f: + if random.random() < dev_ratio: + f2.write(line) + else: + f1.write(line) + os.remove(os.path.join(data_dir, 'train.txt')) + os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt')) + finally: + if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): + os.remove(os.path.join(data_dir, 'middle_file.txt')) + + return data_dir + + +class SSTLoader(Loader): + """ + 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.loader.SSTLoader` + + 读取之后的DataSet具有以下的结构 + + .. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field + :header: "raw_words" + + "(3 (2 It) (4 (4 (2 's) (4 (3 (2 a)..." + "(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..." + "..." + + raw_words列是str。 + + """ + + def __init__(self): + super().__init__() + + def _load(self, path: str): + """ + 从path读取SST文件 + + :param str path: 文件路径 + :return: DataSet + """ + ds = DataSet() + with open(path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + ds.append(Instance(raw_words=line)) + return ds + + def download(self): + """ + 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + + https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf + + :return: str, 数据集的目录地址 + """ + output_dir = self._get_dataset_path(dataset_name='sst') + return output_dir + + +class SST2Loader(Loader): + """ + 数据SST2的Loader + 读取之后DataSet将如下所示 + + .. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field + :header: "raw_words", "target" + + "it 's a charming and often affecting...", "1" + "unflinchingly bleak and...", "0" + "..." + + test的DataSet没有target列。 + """ + + def __init__(self): + super().__init__() + + def _load(self, path: str): + """ + 从path读取SST2文件 + + :param str path: 数据路径 + :return: DataSet + """ + ds = DataSet() + + with open(path, 'r', encoding='utf-8') as f: + f.readline() # 跳过header + if 'test' in os.path.split(path)[1]: + warnings.warn("SST2's test file has no target.") + for line in f: + line = line.strip() + if line: + sep_index = line.index('\t') + raw_words = line[sep_index + 1:] + if raw_words: + ds.append(Instance(raw_words=raw_words)) + else: + for line in f: + line = line.strip() + if line: + raw_words = line[:-2] + target = line[-1] + if raw_words: + ds.append(Instance(raw_words=raw_words, target=target)) + return ds + + def download(self): + """ + 自动下载数据集,如果你使用了该数据集,请引用以下的文章 + + https://nlp.stanford.edu/pubs/SocherBauerManningNg_ACL2013.pdf + + :return: + """ + output_dir = self._get_dataset_path(dataset_name='sst-2') + return output_dir diff --git a/fastNLP/io/loader/conll.py b/fastNLP/io/loader/conll.py new file mode 100644 index 00000000..b2c89ecc --- /dev/null +++ b/fastNLP/io/loader/conll.py @@ -0,0 +1,264 @@ +from typing import Dict, Union + +from .loader import Loader +from ...core.dataset import DataSet +from ..file_reader import _read_conll +from ...core.instance import Instance +from .. import DataBundle +from ..utils import check_loader_paths +from ...core.const import Const + + +class ConllLoader(Loader): + """ + 别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.data_loader.ConllLoader` + + ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示: + + Example:: + + # 文件中的内容 + Nadim NNP B-NP B-PER + Ladki NNP I-NP I-PER + + AL-AIN NNP B-NP B-LOC + United NNP B-NP B-LOC + Arab NNP I-NP I-LOC + Emirates NNPS I-NP I-LOC + 1996-12-06 CD I-NP O + ... + + # 如果用以下的参数读取,返回的DataSet将包含raw_words和pos两个field, 这两个field的值分别取自于第0列与第1列 + dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll') + # 如果用以下的参数读取,返回的DataSet将包含raw_words和ner两个field, 这两个field的值分别取自于第0列与第2列 + dataset = ConllLoader(headers=['raw_words', 'ner'], indexes=[0, 3])._load('/path/to/train.conll') + # 如果用以下的参数读取,返回的DataSet将包含raw_words, pos和ner三个field + dataset = ConllLoader(headers=['raw_words', 'pos', 'ner'], indexes=[0, 1, 3])._load('/path/to/train.conll') + + ConllLoader返回的DataSet的field由传入的headers确定。 + + 数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 + + :param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 + :param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` + :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` + + """ + def __init__(self, headers, indexes=None, dropna=True): + super(ConllLoader, self).__init__() + if not isinstance(headers, (list, tuple)): + raise TypeError( + 'invalid headers: {}, should be list of strings'.format(headers)) + self.headers = headers + self.dropna = dropna + if indexes is None: + self.indexes = list(range(len(self.headers))) + else: + if len(indexes) != len(headers): + raise ValueError + self.indexes = indexes + + def _load(self, path): + """ + 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 + + :param str path: 文件的路径 + :return: DataSet + """ + ds = DataSet() + for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): + ins = {h: data[i] for i, h in enumerate(self.headers)} + ds.append(Instance(**ins)) + return ds + + +class Conll2003Loader(ConllLoader): + """ + 用于读取conll2003任务的数据。数据的内容应该类似与以下的内容, 第一列为raw_words, 第二列为pos, 第三列为chunking,第四列为ner。 + + Example:: + + Nadim NNP B-NP B-PER + Ladki NNP I-NP I-PER + + AL-AIN NNP B-NP B-LOC + United NNP B-NP B-LOC + Arab NNP I-NP I-LOC + Emirates NNPS I-NP I-LOC + 1996-12-06 CD I-NP O + ... + + 返回的DataSet的内容为 + + .. csv-table:: 下面是Conll2003Loader加载后数据具备的结构。 + :header: "raw_words", "pos", "chunk", "ner" + + "[Nadim, Ladki]", "[NNP, NNP]", "[B-NP, I-NP]", "[B-PER, I-PER]" + "[AL-AIN, United, Arab, ...]", "[NNP, NNP, NNP, ...]", "[B-NP, B-NP, I-NP, ...]", "[B-LOC, B-LOC, I-LOC, ...]" + "[...]", "[...]", "[...]", "[...]" + + """ + def __init__(self): + headers = [ + 'raw_words', 'pos', 'chunk', 'ner', + ] + super(Conll2003Loader, self).__init__(headers=headers) + + def _load(self, path): + """ + 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 + + :param str path: 文件的路径 + :return: DataSet + """ + ds = DataSet() + for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): + doc_start = False + for i, h in enumerate(self.headers): + field = data[i] + if str(field[0]).startswith('-DOCSTART-'): + doc_start = True + break + if doc_start: + continue + ins = {h: data[i] for i, h in enumerate(self.headers)} + ds.append(Instance(**ins)) + return ds + + def download(self, output_dir=None): + raise RuntimeError("conll2003 cannot be downloaded automatically.") + + +class Conll2003NERLoader(ConllLoader): + """ + 用于读取conll2003任务的NER数据。 + + Example:: + + Nadim NNP B-NP B-PER + Ladki NNP I-NP I-PER + + AL-AIN NNP B-NP B-LOC + United NNP B-NP B-LOC + Arab NNP I-NP I-LOC + Emirates NNPS I-NP I-LOC + 1996-12-06 CD I-NP O + ... + + 返回的DataSet的内容为 + + .. csv-table:: 下面是Conll2003Loader加载后数据具备的结构, target是BIO2编码 + :header: "raw_words", "target" + + "[Nadim, Ladki]", "[B-PER, I-PER]" + "[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" + "[...]", "[...]" + + """ + def __init__(self): + headers = [ + 'raw_words', 'target', + ] + super().__init__(headers=headers, indexes=[0, 3]) + + def _load(self, path): + """ + 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 + + :param str path: 文件的路径 + :return: DataSet + """ + ds = DataSet() + for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): + doc_start = False + for i, h in enumerate(self.headers): + field = data[i] + if str(field[0]).startswith('-DOCSTART-'): + doc_start = True + break + if doc_start: + continue + ins = {h: data[i] for i, h in enumerate(self.headers)} + ds.append(Instance(**ins)) + return ds + + def download(self): + raise RuntimeError("conll2003 cannot be downloaded automatically.") + + +class OntoNotesNERLoader(ConllLoader): + """ + 用以读取OntoNotes的NER数据,同时也是Conll2012的NER任务数据。将OntoNote数据处理为conll格式的过程可以参考 + https://github.com/yhcc/OntoNotes-5.0-NER。OntoNoteNERLoader将取第4列和第11列的内容。 + + 返回的DataSet的内容为 + + .. csv-table:: 下面是使用OntoNoteNERLoader读取的DataSet所具备的结构, target列是BIO编码 + :header: "raw_words", "target" + + "[Nadim, Ladki]", "[B-PER, I-PER]" + "[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" + "[...]", "[...]" + + """ + + def __init__(self): + super().__init__(headers=[Const.RAW_WORD, Const.TARGET], indexes=[3, 10]) + + def _load(self, path:str): + dataset = super()._load(path) + + def convert_to_bio(tags): + bio_tags = [] + flag = None + for tag in tags: + label = tag.strip("()*") + if '(' in tag: + bio_label = 'B-' + label + flag = label + elif flag: + bio_label = 'I-' + flag + else: + bio_label = 'O' + if ')' in tag: + flag = None + bio_tags.append(bio_label) + return bio_tags + + def convert_word(words): + converted_words = [] + for word in words: + word = word.replace('/.', '.') # 有些结尾的.是/.形式的 + if not word.startswith('-'): + converted_words.append(word) + continue + # 以下是由于这些符号被转义了,再转回来 + tfrs = {'-LRB-':'(', + '-RRB-': ')', + '-LSB-': '[', + '-RSB-': ']', + '-LCB-': '{', + '-RCB-': '}' + } + if word in tfrs: + converted_words.append(tfrs[word]) + else: + converted_words.append(word) + return converted_words + + dataset.apply_field(convert_word, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) + dataset.apply_field(convert_to_bio, field_name=Const.TARGET, new_field_name=Const.TARGET) + + return dataset + + def download(self): + raise RuntimeError("Ontonotes cannot be downloaded automatically, you can refer " + "https://github.com/yhcc/OntoNotes-5.0-NER to download and preprocess.") + + +class CTBLoader(Loader): + def __init__(self): + super().__init__() + + def _load(self, path:str): + pass diff --git a/fastNLP/io/loader/csv.py b/fastNLP/io/loader/csv.py new file mode 100644 index 00000000..166f912b --- /dev/null +++ b/fastNLP/io/loader/csv.py @@ -0,0 +1,32 @@ +from ...core.dataset import DataSet +from ...core.instance import Instance +from ..file_reader import _read_csv +from .loader import Loader + + +class CSVLoader(Loader): + """ + 别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader` + + 读取CSV格式的数据集, 返回 ``DataSet`` 。 + + :param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 + 若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` + :param str sep: CSV文件中列与列之间的分隔符. Default: "," + :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . + Default: ``False`` + """ + + def __init__(self, headers=None, sep=",", dropna=False): + super().__init__() + self.headers = headers + self.sep = sep + self.dropna = dropna + + def _load(self, path): + ds = DataSet() + for idx, data in _read_csv(path, headers=self.headers, + sep=self.sep, dropna=self.dropna): + ds.append(Instance(**data)) + return ds + diff --git a/fastNLP/io/loader/cws.py b/fastNLP/io/loader/cws.py new file mode 100644 index 00000000..3af28116 --- /dev/null +++ b/fastNLP/io/loader/cws.py @@ -0,0 +1,41 @@ +from .loader import Loader +from ...core.dataset import DataSet +from ...core.instance import Instance + + +class CWSLoader(Loader): + """ + 分词任务数据加载器, + SigHan2005的数据可以用xxx下载并预处理 + + CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: + + Example:: + + 上海 浦东 开发 与 法制 建设 同步 + 新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 ) + ... + + 该Loader读取后的DataSet具有如下的结构 + + .. csv-table:: + :header: "raw_words" + + "上海 浦东 开发 与 法制 建设 同步" + "新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" + "..." + """ + def __init__(self): + super().__init__() + + def _load(self, path:str): + ds = DataSet() + with open(path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + ds.append(Instance(raw_words=line)) + return ds + + def download(self, output_dir=None): + raise RuntimeError("You can refer {} for sighan2005's data downloading.") diff --git a/fastNLP/io/loader/json.py b/fastNLP/io/loader/json.py new file mode 100644 index 00000000..8856b73a --- /dev/null +++ b/fastNLP/io/loader/json.py @@ -0,0 +1,40 @@ +from ...core.dataset import DataSet +from ...core.instance import Instance +from ..file_reader import _read_json +from .loader import Loader + + +class JsonLoader(Loader): + """ + 别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.loader.JsonLoader` + + 读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 + + :param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name + ``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` , + `value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名 + ``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` + :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . + Default: ``False`` + """ + + def __init__(self, fields=None, dropna=False): + super(JsonLoader, self).__init__() + self.dropna = dropna + self.fields = None + self.fields_list = None + if fields: + self.fields = {} + for k, v in fields.items(): + self.fields[k] = k if v is None else v + self.fields_list = list(self.fields.keys()) + + def _load(self, path): + ds = DataSet() + for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): + if self.fields: + ins = {self.fields[k]: v for k, v in d.items()} + else: + ins = d + ds.append(Instance(**ins)) + return ds diff --git a/fastNLP/io/loader/loader.py b/fastNLP/io/loader/loader.py new file mode 100644 index 00000000..607d6920 --- /dev/null +++ b/fastNLP/io/loader/loader.py @@ -0,0 +1,75 @@ +from ...core.dataset import DataSet +from .. import DataBundle +from ..utils import check_loader_paths +from typing import Union, Dict +import os +from ..file_utils import _get_dataset_url, get_cache_path, cached_path + +class Loader: + def __init__(self): + pass + + def _load(self, path:str) -> DataSet: + raise NotImplementedError + + def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle: + """ + 从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 + + 读取的field根据ConllLoader初始化时传入的headers决定。 + + :param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式 + (0) 如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 + + (1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件 + 名包含'train'、 'dev'、 'test'则会报错 + + Example:: + + data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train、 + # dev、 test等有所变化,可以通过以下的方式取出DataSet + tr_data = data_bundle.datasets['train'] + te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段 + + (2) 传入文件路径 + + Example:: + + data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' + tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet + + (3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test + + Example:: + + paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} + data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" + dev_data = data_bundle.datasets['dev'] + + :return: 返回的:class:`~fastNLP.io.DataBundle` + """ + if paths is None: + paths = self.download() + paths = check_loader_paths(paths) + datasets = {name: self._load(path) for name, path in paths.items()} + data_bundle = DataBundle(datasets=datasets) + return data_bundle + + def download(self): + raise NotImplementedError(f"{self.__class__} cannot download data automatically.") + + def _get_dataset_path(self, dataset_name): + """ + 传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存 + + :param str dataset_name: 数据集的名称 + :return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。 + """ + + default_cache_path = get_cache_path() + url = _get_dataset_url(dataset_name) + output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') + + return output_dir + + diff --git a/fastNLP/io/loader/matching.py b/fastNLP/io/loader/matching.py new file mode 100644 index 00000000..58fa0d6f --- /dev/null +++ b/fastNLP/io/loader/matching.py @@ -0,0 +1,302 @@ +import warnings +from .loader import Loader +from .json import JsonLoader +from ...core.const import Const +from .. import DataBundle +import os +from typing import Union, Dict +from ...core.dataset import DataSet +from ...core.instance import Instance + + +class MNLILoader(Loader): + """ + 读取MNLI任务的数据,读取之后的DataSet中包含以下的内容,words0是sentence1, words1是sentence2, target是gold_label, 测试集中没 + 有target列。 + + .. csv-table:: + :header: "raw_words1", "raw_words2", "target" + + "The new rights are...", "Everyone really likes..", "neutral" + "This site includes a...", "The Government Executive...", "contradiction" + "...", "...","." + + """ + def __init__(self): + super().__init__() + + def _load(self, path:str): + ds = DataSet() + with open(path, 'r', encoding='utf-8') as f: + f.readline() # 跳过header + if path.endswith("test.tsv"): + warnings.warn("RTE's test file has no target.") + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[8] + raw_words2 = parts[9] + if raw_words1 and raw_words2: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2)) + else: + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[8] + raw_words2 = parts[9] + target = parts[-1] + if raw_words1 and raw_words2 and target: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) + return ds + + def load(self, paths:str=None): + """ + + :param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv, + test_mismatched.tsv, train.tsv文件夹 + :return: DataBundle + """ + if paths: + paths = os.path.abspath(os.path.expanduser(paths)) + else: + paths = self.download() + if not os.path.isdir(paths): + raise NotADirectoryError(f"{paths} is not a valid directory.") + + files = {'dev_matched':"dev_matched.tsv", + "dev_mismatched":"dev_mismatched.tsv", + "test_matched":"test_matched.tsv", + "test_mismatched":"test_mismatched.tsv", + "train":'train.tsv'} + + datasets = {} + for name, filename in files.items(): + filepath = os.path.join(paths, filename) + if not os.path.isfile(filepath): + if 'test' not in name: + raise FileNotFoundError(f"{name} not found in directory {filepath}.") + datasets[name] = self._load(filepath) + + data_bundle = DataBundle(datasets=datasets) + + return data_bundle + + def download(self): + """ + 如果你使用了这个数据,请引用 + + https://www.nyu.edu/projects/bowman/multinli/paper.pdf + :return: + """ + output_dir = self._get_dataset_path('mnli') + return output_dir + + +class SNLILoader(JsonLoader): + """ + 读取之后的DataSet中的field情况为 + + .. csv-table:: 下面是使用SNLILoader加载的DataSet所具备的field + :header: "raw_words1", "raw_words2", "target" + + "The new rights are...", "Everyone really likes..", "neutral" + "This site includes a...", "The Government Executive...", "entailment" + "...", "...", "." + + """ + def __init__(self): + super().__init__(fields={ + 'sentence1': Const.RAW_WORDS(0), + 'sentence2': Const.RAW_WORDS(1), + 'gold_label': Const.TARGET, + }) + + def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle: + """ + 从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 + + 读取的field根据ConllLoader初始化时传入的headers决定。 + + :param str paths: 传入一个目录, 将在该目录下寻找snli_1.0_train.jsonl, snli_1.0_dev.jsonl + 和snli_1.0_test.jsonl三个文件。 + + :return: 返回的:class:`~fastNLP.io.DataBundle` + """ + _paths = {} + if paths is None: + paths = self.download() + if paths: + if os.path.isdir(paths): + if not os.path.isfile(os.path.join(paths, 'snli_1.0_train.jsonl')): + raise FileNotFoundError(f"snli_1.0_train.jsonl is not found in {paths}") + _paths['train'] = os.path.join(paths, 'snli_1.0_train.jsonl') + for filename in ['snli_1.0_dev.jsonl', 'snli_1.0_test.jsonl']: + filepath = os.path.join(paths, filename) + _paths[filename.split('_')[-1].split('.')[0]] = filepath + paths = _paths + else: + raise NotADirectoryError(f"{paths} is not a valid directory.") + + datasets = {name: self._load(path) for name, path in paths.items()} + data_bundle = DataBundle(datasets=datasets) + return data_bundle + + def download(self): + """ + 如果您的文章使用了这份数据,请引用 + + http://nlp.stanford.edu/pubs/snli_paper.pdf + + :return: str + """ + return self._get_dataset_path('snli') + + +class QNLILoader(JsonLoader): + """ + QNLI数据集的Loader, + 加载的DataSet将具备以下的field, raw_words1是question, raw_words2是sentence, target是label + + .. csv-table:: + :header: "raw_words1", "raw_words2", "target" + + "What came into force after the new...", "As of that day...", "entailment" + "What is the first major...", "The most important tributaries", "not_entailment" + "...","." + + test数据集没有target列 + + """ + def __init__(self): + super().__init__() + + def _load(self, path): + ds = DataSet() + + with open(path, 'r', encoding='utf-8') as f: + f.readline() # 跳过header + if path.endswith("test.tsv"): + warnings.warn("QNLI's test file has no target.") + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[1] + raw_words2 = parts[2] + if raw_words1 and raw_words2: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2)) + else: + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[1] + raw_words2 = parts[2] + target = parts[-1] + if raw_words1 and raw_words2 and target: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) + return ds + + def download(self): + """ + 如果您的实验使用到了该数据,请引用 + + TODO 补充 + + :return: + """ + return self._get_dataset_path('qnli') + + +class RTELoader(Loader): + """ + RTE数据的loader + 加载的DataSet将具备以下的field, raw_words1是sentence0,raw_words2是sentence1, target是label + + .. csv-table:: + :header: "raw_words1", "raw_words2", "target" + + "Dana Reeve, the widow of the actor...", "Christopher Reeve had an...", "not_entailment" + "Yet, we now are discovering that...", "Bacteria is winning...", "entailment" + "...","." + + test数据集没有target列 + """ + def __init__(self): + super().__init__() + + def _load(self, path:str): + ds = DataSet() + + with open(path, 'r', encoding='utf-8') as f: + f.readline() # 跳过header + if path.endswith("test.tsv"): + warnings.warn("RTE's test file has no target.") + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[1] + raw_words2 = parts[2] + if raw_words1 and raw_words2: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2)) + else: + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[1] + raw_words2 = parts[2] + target = parts[-1] + if raw_words1 and raw_words2 and target: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) + return ds + + def download(self): + return self._get_dataset_path('rte') + + +class QuoraLoader(Loader): + """ + Quora matching任务的数据集Loader + + 支持读取的文件中的内容,应该有以下的形式, 以制表符分隔,且前三列的内容必须是:第一列是label,第二列和第三列是句子 + + Example:: + + 1 How do I get funding for my web based startup idea ? How do I get seed funding pre product ? 327970 + 1 How can I stop my depression ? What can I do to stop being depressed ? 339556 + ... + + 加载的DataSet将具备以下的field + + .. csv-table:: + :header: "raw_words1", "raw_words2", "target" + + "What should I do to avoid...", "1" + "How do I not sleep in a boring class...", "0" + "...","." + + """ + def __init__(self): + super().__init__() + + def _load(self, path:str): + ds = DataSet() + + with open(path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[1] + raw_words2 = parts[2] + target = parts[0] + if raw_words1 and raw_words2 and target: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) + return ds + + def download(self): + raise RuntimeError("Quora cannot be downloaded automatically.") diff --git a/fastNLP/io/pipe/__init__.py b/fastNLP/io/pipe/__init__.py new file mode 100644 index 00000000..ad68f486 --- /dev/null +++ b/fastNLP/io/pipe/__init__.py @@ -0,0 +1,37 @@ +""" +Pipe用于处理通过 Loader 读取的数据,所有的 Pipe 都包含 ``process`` 和 ``process_from_file`` 两种方法。 +``process(data_bundle)`` 传入一个 :class:`~fastNLP.io.DataBundle` 类型的对象, 在传入的 `data_bundle` 上进行原位修改,并将其返回; +``process_from_file(paths)`` 传入的文件路径,返回一个 :class:`~fastNLP.io.DataBundle` 类型的对象。 +``process(data_bundle)`` 或者 ``process_from_file(paths)`` 的返回 `data_bundle` 中的 :class:`~fastNLP.DataSet` +一般都包含原文与转换为index的输入以及转换为index的target;除了 :class:`~fastNLP.DataSet` 之外, +`data_bundle` 还会包含将field转为index时所建立的词表。 + +""" +__all__ = [ + "YelpFullPipe", + "YelpPolarityPipe", + "SSTPipe", + "SST2Pipe", + "IMDBPipe", + + "Conll2003NERPipe", + "OntoNotesNERPipe", + + "MatchingBertPipe", + "RTEBertPipe", + "SNLIBertPipe", + "QuoraBertPipe", + "QNLIBertPipe", + "MNLIBertPipe", + "MatchingPipe", + "RTEPipe", + "SNLIPipe", + "QuoraPipe", + "QNLIPipe", + "MNLIPipe", +] + +from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe +from .conll import Conll2003NERPipe, OntoNotesNERPipe +from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ + MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py new file mode 100644 index 00000000..429b6552 --- /dev/null +++ b/fastNLP/io/pipe/classification.py @@ -0,0 +1,444 @@ +from nltk import Tree + +from ..base_loader import DataBundle +from ...core.vocabulary import Vocabulary +from ...core.const import Const +from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader +from ...core.dataset import DataSet +from ...core.instance import Instance + +from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance +from .pipe import Pipe +import re +nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') +from ...core.utils import cache_results + +class _CLSPipe(Pipe): + """ + 分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列 + + """ + def __init__(self, tokenizer:str='spacy', lang='en'): + self.tokenizer = get_tokenizer(tokenizer, lang=lang) + + def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): + """ + 将DataBundle中的数据进行tokenize + + :param DataBundle data_bundle: + :param str field_name: + :param str new_field_name: + :return: 传入的DataBundle对象 + """ + new_field_name = new_field_name or field_name + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) + + return data_bundle + + def _granularize(self, data_bundle, tag_map): + """ + 该函数对data_bundle中'target'列中的内容进行转换。 + + :param data_bundle: + :param dict tag_map: 将target列中的tag做以下的映射,比如{"0":0, "1":0, "3":1, "4":1}, 则会删除target为"2"的instance, + 且将"1"认为是第0类。 + :return: 传入的data_bundle + """ + for name in list(data_bundle.datasets.keys()): + dataset = data_bundle.get_dataset(name) + dataset.apply_field(lambda target:tag_map.get(target, -100), field_name=Const.TARGET, + new_field_name=Const.TARGET) + dataset.drop(lambda ins:ins[Const.TARGET] == -100) + data_bundle.set_dataset(dataset, name) + return data_bundle + + +def _clean_str(words): + """ + heavily borrowed from github + https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb + :param sentence: is a str + :return: + """ + words_collection = [] + for word in words: + if word in ['-lrb-', '-rrb-', '', '-r', '-l', 'b-']: + continue + tt = nonalpnum.split(word) + t = ''.join(tt) + if t != '': + words_collection.append(t) + + return words_collection + + +class YelpFullPipe(_CLSPipe): + """ + 处理YelpFull的数据, 处理之后DataSet中的内容如下 + + .. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field + :header: "raw_words", "words", "target", "seq_len" + + "It 's a ...", "[4, 2, 10, ...]", 0, 10 + "Offers that ...", "[20, 40, ...]", 1, 21 + "...", "[...]", ., . + + :param bool lower: 是否对输入进行小写化。 + :param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将1、2归为1类,4、5归为一类,丢掉2;若为3, 则有3分类问题,将 + 1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 + :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 + """ + def __init__(self, lower:bool=False, granularity=5, tokenizer:str='spacy'): + super().__init__(tokenizer=tokenizer, lang='en') + self.lower = lower + assert granularity in (2, 3, 5), "granularity can only be 2,3,5." + self.granularity = granularity + + if granularity==2: + self.tag_map = {"1": 0, "2": 0, "4": 1, "5": 1} + elif granularity==3: + self.tag_map = {"1": 0, "2": 0, "3":1, "4": 2, "5": 2} + else: + self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4} + + def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): + """ + 将DataBundle中的数据进行tokenize + + :param DataBundle data_bundle: + :param str field_name: + :param str new_field_name: + :return: 传入的DataBundle对象 + """ + new_field_name = new_field_name or field_name + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) + dataset.apply_field(_clean_str, field_name=field_name, new_field_name=new_field_name) + return data_bundle + + def process(self, data_bundle): + """ + 传入的DataSet应该具备如下的结构 + + .. csv-table:: + :header: "raw_words", "target" + + "I got 'new' tires from them and... ", "1" + "Don't waste your time. We had two...", "1" + "...", "..." + + :param data_bundle: + :return: + """ + + # 复制一列words + data_bundle = _add_words_field(data_bundle, lower=self.lower) + + # 进行tokenize + data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) + + # 根据granularity设置tag + data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) + + # 删除空行 + data_bundle = _drop_empty_instance(data_bundle, field_name=Const.INPUT) + + # index + data_bundle = _indexize(data_bundle=data_bundle) + + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(Const.INPUT) + + data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) + data_bundle.set_target(Const.TARGET) + + return data_bundle + + def process_from_file(self, paths=None): + """ + + :param paths: + :return: DataBundle + """ + data_bundle = YelpFullLoader().load(paths) + return self.process(data_bundle=data_bundle) + + +class YelpPolarityPipe(_CLSPipe): + """ + 处理YelpPolarity的数据, 处理之后DataSet中的内容如下 + + .. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field + :header: "raw_words", "words", "target", "seq_len" + + "It 's a ...", "[4, 2, 10, ...]", 0, 10 + "Offers that ...", "[20, 40, ...]", 1, 21 + "...", "[...]", ., . + + :param bool lower: 是否对输入进行小写化。 + :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 + """ + def __init__(self, lower:bool=False, tokenizer:str='spacy'): + super().__init__(tokenizer=tokenizer, lang='en') + self.lower = lower + + def process(self, data_bundle): + # 复制一列words + data_bundle = _add_words_field(data_bundle, lower=self.lower) + + # 进行tokenize + data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) + # index + data_bundle = _indexize(data_bundle=data_bundle) + + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(Const.INPUT) + + data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) + data_bundle.set_target(Const.TARGET) + + return data_bundle + + def process_from_file(self, paths=None): + """ + + :param str paths: + :return: DataBundle + """ + data_bundle = YelpPolarityLoader().load(paths) + return self.process(data_bundle=data_bundle) + + +class SSTPipe(_CLSPipe): + """ + 别名::class:`fastNLP.io.SSTPipe` :class:`fastNLP.io.pipe.SSTPipe` + + 经过该Pipe之后,DataSet中具备的field如下所示 + + .. csv-table:: 下面是使用SSTPipe处理后的DataSet所具备的field + :header: "raw_words", "words", "target", "seq_len" + + "It 's a ...", "[4, 2, 10, ...]", 0, 16 + "Offers that ...", "[20, 40, ...]", 1, 18 + "...", "[...]", ., . + + :param bool subtree: 是否将train, test, dev数据展开为子树,扩充数据量。 Default: ``False`` + :param bool train_subtree: 是否将train集通过子树扩展数据。 + :param bool lower: 是否对输入进行小写化。 + :param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将0、1归为1类,3、4归为一类,丢掉2;若为3, 则有3分类问题,将 + 0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。 + :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 + """ + + def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'): + super().__init__(tokenizer=tokenizer, lang='en') + self.subtree = subtree + self.train_tree = train_subtree + self.lower = lower + assert granularity in (2, 3, 5), "granularity can only be 2,3,5." + self.granularity = granularity + + if granularity==2: + self.tag_map = {"0": 0, "1": 0, "3": 1, "4": 1} + elif granularity==3: + self.tag_map = {"0": 0, "1": 0, "2":1, "3": 2, "4": 2} + else: + self.tag_map = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4} + + def process(self, data_bundle:DataBundle): + """ + 对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 + + .. csv-table:: + :header: "raw_words" + + "(3 (2 It) (4 (4 (2 's) (4 (3 (2 a)..." + "(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..." + "..." + + :param ~fastNLP.io.DataBundle data_bundle: 需要处理的DataBundle对象 + :return: + """ + # 先取出subtree + for name in list(data_bundle.datasets.keys()): + dataset = data_bundle.get_dataset(name) + ds = DataSet() + use_subtree = self.subtree or (name == 'train' and self.train_tree) + for ins in dataset: + raw_words = ins['raw_words'] + tree = Tree.fromstring(raw_words) + if use_subtree: + for t in tree.subtrees(): + raw_words = " ".join(t.leaves()) + instance = Instance(raw_words=raw_words, target=t.label()) + ds.append(instance) + else: + instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label()) + ds.append(instance) + data_bundle.set_dataset(ds, name) + + _add_words_field(data_bundle, lower=self.lower) + + # 进行tokenize + data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) + + # 根据granularity设置tag + data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) + + # index + data_bundle = _indexize(data_bundle=data_bundle) + + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(Const.INPUT) + + data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) + data_bundle.set_target(Const.TARGET) + + return data_bundle + + def process_from_file(self, paths=None): + data_bundle = SSTLoader().load(paths) + return self.process(data_bundle=data_bundle) + + +class SST2Pipe(_CLSPipe): + """ + 加载SST2的数据, 处理完成之后DataSet将拥有以下的field + + .. csv-table:: + :header: "raw_words", "words", "target", "seq_len" + + "it 's a charming and... ", "[3, 4, 5, 6, 7,...]", 1, 43 + "unflinchingly bleak and...", "[10, 11, 7,...]", 1, 21 + "...", "...", ., . + + :param bool lower: 是否对输入进行小写化。 + :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 + """ + def __init__(self, lower=False, tokenizer='spacy'): + super().__init__(tokenizer=tokenizer, lang='en') + self.lower = lower + + def process(self, data_bundle:DataBundle): + """ + 可以处理的DataSet应该具备如下的结构 + + .. csv-table:: + :header: "raw_words", "target" + + "it 's a charming and... ", 1 + "unflinchingly bleak and...", 1 + "...", "..." + + :param data_bundle: + :return: + """ + _add_words_field(data_bundle, self.lower) + + data_bundle = self._tokenize(data_bundle=data_bundle) + + src_vocab = Vocabulary() + src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT, + no_create_entry_dataset=[dataset for name,dataset in data_bundle.datasets.items() if + name != 'train']) + src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) + + tgt_vocab = Vocabulary(unknown=None, padding=None) + tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) + datasets = [] + for name, dataset in data_bundle.datasets.items(): + if dataset.has_field(Const.TARGET): + datasets.append(dataset) + tgt_vocab.index_dataset(*datasets, field_name=Const.TARGET) + + data_bundle.set_vocab(src_vocab, Const.INPUT) + data_bundle.set_vocab(tgt_vocab, Const.TARGET) + + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(Const.INPUT) + + data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) + data_bundle.set_target(Const.TARGET) + + return data_bundle + + def process_from_file(self, paths=None): + """ + + :param str paths: 如果为None,则自动下载并缓存到fastNLP的缓存地址。 + :return: DataBundle + """ + data_bundle = SST2Loader().load(paths) + return self.process(data_bundle) + + +class IMDBPipe(_CLSPipe): + """ + 经过本Pipe处理后DataSet将如下 + + .. csv-table:: 输出DataSet的field + :header: "raw_words", "words", "target", "seq_len" + + "Bromwell High is a cartoon ... ", "[3, 5, 6, 9, ...]", 0, 20 + "Story of a man who has ...", "[20, 43, 9, 10, ...]", 1, 31 + "...", "[...]", ., . + + 其中raw_words为str类型,是原文; words是转换为index的输入; target是转换为index的目标值; + words列被设置为input; target列被设置为target。 + + :param bool lower: 是否将words列的数据小写。 + :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 + """ + def __init__(self, lower:bool=False, tokenizer:str='spacy'): + super().__init__(tokenizer=tokenizer, lang='en') + self.lower = lower + + def process(self, data_bundle:DataBundle): + """ + 期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型 + + .. csv-table:: 输入DataSet的field + :header: "raw_words", "target" + + "Bromwell High is a cartoon ... ", "pos" + "Story of a man who has ...", "neg" + "...", "..." + + :param DataBunlde data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和target两个field,且raw_words列应该为str, + target列应该为str。 + :return: DataBundle + """ + # 替换
+ def replace_br(raw_words): + raw_words = raw_words.replace("
", ' ') + return raw_words + + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(replace_br, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) + + _add_words_field(data_bundle, lower=self.lower) + self._tokenize(data_bundle, field_name=Const.INPUT, new_field_name=Const.INPUT) + _indexize(data_bundle) + + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(Const.INPUT) + dataset.set_input(Const.INPUT, Const.INPUT_LEN) + dataset.set_target(Const.TARGET) + + return data_bundle + + def process_from_file(self, paths=None): + """ + + :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 + :return: DataBundle + """ + # 读取数据 + data_bundle = IMDBLoader().load(paths) + data_bundle = self.process(data_bundle) + + return data_bundle + + + diff --git a/fastNLP/io/pipe/conll.py b/fastNLP/io/pipe/conll.py new file mode 100644 index 00000000..0379a45b --- /dev/null +++ b/fastNLP/io/pipe/conll.py @@ -0,0 +1,148 @@ +from .pipe import Pipe +from .. import DataBundle +from .utils import iob2, iob2bioes +from ...core.const import Const +from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader +from .utils import _indexize, _add_words_field + + +class _NERPipe(Pipe): + """ + NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 + (创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 + Vocabulary转换为index。 + + raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 + target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target, seq_len。 + + :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 + :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 + :param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。 + """ + + def __init__(self, encoding_type: str = 'bio', lower: bool = False, target_pad_val=0): + if encoding_type == 'bio': + self.convert_tag = iob2 + else: + self.convert_tag = lambda words: iob2bioes(iob2(words)) + self.lower = lower + self.target_pad_val = int(target_pad_val) + + def process(self, data_bundle: DataBundle) -> DataBundle: + """ + 支持的DataSet的field为 + + .. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader + :header: "raw_words", "target" + + "[Nadim, Ladki]", "[B-PER, I-PER]" + "[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" + "[...]", "[...]" + + + :param DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。 + 在传入DataBundle基础上原位修改。 + :return: DataBundle + + Example:: + + data_bundle = Conll2003Loader().load('/path/to/conll2003/') + data_bundle = Conll2003NERPipe().process(data_bundle) + + # 获取train + tr_data = data_bundle.get_dataset('train') + + # 获取target这个field的词表 + target_vocab = data_bundle.get_vocab('target') + # 获取words这个field的词表 + word_vocab = data_bundle.get_vocab('words') + + """ + # 转换tag + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) + + _add_words_field(data_bundle, lower=self.lower) + + # index + _indexize(data_bundle) + + input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] + target_fields = [Const.TARGET, Const.INPUT_LEN] + + for name, dataset in data_bundle.datasets.items(): + dataset.set_pad_val(Const.TARGET, self.target_pad_val) + dataset.add_seq_len(Const.INPUT) + + data_bundle.set_input(*input_fields) + data_bundle.set_target(*target_fields) + + return data_bundle + + def process_from_file(self, paths) -> DataBundle: + """ + + :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.ConllLoader` 的load函数。 + :return: DataBundle + """ + # 读取数据 + data_bundle = Conll2003NERLoader().load(paths) + data_bundle = self.process(data_bundle) + + return data_bundle + + +class Conll2003NERPipe(_NERPipe): + """ + Conll2003的NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 + (创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 + Vocabulary转换为index。 + 经过该Pipe过后,DataSet中的内容如下所示 + + .. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader + :header: "raw_words", "words", "target", "seq_len" + + "[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2 + "[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 10 + "[...]", "[...]", "[...]", . + + raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 + target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 + + :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 + :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 + :param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。 + """ + + def process_from_file(self, paths) -> DataBundle: + """ + + :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.ConllLoader` 的load函数。 + :return: DataBundle + """ + # 读取数据 + data_bundle = Conll2003NERLoader().load(paths) + data_bundle = self.process(data_bundle) + + return data_bundle + + +class OntoNotesNERPipe(_NERPipe): + """ + 处理OntoNotes的NER数据,处理之后DataSet中的field情况为 + + .. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader + :header: "raw_words", "words", "target", "seq_len" + + "[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2 + "[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 6 + "[...]", "[...]", "[...]", . + + :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 + :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 + :param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。 + """ + + def process_from_file(self, paths): + data_bundle = OntoNotesNERLoader().load(paths) + return self.process(data_bundle) diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py new file mode 100644 index 00000000..9f7c7d68 --- /dev/null +++ b/fastNLP/io/pipe/matching.py @@ -0,0 +1,252 @@ + +from .pipe import Pipe +from .utils import get_tokenizer +from ...core.const import Const +from ...core.vocabulary import Vocabulary +from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader + + +class MatchingBertPipe(Pipe): + """ + Matching任务的Bert pipe,输出的DataSet将包含以下的field + + .. csv-table:: + :header: "raw_words1", "raw_words2", "words", "target", "seq_len" + + "The new rights are...", "Everyone really likes..", "[2, 3, 4, 5, ...]", 1, 10 + "This site includes a...", "The Government Executive...", "[11, 12, 13,...]", 0, 5 + "...", "...", "[...]", ., . + + words列是将raw_words1(即premise), raw_words2(即hypothesis)使用"[SEP]"链接起来转换为index的。 + words列被设置为input,target列被设置为target和input(设置为input以方便在forward函数中计算loss, + 如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数的形参名进行传参). + + :param bool lower: 是否将word小写化。 + :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 + """ + def __init__(self, lower=False, tokenizer: str='raw'): + super().__init__() + + self.lower = bool(lower) + self.tokenizer = get_tokenizer(tokenizer=tokenizer) + + def _tokenize(self, data_bundle, field_names, new_field_names): + """ + + :param DataBundle data_bundle: DataBundle. + :param list field_names: List[str], 需要tokenize的field名称 + :param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。 + :return: 输入的DataBundle对象 + """ + for name, dataset in data_bundle.datasets.items(): + for field_name, new_field_name in zip(field_names, new_field_names): + dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, + new_field_name=new_field_name) + return data_bundle + + def process(self, data_bundle): + for dataset in data_bundle.datasets.values(): + if dataset.has_field(Const.TARGET): + dataset.drop(lambda x: x[Const.TARGET] == '-') + + for name, dataset in data_bundle.datasets.items(): + dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0)) + dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1)) + + if self.lower: + for name, dataset in data_bundle.datasets.items(): + dataset[Const.INPUTS(0)].lower() + dataset[Const.INPUTS(1)].lower() + + data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUTS(1)], + [Const.INPUTS(0), Const.INPUTS(1)]) + + # concat两个words + def concat(ins): + words0 = ins[Const.INPUTS(0)] + words1 = ins[Const.INPUTS(1)] + words = words0 + ['[SEP]'] + words1 + return words + + for name, dataset in data_bundle.datasets.items(): + dataset.apply(concat, new_field_name=Const.INPUT) + dataset.delete_field(Const.INPUTS(0)) + dataset.delete_field(Const.INPUTS(1)) + + word_vocab = Vocabulary() + word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], + field_name=Const.INPUT, + no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if + 'train' not in name]) + word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) + + target_vocab = Vocabulary(padding=None, unknown=None) + target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) + has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if + dataset.has_field(Const.TARGET)] + target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) + + data_bundle.set_vocab(word_vocab, Const.INPUT) + data_bundle.set_vocab(target_vocab, Const.TARGET) + + input_fields = [Const.INPUT, Const.INPUT_LEN, Const.TARGET] + target_fields = [Const.TARGET] + + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(Const.INPUT) + dataset.set_input(*input_fields, flag=True) + dataset.set_target(*target_fields, flag=True) + + return data_bundle + + +class RTEBertPipe(MatchingBertPipe): + def process_from_file(self, paths=None): + data_bundle = RTELoader().load(paths) + return self.process(data_bundle) + + +class SNLIBertPipe(MatchingBertPipe): + def process_from_file(self, paths=None): + data_bundle = SNLILoader().load(paths) + return self.process(data_bundle) + + +class QuoraBertPipe(MatchingBertPipe): + def process_from_file(self, paths): + data_bundle = QuoraLoader().load(paths) + return self.process(data_bundle) + + +class QNLIBertPipe(MatchingBertPipe): + def process_from_file(self, paths=None): + data_bundle = QNLILoader().load(paths) + return self.process(data_bundle) + + +class MNLIBertPipe(MatchingBertPipe): + def process_from_file(self, paths=None): + data_bundle = MNLILoader().load(paths) + return self.process(data_bundle) + + +class MatchingPipe(Pipe): + """ + Matching任务的Pipe。输出的DataSet将包含以下的field + + .. csv-table:: + :header: "raw_words1", "raw_words2", "words1", "words2", "target", "seq_len1", "seq_len2" + + "The new rights are...", "Everyone really likes..", "[2, 3, 4, 5, ...]", "[10, 20, 6]", 1, 10, 13 + "This site includes a...", "The Government Executive...", "[11, 12, 13,...]", "[2, 7, ...]", 0, 6, 7 + "...", "...", "[...]", "[...]", ., ., . + + words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target + 和input(设置为input以方便在forward函数中计算loss,如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数 + 的形参名进行传参)。 + + :param bool lower: 是否将所有raw_words转为小写。 + :param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。 + """ + def __init__(self, lower=False, tokenizer: str='raw'): + super().__init__() + + self.lower = bool(lower) + self.tokenizer = get_tokenizer(tokenizer=tokenizer) + + def _tokenize(self, data_bundle, field_names, new_field_names): + """ + + :param DataBundle data_bundle: DataBundle. + :param list field_names: List[str], 需要tokenize的field名称 + :param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。 + :return: 输入的DataBundle对象 + """ + for name, dataset in data_bundle.datasets.items(): + for field_name, new_field_name in zip(field_names, new_field_names): + dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, + new_field_name=new_field_name) + return data_bundle + + def process(self, data_bundle): + """ + 接受的DataBundle中的DataSet应该具有以下的field, target列可以没有 + + .. csv-table:: + :header: "raw_words1", "raw_words2", "target" + + "The new rights are...", "Everyone really likes..", "entailment" + "This site includes a...", "The Government Executive...", "not_entailment" + "...", "..." + + :param data_bundle: + :return: + """ + data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)], + [Const.INPUTS(0), Const.INPUTS(1)]) + + for dataset in data_bundle.datasets.values(): + if dataset.has_field(Const.TARGET): + dataset.drop(lambda x: x[Const.TARGET] == '-') + + if self.lower: + for name, dataset in data_bundle.datasets.items(): + dataset[Const.INPUTS(0)].lower() + dataset[Const.INPUTS(1)].lower() + + word_vocab = Vocabulary() + word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], + field_name=[Const.INPUTS(0), Const.INPUTS(1)], + no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if + 'train' not in name]) + word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)]) + + target_vocab = Vocabulary(padding=None, unknown=None) + target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) + has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if + dataset.has_field(Const.TARGET)] + target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) + + data_bundle.set_vocab(word_vocab, Const.INPUTS(0)) + data_bundle.set_vocab(target_vocab, Const.TARGET) + + input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET] + target_fields = [Const.TARGET] + + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0)) + dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1)) + dataset.set_input(*input_fields, flag=True) + dataset.set_target(*target_fields, flag=True) + + return data_bundle + + +class RTEPipe(MatchingPipe): + def process_from_file(self, paths=None): + data_bundle = RTELoader().load(paths) + return self.process(data_bundle) + + +class SNLIPipe(MatchingPipe): + def process_from_file(self, paths=None): + data_bundle = SNLILoader().load(paths) + return self.process(data_bundle) + + +class QuoraPipe(MatchingPipe): + def process_from_file(self, paths): + data_bundle = QuoraLoader().load(paths) + return self.process(data_bundle) + +class QNLIPipe(MatchingPipe): + def process_from_file(self, paths=None): + data_bundle = QNLILoader().load(paths) + return self.process(data_bundle) + + +class MNLIPipe(MatchingPipe): + def process_from_file(self, paths=None): + data_bundle = MNLILoader().load(paths) + return self.process(data_bundle) + diff --git a/fastNLP/io/pipe/pipe.py b/fastNLP/io/pipe/pipe.py new file mode 100644 index 00000000..76cc00ec --- /dev/null +++ b/fastNLP/io/pipe/pipe.py @@ -0,0 +1,9 @@ +from .. import DataBundle + + +class Pipe: + def process(self, data_bundle: DataBundle) -> DataBundle: + raise NotImplementedError + + def process_from_file(self, paths) -> DataBundle: + raise NotImplementedError diff --git a/fastNLP/io/pipe/utils.py b/fastNLP/io/pipe/utils.py new file mode 100644 index 00000000..48454b67 --- /dev/null +++ b/fastNLP/io/pipe/utils.py @@ -0,0 +1,142 @@ +from typing import List +from ...core.vocabulary import Vocabulary +from ...core.const import Const + +def iob2(tags:List[str])->List[str]: + """ + 检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format + + :param tags: 需要转换的tags + """ + for i, tag in enumerate(tags): + if tag == "O": + continue + split = tag.split("-") + if len(split) != 2 or split[0] not in ["I", "B"]: + raise TypeError("The encoding schema is not a valid IOB type.") + if split[0] == "B": + continue + elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2 + tags[i] = "B" + tag[1:] + elif tags[i - 1][1:] == tag[1:]: + continue + else: # conversion IOB1 to IOB2 + tags[i] = "B" + tag[1:] + return tags + +def iob2bioes(tags:List[str])->List[str]: + """ + 将iob的tag转换为bioes编码 + :param tags: + :return: + """ + new_tags = [] + for i, tag in enumerate(tags): + if tag == 'O': + new_tags.append(tag) + else: + split = tag.split('-')[0] + if split == 'B': + if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I': + new_tags.append(tag) + else: + new_tags.append(tag.replace('B-', 'S-')) + elif split == 'I': + if i + 1Dict[str, str]: +def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: """ - 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果 - { - 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 - 'test': 'xxx' # 可能有,也可能没有 - ... - } - 如果paths为不合法的,将直接进行raise相应的错误 + 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果:: - :param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 + { + 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 + 'test': 'xxx' # 可能有,也可能没有 + ... + } + + 如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。 + + :param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 :return: """ - if isinstance(paths, str): + if isinstance(paths, (str, Path)): + paths = os.path.abspath(os.path.expanduser(paths)) if os.path.isfile(paths): return {'train': paths} elif os.path.isdir(paths): @@ -37,6 +41,8 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: path_pair = ('test', filename) if path_pair: files[path_pair[0]] = os.path.join(paths, path_pair[1]) + if 'train' not in files: + raise KeyError(f"There is no train file in {paths}.") return files else: raise FileNotFoundError(f"{paths} is not a valid file path.") @@ -47,8 +53,10 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: raise KeyError("You have to include `train` in your dict.") for key, value in paths.items(): if isinstance(key, str) and isinstance(value, str): + value = os.path.abspath(os.path.expanduser(value)) if not os.path.isfile(value): raise TypeError(f"{value} is not a valid file.") + paths[key] = value else: raise TypeError("All keys and values in paths should be str.") return paths diff --git a/fastNLP/models/bert.py b/fastNLP/models/bert.py index adecab60..ad7750ec 100644 --- a/fastNLP/models/bert.py +++ b/fastNLP/models/bert.py @@ -2,13 +2,14 @@ bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. """ +import os import torch from torch import nn from .base_model import BaseModel from ..core.const import Const from ..modules.encoder import BertModel -from ..modules.encoder.bert import BertConfig +from ..modules.encoder.bert import BertConfig, CONFIG_FILE class BertForSequenceClassification(BaseModel): @@ -54,6 +55,7 @@ class BertForSequenceClassification(BaseModel): self.num_labels = num_labels if bert_dir is not None: self.bert = BertModel.from_pretrained(bert_dir) + config = BertConfig(os.path.join(bert_dir, CONFIG_FILE)) else: if config is None: config = BertConfig(30522) @@ -67,20 +69,20 @@ class BertForSequenceClassification(BaseModel): model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) return model - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): - _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + def forward(self, words, seq_len=None, target=None): + _, pooled_output = self.bert(words, attention_mask=seq_len, output_all_encoded_layers=False) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) - if labels is not None: + if target is not None: loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + loss = loss_fct(logits, target) return {Const.OUTPUT: logits, Const.LOSS: loss} else: return {Const.OUTPUT: logits} - def predict(self, input_ids, token_type_ids=None, attention_mask=None): - logits = self.forward(input_ids, token_type_ids, attention_mask) + def predict(self, words, seq_len=None): + logits = self.forward(words, seq_len=seq_len)[Const.OUTPUT] return {Const.OUTPUT: torch.argmax(logits, dim=-1)} @@ -140,7 +142,8 @@ class BertForMultipleChoice(BaseModel): model = cls(num_choices=num_choices, config=config, bert_dir=pretrained_model_dir) return model - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): + def forward(self, words, seq_len1=None, seq_len2=None, target=None): + input_ids, token_type_ids, attention_mask = words, seq_len1, seq_len2 flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) @@ -149,15 +152,15 @@ class BertForMultipleChoice(BaseModel): logits = self.classifier(pooled_output) reshaped_logits = logits.view(-1, self.num_choices) - if labels is not None: + if target is not None: loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(reshaped_logits, labels) + loss = loss_fct(reshaped_logits, target) return {Const.OUTPUT: reshaped_logits, Const.LOSS: loss} else: return {Const.OUTPUT: reshaped_logits} - def predict(self, input_ids, token_type_ids=None, attention_mask=None): - logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT] + def predict(self, words, seq_len1=None, seq_len2=None,): + logits = self.forward(words, seq_len1=seq_len1, seq_len2=seq_len2)[Const.OUTPUT] return {Const.OUTPUT: torch.argmax(logits, dim=-1)} @@ -219,27 +222,27 @@ class BertForTokenClassification(BaseModel): model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) return model - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + def forward(self, words, seq_len1=None, seq_len2=None, target=None): + sequence_output, _ = self.bert(words, seq_len1, seq_len2, output_all_encoded_layers=False) sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) - if labels is not None: + if target is not None: loss_fct = nn.CrossEntropyLoss() # Only keep active parts of the loss - if attention_mask is not None: - active_loss = attention_mask.view(-1) == 1 + if seq_len2 is not None: + active_loss = seq_len2.view(-1) == 1 active_logits = logits.view(-1, self.num_labels)[active_loss] - active_labels = labels.view(-1)[active_loss] + active_labels = target.view(-1)[active_loss] loss = loss_fct(active_logits, active_labels) else: - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + loss = loss_fct(logits.view(-1, self.num_labels), target.view(-1)) return {Const.OUTPUT: logits, Const.LOSS: loss} else: return {Const.OUTPUT: logits} - def predict(self, input_ids, token_type_ids=None, attention_mask=None): - logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT] + def predict(self, words, seq_len1=None, seq_len2=None): + logits = self.forward(words, seq_len1, seq_len2)[Const.OUTPUT] return {Const.OUTPUT: torch.argmax(logits, dim=-1)} @@ -304,34 +307,34 @@ class BertForQuestionAnswering(BaseModel): model = cls(config=config, bert_dir=pretrained_model_dir) return model - def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + def forward(self, words, seq_len1=None, seq_len2=None, target1=None, target2=None): + sequence_output, _ = self.bert(words, seq_len1, seq_len2, output_all_encoded_layers=False) logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) - if start_positions is not None and end_positions is not None: + if target1 is not None and target2 is not None: # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) + if len(target1.size()) > 1: + target1 = target1.squeeze(-1) + if len(target2.size()) > 1: + target2 = target2.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) - start_positions.clamp_(0, ignored_index) - end_positions.clamp_(0, ignored_index) + target1.clamp_(0, ignored_index) + target2.clamp_(0, ignored_index) loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) + start_loss = loss_fct(start_logits, target1) + end_loss = loss_fct(end_logits, target2) total_loss = (start_loss + end_loss) / 2 return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits, Const.LOSS: total_loss} else: return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits} - def predict(self, input_ids, token_type_ids=None, attention_mask=None, **kwargs): - logits = self.forward(input_ids, token_type_ids, attention_mask) + def predict(self, words, seq_len1=None, seq_len2=None): + logits = self.forward(words, seq_len1, seq_len2) start_logits = logits[Const.OUTPUTS(0)] end_logits = logits[Const.OUTPUTS(1)] return {Const.OUTPUTS(0): torch.argmax(start_logits, dim=-1), diff --git a/fastNLP/models/snli.py b/fastNLP/models/snli.py index 8e35b6bc..3be942e8 100644 --- a/fastNLP/models/snli.py +++ b/fastNLP/models/snli.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from torch.nn import CrossEntropyLoss from .base_model import BaseModel -from ..embeddings.embedding import TokenEmbedding +from ..embeddings.embedding import TokenEmbedding, Embedding from ..core.const import Const from ..core.utils import seq_len_to_mask @@ -21,18 +21,21 @@ class ESIM(BaseModel): ESIM model的一个PyTorch实现 论文参见: https://arxiv.org/pdf/1609.06038.pdf - :param fastNLP.TokenEmbedding init_embedding: 初始化的TokenEmbedding + :param init_embedding: 初始化的Embedding :param int hidden_size: 隐藏层大小,默认值为Embedding的维度 :param int num_labels: 目标标签种类数量,默认值为3 :param float dropout_rate: dropout的比率,默认值为0.3 :param float dropout_embed: 对Embedding的dropout比率,默认值为0.1 """ - def __init__(self, init_embedding: TokenEmbedding, hidden_size=None, num_labels=3, dropout_rate=0.3, + def __init__(self, init_embedding, hidden_size=None, num_labels=3, dropout_rate=0.3, dropout_embed=0.1): super(ESIM, self).__init__() - self.embedding = init_embedding + if isinstance(init_embedding, TokenEmbedding) or isinstance(init_embedding, Embedding): + self.embedding = init_embedding + else: + self.embedding = Embedding(init_embedding) self.dropout_embed = EmbedDropout(p=dropout_embed) if hidden_size is None: hidden_size = self.embedding.embed_size diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index ce175df1..e73b2c40 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -563,6 +563,8 @@ class WordpieceTokenizer(object): output_tokens.append(self.unk_token) else: output_tokens.extend(sub_tokens) + if len(output_tokens)==0: #防止里面全是空格或者回车符号 + return [self.unk_token] return output_tokens @@ -848,7 +850,7 @@ class _WordPieceBertModel(nn.Module): """ - def __init__(self, model_dir: str, layers: str = '-1'): + def __init__(self, model_dir: str, layers: str = '-1', pooled_cls:bool=False): super().__init__() self.tokenzier = BertTokenizer.from_pretrained(model_dir) @@ -867,8 +869,9 @@ class _WordPieceBertModel(nn.Module): self._cls_index = self.tokenzier.vocab['[CLS]'] self._sep_index = self.tokenzier.vocab['[SEP]'] self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece + self.pooled_cls = pooled_cls - def index_dataset(self, *datasets, field_name): + def index_dataset(self, *datasets, field_name, add_cls_sep=True): """ 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是 [CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 @@ -884,10 +887,11 @@ class _WordPieceBertModel(nn.Module): tokens = self.tokenzier.wordpiece_tokenizer.tokenize(word) word_piece_ids = self.tokenzier.convert_tokens_to_ids(tokens) word_pieces.extend(word_piece_ids) - if word_pieces[0] != self._cls_index: - word_pieces.insert(0, self._cls_index) - if word_pieces[-1] != self._sep_index: - word_pieces.insert(-1, self._sep_index) + if add_cls_sep: + if word_pieces[0] != self._cls_index: + word_pieces.insert(0, self._cls_index) + if word_pieces[-1] != self._sep_index: + word_pieces.insert(-1, self._sep_index) return word_pieces for index, dataset in enumerate(datasets): @@ -909,10 +913,13 @@ class _WordPieceBertModel(nn.Module): batch_size, max_len = word_pieces.size() attn_masks = word_pieces.ne(self._wordpiece_pad_index) - bert_outputs, _ = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, + bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, output_all_encoded_layers=True) # output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size outputs = bert_outputs[0].new_zeros((len(self.layers), batch_size, max_len, bert_outputs[0].size(-1))) for l_index, l in enumerate(self.layers): - outputs[l_index] = bert_outputs[l] + bert_output = bert_outputs[l] + if l==len(bert_outputs) and self.pooled_cls: + bert_output[:, 0] = pooled_cls + outputs[l_index] = bert_output return outputs diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index dbae9c73..ead75711 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -3,7 +3,8 @@ from functools import reduce import torch import torch.nn as nn import torch.nn.init as init - +import glob +import os def initial_parameter(net, initial_method=None): """A method used to initialize the weights of PyTorch models. @@ -111,7 +112,7 @@ def get_dropout_mask(drop_p: float, tensor: torch.Tensor): 根据tensor的形状,生成一个mask :param drop_p: float, 以多大的概率置为0。 - :param tensor:torch.Tensor + :param tensor: torch.Tensor :return: torch.FloatTensor. 与tensor一样的shape """ mask_x = torch.ones_like(tensor) @@ -119,7 +120,6 @@ def get_dropout_mask(drop_p: float, tensor: torch.Tensor): training=False, inplace=True) return mask_x -import glob def _get_file_name_base_on_postfix(dir_path, postfix): """ @@ -128,9 +128,9 @@ def _get_file_name_base_on_postfix(dir_path, postfix): :param postfix: 形如".bin", ".json"等 :return: str,文件的路径 """ - files = glob.glob(os.path.join(dir_path, '*' + postfix)) + files = list(filter(lambda filename:filename.endswith(postfix), os.listdir(os.path.join(dir_path)))) if len(files) == 0: - raise FileNotFoundError(f"There is no file endswith *.{postfix} file in {dir_path}") + raise FileNotFoundError(f"There is no file endswith *{postfix} file in {dir_path}") elif len(files) > 1: - raise FileExistsError(f"There are multiple *.{postfix} files in {dir_path}") + raise FileExistsError(f"There are multiple *{postfix} files in {dir_path}") return os.path.join(dir_path, files[0]) \ No newline at end of file diff --git a/reproduction/joint_cws_parse/models/CharParser.py b/reproduction/joint_cws_parse/models/CharParser.py index c07c070e..7d89cacb 100644 --- a/reproduction/joint_cws_parse/models/CharParser.py +++ b/reproduction/joint_cws_parse/models/CharParser.py @@ -224,11 +224,11 @@ class CharBiaffineParser(BiaffineParser): batch_size, seq_len, _ = arc_pred.shape flip_mask = (mask == 0) - _arc_pred = arc_pred.clone() - _arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) + # _arc_pred = arc_pred.clone() + _arc_pred = arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf')) - arc_true[:, 0].fill_(-1) - label_true[:, 0].fill_(-1) + arc_true.data[:, 0].fill_(-1) + label_true.data[:, 0].fill_(-1) arc_nll = F.cross_entropy(_arc_pred.view(-1, seq_len), arc_true.view(-1), ignore_index=-1) label_nll = F.cross_entropy(label_pred.view(-1, label_pred.size(-1)), label_true.view(-1), ignore_index=-1) diff --git a/reproduction/joint_cws_parse/train.py b/reproduction/joint_cws_parse/train.py index 0c34614b..ed4b07f0 100644 --- a/reproduction/joint_cws_parse/train.py +++ b/reproduction/joint_cws_parse/train.py @@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import StepLR from fastNLP import Tester from fastNLP import GradientClipCallback, LRScheduler import os +from fastNLP import cache_results def set_random_seed(random_seed=666): import random, numpy, torch @@ -39,43 +40,42 @@ label_mlp_size = 100 batch_size = 32 update_every = 4 n_epochs = 100 -data_folder = '' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件 -vector_folder = '' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt +data_name = 'new_ctb7' #################################################### +data_folder = f'/remote-home/hyan01/exps/JointCwsPosParser/data/{data_name}/output' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件 +vector_folder = '/remote-home/hyan01/exps/CWS/pretrain/vectors' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt set_random_seed(1234) device = 0 -# @cache_results('caches/{}.pkl'.format(data_name)) -# def get_data(): -data = CTBxJointLoader().process(data_folder) - -char_labels_vocab = data.vocabs['char_labels'] - -pre_chars_vocab = data.vocabs['pre_chars'] -pre_bigrams_vocab = data.vocabs['pre_bigrams'] -pre_trigrams_vocab = data.vocabs['pre_trigrams'] - -chars_vocab = data.vocabs['chars'] -bigrams_vocab = data.vocabs['bigrams'] -trigrams_vocab = data.vocabs['trigrams'] - -pre_chars_embed = StaticEmbedding(pre_chars_vocab, - model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'), - init_method=uniform_init, normalize=False) -pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data/pre_chars_embed.embedding.weight.data.std() -pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab, - model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'), - init_method=uniform_init, normalize=False) -pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data/pre_bigrams_embed.embedding.weight.data.std() -pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab, - model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'), - init_method=uniform_init, normalize=False) -pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data/pre_trigrams_embed.embedding.weight.data.std() - - # return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data - -# chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data() +@cache_results('caches/{}.pkl'.format(data_name)) +def get_data(): + data = CTBxJointLoader().process(data_folder) + char_labels_vocab = data.vocabs['char_labels'] + + pre_chars_vocab = data.vocabs['pre_chars'] + pre_bigrams_vocab = data.vocabs['pre_bigrams'] + pre_trigrams_vocab = data.vocabs['pre_trigrams'] + + chars_vocab = data.vocabs['chars'] + bigrams_vocab = data.vocabs['bigrams'] + trigrams_vocab = data.vocabs['trigrams'] + pre_chars_embed = StaticEmbedding(pre_chars_vocab, + model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'), + init_method=uniform_init, normalize=False) + pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data / pre_chars_embed.embedding.weight.data.std() + pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab, + model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'), + init_method=uniform_init, normalize=False) + pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data / pre_bigrams_embed.embedding.weight.data.std() + pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab, + model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'), + init_method=uniform_init, normalize=False) + pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data / pre_trigrams_embed.embedding.weight.data.std() + + return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data + +chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data() print(data) model = CharParser(char_vocab_size=len(chars_vocab), @@ -104,11 +104,24 @@ optimizer = optim.Adam([param for param in model.parameters() if param.requires_ sampler = BucketSampler(seq_len_field_name='seq_lens') callbacks = [] + +from fastNLP.core.callback import Callback +from torch.optim.lr_scheduler import LambdaLR +class SchedulerCallback(Callback): + def __init__(self, scheduler): + super().__init__() + self.scheduler = scheduler + + def on_backward_end(self): + if self.step % self.update_every==0: + self.scheduler.step() + +scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000)) # scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000)) -scheduler = StepLR(optimizer, step_size=18, gamma=0.75) -# optim_callback = OptimizerCallback(optimizer, scheduler, update_every) +# scheduler = StepLR(optimizer, step_size=18, gamma=0.75) +scheduler_callback = SchedulerCallback(scheduler) # callbacks.append(optim_callback) -scheduler_callback = LRScheduler(scheduler) +# scheduler_callback = LRScheduler(scheduler) callbacks.append(scheduler_callback) callbacks.append(GradientClipCallback(clip_type='value', clip_value=5)) @@ -119,6 +132,6 @@ callbacks.append(dev_callback) trainer = Trainer(data.datasets['train'], model, loss=None, metrics=metrics, n_epochs=n_epochs, batch_size=batch_size, print_every=3, validate_every=-1, dev_data=data.datasets['dev'], save_path=None, optimizer=optimizer, - check_code_level=0, metric_key='u_f1', sampler=sampler, prefetch=True, use_tqdm=True, + check_code_level=0, metric_key='u_f1', sampler=sampler, num_workers=2, use_tqdm=True, device=device, callbacks=callbacks, update_every=update_every) trainer.train() \ No newline at end of file diff --git a/reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py b/reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py index cec5ab76..0d292bdc 100644 --- a/reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py +++ b/reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py @@ -51,7 +51,7 @@ class ChineseNERLoader(DataSetLoader): :param paths: :param bool, bigrams: 是否包含生成bigram feature, [a, b, c, d] -> [ab, bc, cd, d] :param bool, trigrams: 是否包含trigram feature,[a, b, c, d] -> [abc, bcd, cd, d] - :return: DataBundle + :return: ~fastNLP.io.DataBundle 包含以下的fields raw_chars: List[str] chars: List[int] diff --git a/reproduction/text_classification/README.md b/reproduction/text_classification/README.md index 8bdfb9fe..96ea7a10 100644 --- a/reproduction/text_classification/README.md +++ b/reproduction/text_classification/README.md @@ -11,6 +11,13 @@ LSTM+self_attention:论文链接[A Structured Self-attentive Sentence Embedding] AWD-LSTM:论文链接[Regularizing and Optimizing LSTM Language Models](https://arxiv.org/pdf/1708.02182.pdf) +#数据集来源 +IMDB:http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz +SST-2:https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8 +SST:https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip +yelp_full:https://drive.google.com/drive/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M +yelp_polarity:https://drive.google.com/drive/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M + # 数据集及复现结果汇总 使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道,-表示论文没有在该数据集上列出结果) diff --git a/reproduction/text_classification/train_char_cnn.py b/reproduction/text_classification/train_char_cnn.py index 0b8fc535..3482de70 100644 --- a/reproduction/text_classification/train_char_cnn.py +++ b/reproduction/text_classification/train_char_cnn.py @@ -203,7 +203,7 @@ callbacks.append( def train(model,datainfo,loss,metrics,optimizer,num_epochs=100): trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'),batch_size=ops.batch_size, metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=[0,1,2], check_code_level=-1, - n_epochs=num_epochs) + n_epochs=num_epochs,callbacks=callbacks) print(trainer.train()) diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index 0228f207..9c05c334 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -1,4 +1,5 @@ import os +import sys import unittest from fastNLP import DataSet @@ -79,6 +80,16 @@ class TestDataSetMethods(unittest.TestCase): self.assertFalse("x" in dd.field_arrays) self.assertTrue("y" in dd.field_arrays) + def test_delete_instance(self): + dd = DataSet() + old_length = 2 + dd.add_field("x", [[1, 2, 3]] * old_length) + dd.add_field("y", [[1, 2, 3, 4]] * old_length) + dd.delete_instance(0) + self.assertEqual(len(dd), old_length-1) + dd.delete_instance(0) + self.assertEqual(len(dd), old_length-2) + def test_getitem(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ins_1, ins_0 = ds[0], ds[1] diff --git a/test/core/test_dist_trainer.py b/test/core/test_dist_trainer.py new file mode 100644 index 00000000..c6879634 --- /dev/null +++ b/test/core/test_dist_trainer.py @@ -0,0 +1,167 @@ +import unittest + +import numpy as np +import torch.cuda +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP import CrossEntropyLoss, BCELoss +from fastNLP import SGD +from fastNLP.core.dist_trainer import DistTrainer, get_local_rank +from fastNLP.models.base_model import NaiveClassifier +import shutil +import os +import subprocess +from argparse import ArgumentParser +from fastNLP.core.callback import EchoCallback +from fastNLP import AccuracyMetric + +def prepare_fake_dataset(): + mean = np.array([-3, -3]) + cov = np.array([[1, 0], [0, 1]]) + class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) + + mean = np.array([3, 3]) + cov = np.array([[1, 0], [0, 1]]) + class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) + + data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=0) for item in class_A] + + [Instance(x=[float(item[0]), float(item[1])], y=1) for item in class_B]) + return data_set + +def prepare_fake_dataset2(*args, size=100): + ys = np.random.randint(4, size=100, dtype=np.int64) + data = {'y': ys} + for arg in args: + data[arg] = np.random.randn(size, 5) + return DataSet(data=data) + +def set_rng_seed(seed): + np.random.seed(seed) + +def prepare_env(): + def prepare_fake_dataset(): + mean = np.array([-3, -3]) + cov = np.array([[1, 0], [0, 1]]) + class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) + + mean = np.array([3, 3]) + cov = np.array([[1, 0], [0, 1]]) + class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) + + data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + + [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) + return data_set + + data_set = prepare_fake_dataset() + data_set.set_input("x") + data_set.set_target("y") + model = NaiveClassifier(2, 1) + return data_set, model + +class TestDistTrainer(unittest.TestCase): + save_path = './save_cp' + + def run1(self): + # test distributed training + print('local rank', get_local_rank()) + set_rng_seed(100) + data_set = prepare_fake_dataset() + data_set.set_input("x", flag=True) + data_set.set_target("y", flag=True) + + model = NaiveClassifier(2, 2) + + trainer = DistTrainer( + model=model, train_data=data_set, optimizer=SGD(lr=0.1), + loss=CrossEntropyLoss(pred="predict", target="y"), + batch_size_per_gpu=8, n_epochs=3, print_every=50, save_path=self.save_path, + ) + trainer.train() + """ + # 应该正确运行 + """ + if trainer.is_master and os.path.exists(self.save_path): + shutil.rmtree(self.save_path) + + def run2(self): + # test fp16 with distributed training + print('local rank', get_local_rank()) + set_rng_seed(100) + data_set = prepare_fake_dataset() + data_set.set_input("x", flag=True) + data_set.set_target("y", flag=True) + + model = NaiveClassifier(2, 2) + + trainer = DistTrainer( + model=model, train_data=data_set, optimizer=SGD(lr=0.1), + loss=CrossEntropyLoss(pred="predict", target="y"), + batch_size_per_gpu=8, n_epochs=3, print_every=50, save_path=self.save_path, + fp16='O1' + ) + trainer.train() + """ + # 应该正确运行 + """ + if trainer.is_master and os.path.exists(self.save_path): + shutil.rmtree(self.save_path) + + def run3(self): + set_rng_seed(100) + data_set, model = prepare_env() + trainer = DistTrainer( + data_set, model, optimizer=None, + loss=BCELoss(pred="predict", target="y"), + n_epochs=3, print_every=50, + callbacks_all=[EchoCallback('callbacks_all')], + callbacks_master=[EchoCallback('callbacks_master')] + ) + trainer.train() + + def run4(self): + set_rng_seed(100) + data_set, model = prepare_env() + + train_set, dev_set = data_set.split(0.3) + + model = NaiveClassifier(2, 1) + + trainer = DistTrainer( + train_set, model, optimizer=SGD(lr=0.1), + loss=BCELoss(pred="predict", target="y"), + batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set, + metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, + ) + trainer.train() + """ + # 应该正确运行 + """ + + def run_dist(self, run_id): + if torch.cuda.is_available(): + ngpu = min(2, torch.cuda.device_count()) + path = __file__ + cmd = ['python', '-m', 'torch.distributed.launch', + '--nproc_per_node', str(ngpu), path, '--test', str(run_id)] + print(' '.join(cmd)) + subprocess.check_call(cmd) + + def test_normal_run(self): + self.run_dist(1) + + def no_test_fp16(self): + self.run_dist(2) + + def test_callback(self): + self.run_dist(3) + + def test_dev_data(self): + self.run_dist(4) + +if __name__ == '__main__': + runner = TestDistTrainer() + parser = ArgumentParser() + parser.add_argument('--test', type=int) + args, _ = parser.parse_known_args() + if args.test and hasattr(runner, 'run%s'%args.test): + getattr(runner, 'run%s'%args.test)() diff --git a/test/core/test_field.py b/test/core/test_field.py index e9053f37..c46e2de2 100644 --- a/test/core/test_field.py +++ b/test/core/test_field.py @@ -170,22 +170,22 @@ class TestFieldArray(unittest.TestCase): def test_append(self): with self.assertRaises(Exception): - fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) + fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True, use_1st_ins_infer_dim_type=False) fa.append(0) with self.assertRaises(Exception): - fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True) + fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True, use_1st_ins_infer_dim_type=False) fa.append([1, 2, 3, 4, 5]) with self.assertRaises(Exception): - fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) + fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True, use_1st_ins_infer_dim_type=False) fa.append([]) with self.assertRaises(Exception): - fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) + fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True, use_1st_ins_infer_dim_type=False) fa.append(["str", 0, 0, 0, 1.89]) - fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]], is_input=True) + fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]], is_input=True, use_1st_ins_infer_dim_type=False) fa.append([1.2, 2.3, 3.4, 4.5, 5.6]) self.assertEqual(len(fa), 3) self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index 9c8a586c..236066d6 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -7,7 +7,7 @@ from fastNLP import AccuracyMetric from fastNLP.core.metrics import _pred_topk, _accuracy_topk from fastNLP.core.vocabulary import Vocabulary from collections import Counter -from fastNLP.core.metrics import SpanFPreRecMetric +from fastNLP.core.metrics import SpanFPreRecMetric, ExtractiveQAMetric def _generate_tags(encoding_type, number_labels=4): @@ -347,3 +347,46 @@ class TestUsefulFunctions(unittest.TestCase): _ = _pred_topk(np.random.randint(0, 3, size=(10, 1))) # 跑通即可 + + +class TestExtractiveQAMetric(unittest.TestCase): + + def test_cast_1(self): + qa_prediction = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, + -0.3782, 0.8240], + [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, -1.1563, + -0.3562, -1.4116], + [-1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, + -2.0023, 0.0075], + [-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, + 0.3832, -0.1540], + [-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, + -1.3508, -0.9513], + [1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, + -0.0842, -0.4294]], + + [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, + -1.4138, -0.8853], + [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, + -1.0726, 0.0364], + [0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, + -0.8836, -0.9320], + [0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, + -1.6857, 1.1571], + [1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, + 3.5837, 1.0184], + [1.6495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, + -0.9025, 0.0864]]]) + qa_prediction = qa_prediction.permute(1, 2, 0) + pred1, pred2 = qa_prediction.split(1, dim=-1) + pred1 = pred1.squeeze(-1) + pred2 = pred2.squeeze(-1) + target1 = torch.LongTensor([3, 0, 2, 4, 4, 0]) + target2 = torch.LongTensor([4, 1, 6, 8, 7, 1]) + metric = ExtractiveQAMetric() + metric.evaluate(pred1, pred2, target1, target2) + result = metric.get_metric() + truth = {'EM': 62.5, 'f_1': 72.5, 'noAns-f_1': 50.0, 'noAns-EM': 50.0, 'hasAns-f_1': 95.0, 'hasAns-EM': 75.0} + for k, v in truth.items(): + self.assertTrue(k in result) + self.assertEqual(v, result[k]) diff --git a/test/data_for_tests/sample_mnli.tsv b/test/data_for_tests/sample_mnli.tsv new file mode 100644 index 00000000..9a30b95b --- /dev/null +++ b/test/data_for_tests/sample_mnli.tsv @@ -0,0 +1,12 @@ +index promptID pairID genre sentence1_binary_parse sentence2_binary_parse sentence1_parse sentence2_parse sentence1 sentence2 label1 label2 label3 label4 label5 gold_label +0 63735 63735n slate ( ( The ( new rights ) ) ( are ( nice enough ) ) ) ( Everyone ( really ( likes ( the ( newest benefits ) ) ) ) ) (ROOT (S (NP (DT The) (JJ new) (NNS rights)) (VP (VBP are) (ADJP (JJ nice) (RB enough))))) (ROOT (S (NP (NN Everyone)) (VP (ADVP (RB really)) (VBZ likes) (NP (DT the) (JJS newest) (NNS benefits))))) The new rights are nice enough Everyone really likes the newest benefits neutral entailment neutral neutral neutral neutral +1 91383 91383c government ( ( This site ) ( ( includes ( ( ( ( a list ) ( of ( all ( award winners ) ) ) ) and ) ( ( a ( searchable database ) ) ( of ( Government ( Executive articles ) ) ) ) ) ) . ) ) ( ( ( The ( Government ( Executive articles ) ) ) ( housed ( on ( the website ) ) ) ) ( ( ( are not ) ( able ( to ( be searched ) ) ) ) . ) ) (ROOT (S (NP (DT This) (NN site)) (VP (VBZ includes) (NP (NP (NP (DT a) (NN list)) (PP (IN of) (NP (DT all) (NN award) (NNS winners)))) (CC and) (NP (NP (DT a) (JJ searchable) (NN database)) (PP (IN of) (NP (NNP Government) (NNP Executive) (NNS articles)))))) (. .))) (ROOT (S (NP (NP (DT The) (NNP Government) (NNP Executive) (NNS articles)) (VP (VBN housed) (PP (IN on) (NP (DT the) (NN website))))) (VP (VBP are) (RB not) (ADJP (JJ able) (S (VP (TO to) (VP (VB be) (ADJP (JJ searched))))))) (. .))) This site includes a list of all award winners and a searchable database of Government Executive articles. The Government Executive articles housed on the website are not able to be searched. contradiction contradiction contradiction contradiction contradiction contradiction +2 755 755e telephone ( ( ( ( uh ( i ( ( do n't ) ( know ( ( i i ) ( have ( ( mixed emotions ) ( about ( him ( ( uh sometimes ) ( i ( like him ) ) ) ) ) ) ) ) ) ) ) ) but ) ( ( at ( the ( same times ) ) ) ( i ( love ( to ( see somebody ) ) ) ) ) ) ( beat him ) ) ( I ( ( ( ( ( ( like him ) ( for ( the ( most part ) ) ) ) , ) but ) ( ( would still ) ( enjoy ( seeing ( someone ( beat him ) ) ) ) ) ) . ) ) (ROOT (SINV (S (S (INTJ (UH uh)) (NP (FW i)) (VP (VBP do) (RB n't) (VP (VB know) (NP (NP (FW i) (FW i)) (SBAR (S (VP (VBP have) (VP (VBN mixed) (NP (NNS emotions)) (PP (IN about) (S (NP (PRP him)) (VP (VBG uh) (ADVP (RB sometimes)) (NP (NP (FW i)) (PP (IN like) (NP (PRP him))))))))))))))) (CC but) (S (PP (IN at) (NP (DT the) (JJ same) (NNS times))) (NP (FW i)) (VP (VBP love) (S (VP (TO to) (VP (VB see) (NP (NN somebody)))))))) (VP (VBD beat)) (NP (PRP him)))) (ROOT (S (NP (PRP I)) (VP (VP (VBP like) (NP (PRP him)) (PP (IN for) (NP (DT the) (JJS most) (NN part)))) (, ,) (CC but) (VP (MD would) (ADVP (RB still)) (VP (VB enjoy) (S (VP (VBG seeing) (S (NP (NN someone)) (VP (VB beat) (NP (PRP him))))))))) (. .))) uh i don't know i i have mixed emotions about him uh sometimes i like him but at the same times i love to see somebody beat him I like him for the most part, but would still enjoy seeing someone beat him. entailment entailment entailment entailment entailment entailment +3 78013 78013c telephone ( yeah ( ( i i ) ( think ( ( my ( favorite restaurant ) ) ( ( is always ) ( been ( ( the ( one closest ) ) ( you ( ( know ( the closest ) ) ( ( as long ) ( as ( it ( 's ( it ( meets ( ( the ( minimum criteria ) ) ( you ( know ( of ( good food ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ( ( My ( favorite restaurants ) ) ( ( ( ( are always ) ( ( ( ( ( at least ) a ) hundred ) miles ) away ) ) ( from ( my house ) ) ) . ) ) (ROOT (S (VP (VB yeah) (NP (NP (FW i) (FW i)) (SBAR (S (VP (VBP think) (SBAR (S (NP (PRP$ my) (JJ favorite) (NN restaurant)) (VP (VBZ is) (ADVP (RB always)) (VP (VBN been) (NP (NP (DT the) (CD one) (JJS closest)) (SBAR (S (NP (PRP you)) (VP (VBP know) (NP (DT the) (JJS closest)) (ADVP (ADVP (RB as) (RB long)) (SBAR (IN as) (S (NP (PRP it)) (VP (VBZ 's) (SBAR (S (NP (PRP it)) (VP (VBZ meets) (NP (NP (DT the) (JJ minimum) (NNS criteria)) (SBAR (S (NP (PRP you)) (VP (VBP know) (PP (IN of) (NP (JJ good) (NN food))))))))))))))))))))))))))))) (ROOT (S (NP (PRP$ My) (JJ favorite) (NNS restaurants)) (VP (VBP are) (ADVP (RB always)) (ADVP (NP (QP (IN at) (JJS least) (DT a) (CD hundred)) (NNS miles)) (RB away)) (PP (IN from) (NP (PRP$ my) (NN house)))) (. .))) yeah i i think my favorite restaurant is always been the one closest you know the closest as long as it's it meets the minimum criteria you know of good food My favorite restaurants are always at least a hundred miles away from my house. contradiction contradiction contradiction contradiction contradiction contradiction +4 96377 96377c telephone ( i ( ( do n't ) ( know ( um ( do ( you ( do ( ( a lot ) ( of camping ) ) ) ) ) ) ) ) ) ( I ( ( know exactly ) . ) ) (ROOT (S (NP (FW i)) (VP (VBP do) (RB n't) (VP (VB know) (SBAR (S (NP (NN um)) (VP (VBP do) (SBAR (S (NP (PRP you)) (VP (VBP do) (NP (NP (DT a) (NN lot)) (PP (IN of) (NP (NN camping)))))))))))))) (ROOT (S (NP (PRP I)) (VP (VBP know) (ADVP (RB exactly))) (. .))) i don't know um do you do a lot of camping I know exactly. contradiction contradiction contradiction contradiction contradiction contradiction +5 139749 139749c telephone ( well ( that ( would ( be ( ( a help ) ( i ( wish ( they ( would ( do ( that ( ( ( here ( we ( have ( got ( so ( ( little ( landfill space ) ) ( left ( that ( we ( 're ( going ( to ( ( run out ) ( before ( ( the end ) ( of ( this decade ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) and ) ( it ( ( 's really ) ( going ( to be ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ( We ( ( have ( plenty ( of ( space ( in ( the landfill ) ) ) ) ) ) . ) ) (ROOT (FRAG (ADVP (RB well)) (SBAR (WHNP (WDT that)) (S (VP (MD would) (VP (VB be) (NP (NP (DT a) (NN help)) (SBAR (S (NP (FW i)) (VP (VBP wish) (SBAR (S (NP (PRP they)) (VP (MD would) (VP (VB do) (SBAR (IN that) (S (S (ADVP (RB here)) (NP (PRP we)) (VP (VBP have) (VP (VBN got) (SBAR (IN so) (S (NP (JJ little) (NN landfill) (NN space)) (VP (VBD left) (SBAR (IN that) (S (NP (PRP we)) (VP (VBP 're) (VP (VBG going) (S (VP (TO to) (VP (VB run) (PRT (RP out)) (PP (IN before) (NP (NP (DT the) (NN end)) (PP (IN of) (NP (DT this) (NN decade)))))))))))))))))) (CC and) (S (NP (PRP it)) (VP (VBZ 's) (ADVP (RB really)) (VP (VBG going) (S (VP (TO to) (VP (VB be))))))))))))))))))))))) (ROOT (S (NP (PRP We)) (VP (VBP have) (NP (NP (RB plenty)) (PP (IN of) (NP (NP (NN space)) (PP (IN in) (NP (DT the) (NN landfill))))))) (. .))) well that would be a help i wish they would do that here we have got so little landfill space left that we're going to run out before the end of this decade and it's really going to be We have plenty of space in the landfill. contradiction contradiction contradiction contradiction contradiction contradiction +6 101415 101415c telephone ( yeah ( ( ( i know ) and ) ( i ( did ( that ( ( ( all ( through college ) ) and ) ( it ( worked too ) ) ) ) ) ) ) ) ( I ( ( ( did ( that all ) ) ( through college ) ) ( but ( it ( never worked ) ) ) ) ) (ROOT (S (VP (VB yeah) (S (S (NP (FW i)) (VP (VBP know))) (CC and) (S (NP (FW i)) (VP (VBD did) (SBAR (IN that) (S (S (NP (DT all)) (PP (IN through) (NP (NN college)))) (CC and) (S (NP (PRP it)) (VP (VBD worked) (ADVP (RB too)))))))))))) (ROOT (S (NP (PRP I)) (VP (VBD did) (ADVP (IN that) (DT all)) (PP (IN through) (NP (NN college))) (SBAR (CC but) (S (NP (PRP it)) (ADVP (RB never)) (VP (VBD worked))))))) yeah i know and i did that all through college and it worked too I did that all through college but it never worked contradiction contradiction contradiction contradiction contradiction contradiction +7 93958 93958n travel ( ( ( ( ( Calcutta ( seems ( to ( be ( ( the ( only ( other ( production center ) ) ) ) ( ( having ( any pretensions ) ) ( to ( ( artistic creativity ) ( at all ) ) ) ) ) ) ) ) ) , ) but ) ( ironically ( you ( ( 're actually ) ( ( more ( likely ( to ( see ( ( the works ) ( of ( ( ( Satyajit Ray ) or ) ( ( Mrinal Sen ) ( shown ( in ( Europe ( or ( North America ) ) ) ) ) ) ) ) ) ) ) ) ) ( than ( in ( India itself ) ) ) ) ) ) ) ) . ) ( ( Most ( of ( ( Mrinal ( Sen 's ) ) work ) ) ) ( ( can ( be ( found ( in ( European collections ) ) ) ) ) . ) ) (ROOT (S (S (NP (NNP Calcutta)) (VP (VBZ seems) (S (VP (TO to) (VP (VB be) (NP (NP (DT the) (JJ only) (JJ other) (NN production) (NN center)) (VP (VBG having) (NP (DT any) (NNS pretensions)) (PP (TO to) (NP (NP (JJ artistic) (NN creativity)) (ADVP (IN at) (DT all))))))))))) (, ,) (CC but) (S (ADVP (RB ironically)) (NP (PRP you)) (VP (VBP 're) (ADVP (RB actually)) (ADJP (ADJP (RBR more) (JJ likely) (S (VP (TO to) (VP (VB see) (NP (NP (DT the) (NNS works)) (PP (IN of) (NP (NP (NNP Satyajit) (NNP Ray)) (CC or) (NP (NP (NNP Mrinal) (NNP Sen)) (VP (VBN shown) (PP (IN in) (NP (NNP Europe) (CC or) (NNP North) (NNP America)))))))))))) (ADVP (IN than) (PP (IN in) (S (VP (VBG India) (NP (PRP itself))))))))) (. .))) (ROOT (S (NP (NP (JJS Most)) (PP (IN of) (NP (NP (NNP Mrinal) (NNP Sen) (POS 's)) (NN work)))) (VP (MD can) (VP (VB be) (VP (VBN found) (PP (IN in) (NP (JJ European) (NNS collections)))))) (. .))) Calcutta seems to be the only other production center having any pretensions to artistic creativity at all, but ironically you're actually more likely to see the works of Satyajit Ray or Mrinal Sen shown in Europe or North America than in India itself. Most of Mrinal Sen's work can be found in European collections. neutral neutral entailment neutral neutral neutral +8 12567 12567c slate ( ( If ( ( that investor ) ( were ( willing ( to ( pay ( extra ( for ( ( the security ) ( of ( limited downside ) ) ) ) ) ) ) ) ) ) ) ( , ( she ( ( could ( ( buy ( put options ) ) ( with ( ( a ( strike price ) ) ( of ( ( ( $ 98 ) , ) ( which ( would ( ( ( lock ( in ( ( her profit ) ( on ( ( the shares ) ( at ( $ 18 ) ) ) ) ) ) ) , ) ( less ( whatever ( ( the options ) cost ) ) ) ) ) ) ) ) ) ) ) ) . ) ) ) ) ( ( THe ( strike price ) ) ( ( could ( be ( $ 8 ) ) ) . ) ) (ROOT (S (SBAR (IN If) (S (NP (DT that) (NN investor)) (VP (VBD were) (ADJP (JJ willing) (S (VP (TO to) (VP (VB pay) (NP (NP (JJ extra)) (PP (IN for) (NP (NP (DT the) (NN security)) (PP (IN of) (NP (JJ limited) (NN downside))))))))))))) (, ,) (NP (PRP she)) (VP (MD could) (VP (VB buy) (NP (NN put) (NNS options)) (PP (IN with) (NP (NP (DT a) (NN strike) (NN price)) (PP (IN of) (NP (NP ($ $) (CD 98)) (, ,) (SBAR (WHNP (WDT which)) (S (VP (MD would) (VP (VB lock) (PP (IN in) (NP (NP (PRP$ her) (NN profit)) (PP (IN on) (NP (NP (DT the) (NNS shares)) (PP (IN at) (NP ($ $) (CD 18))))))) (, ,) (ADVP (ADVP (RBR less)) (SBAR (WHNP (WDT whatever)) (S (NP (DT the) (NNS options)) (VP (VBD cost))))))))))))))) (. .))) (ROOT (S (NP (NNP THe) (NN strike) (NN price)) (VP (MD could) (VP (VB be) (NP ($ $) (CD 8)))) (. .))) If that investor were willing to pay extra for the security of limited downside, she could buy put options with a strike price of $98, which would lock in her profit on the shares at $18, less whatever the options cost. THe strike price could be $8. contradiction contradiction contradiction contradiction contradiction contradiction +9 117487 117487n slate ( ( 3 -RRB- ) ( ( Dare ( you ( ( ( rise ( to ( ( ( ( the occasion ) , ) ( like Raskolnikov ) ) , ) ) ) and ) ( reject ( ( the ( petty rules ) ) ( that ( govern ( lesser men ) ) ) ) ) ) ) ) ? ) ) ( ( ( Would you ) ( ( ( rise up ) and ) ( defeaat ( ( all ( evil lords ) ) ( in ( the town ) ) ) ) ) ) ? ) (ROOT (S (LST (LS 3) (-RRB- -RRB-)) (VP (VB Dare) (S (NP (PRP you)) (VP (VP (VB rise) (PP (TO to) (NP (NP (DT the) (NN occasion)) (, ,) (PP (IN like) (NP (NNP Raskolnikov))) (, ,)))) (CC and) (VP (VB reject) (NP (NP (DT the) (JJ petty) (NNS rules)) (SBAR (WHNP (WDT that)) (S (VP (VBP govern) (NP (JJR lesser) (NNS men)))))))))) (. ?))) (ROOT (SQ (MD Would) (NP (PRP you)) (VP (VP (VB rise) (PRT (RP up))) (CC and) (VP (VB defeaat) (NP (NP (DT all) (JJ evil) (NNS lords)) (PP (IN in) (NP (DT the) (NN town)))))) (. ?))) 3) Dare you rise to the occasion, like Raskolnikov, and reject the petty rules that govern lesser men? Would you rise up and defeaat all evil lords in the town? neutral neutral neutral neutral neutral neutral +10 9616 9616c travel ( ( The ( ( most important ) directions ) ) ( ( ( are ( simply ( ( up and ) up ) ) ) ( ( ( ( ( ( ( ( leads eventually ) ( to ( the cathedral ) ) ) and ) ( fortress ( commanding ( the hilltop ) ) ) ) , ) and ) down ) ( inevitably ( ( leads ( to ( one ( of ( three gates ) ) ) ) ) ( through ( ( the wall ) ( to ( the ( new town ) ) ) ) ) ) ) ) ) . ) ) ( Go ( ( downwards ( to ( one ( of ( ( ( the gates ) , ) ( ( all ( of which ) ) ( will ( ( lead you ) ( into ( the cathedral ) ) ) ) ) ) ) ) ) ) . ) ) (ROOT (S (NP (DT The) (ADJP (RBS most) (JJ important)) (NNS directions)) (VP (VBP are) (PRN (ADVP (RB simply)) (ADVP (RB up) (CC and) (RB up))) (VP (VP (VBZ leads) (ADVP (RB eventually)) (PP (TO to) (NP (DT the) (NN cathedral)))) (CC and) (VP (VBZ fortress) (NP (JJ commanding) (DT the) (NN hilltop))) (, ,) (CC and) (ADVP (RB down)) (VP (ADVP (RB inevitably)) (VBZ leads) (PP (TO to) (NP (NP (CD one)) (PP (IN of) (NP (CD three) (NNS gates))))) (PP (IN through) (NP (NP (DT the) (NN wall)) (PP (TO to) (NP (DT the) (JJ new) (NN town)))))))) (. .))) (ROOT (S (NP (NNP Go)) (VP (VBZ downwards) (PP (TO to) (NP (NP (CD one)) (PP (IN of) (NP (NP (DT the) (NNS gates)) (, ,) (SBAR (WHNP (DT all) (WHPP (IN of) (WHNP (WDT which)))) (S (VP (MD will) (VP (VB lead) (NP (PRP you)) (PP (IN into) (NP (DT the) (NN cathedral)))))))))))) (. .))) The most important directions are simply up and up leads eventually to the cathedral and fortress commanding the hilltop, and down inevitably leads to one of three gates through the wall to the new town. Go downwards to one of the gates, all of which will lead you into the cathedral. contradiction contradiction entailment contradiction contradiction contradiction diff --git a/test/embeddings/__init__.py b/test/embeddings/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/embeddings/test_elmo_embedding.py b/test/embeddings/test_elmo_embedding.py new file mode 100644 index 00000000..a087f0a4 --- /dev/null +++ b/test/embeddings/test_elmo_embedding.py @@ -0,0 +1,21 @@ + +import unittest +from fastNLP import Vocabulary +from fastNLP.embeddings import ElmoEmbedding +import torch +import os + +@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") +class TestDownload(unittest.TestCase): + def test_download_small(self): + # import os + vocab = Vocabulary().add_word_lst("This is a test .".split()) + elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='en-small') + words = torch.LongTensor([[0, 1, 2]]) + print(elmo_embed(words).size()) + + +# 首先保证所有权重可以加载;上传权重;验证可以下载 + + + diff --git a/test/embeddings/test_static_embedding.py b/test/embeddings/test_static_embedding.py index 0c8fc739..ca97dd75 100644 --- a/test/embeddings/test_static_embedding.py +++ b/test/embeddings/test_static_embedding.py @@ -3,13 +3,110 @@ import unittest from fastNLP.embeddings import StaticEmbedding from fastNLP import Vocabulary import torch +import os class TestRandomSameEntry(unittest.TestCase): def test_same_vector(self): - vocab = Vocabulary().add_word_lst(["The", "the", "THE"]) + vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True) - words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE"]]]) + words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'a', 'A']]]) words = embed(words) embed_0 = words[0, 0] - for i in range(1, words.size(1)): + for i in range(1, 3): assert torch.sum(embed_0==words[0, i]).eq(len(embed_0)) + embed_0 = words[0, 3] + for i in range(3, 5): + assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0)) + + @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") + def test_same_vector2(self): + vocab = Vocabulary().add_word_lst(["The", 'a', 'b', "the", "THE", "B", 'a', "A"]) + embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.6B.100d.txt', + lower=True) + words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'b', "B", 'a', 'A']]]) + words = embed(words) + embed_0 = words[0, 0] + for i in range(1, 3): + assert torch.sum(embed_0==words[0, i]).eq(len(embed_0)) + embed_0 = words[0, 3] + for i in range(3, 5): + assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0)) + + @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") + def test_same_vector3(self): + # 验证lower + word_lst = ["The", "the"] + no_create_word_lst = ['of', 'Of', 'With', 'with'] + vocab = Vocabulary().add_word_lst(word_lst) + vocab.add_word_lst(no_create_word_lst, no_create_entry=True) + embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', + lower=True) + words = torch.LongTensor([[vocab.to_index(word) for word in word_lst+no_create_word_lst]]) + words = embed(words) + + lowered_word_lst = [word.lower() for word in word_lst] + lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst] + lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst) + lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True) + lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', + lower=False) + lowered_words = torch.LongTensor([[lowered_vocab.to_index(word) for word in lowered_word_lst+lowered_no_create_word_lst]]) + lowered_words = lowered_embed(lowered_words) + + all_words = word_lst + no_create_word_lst + + for idx, (word_i, word_j) in enumerate(zip(words[0], lowered_words[0])): + with self.subTest(idx=idx, word=all_words[idx]): + assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size) + + @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") + def test_same_vector4(self): + # 验证在有min_freq下的lower + word_lst = ["The", "the", "the", "The", "a", "A"] + no_create_word_lst = ['of', 'Of', "Of", "of", 'With', 'with'] + all_words = word_lst[:-2] + no_create_word_lst[:-2] + vocab = Vocabulary(min_freq=2).add_word_lst(word_lst) + vocab.add_word_lst(no_create_word_lst, no_create_entry=True) + embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', + lower=True) + words = torch.LongTensor([[vocab.to_index(word) for word in all_words]]) + words = embed(words) + + lowered_word_lst = [word.lower() for word in word_lst] + lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst] + lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst) + lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True) + lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', + lower=False) + lowered_words = torch.LongTensor([[lowered_vocab.to_index(word.lower()) for word in all_words]]) + lowered_words = lowered_embed(lowered_words) + + for idx in range(len(all_words)): + word_i, word_j = words[0, idx], lowered_words[0, idx] + with self.subTest(idx=idx, word=all_words[idx]): + assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size) + + @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") + def test_same_vector5(self): + # 检查通过使用min_freq后的word是否内容一致 + word_lst = ["they", "the", "they", "the", 'he', 'he', "a", "A"] + no_create_word_lst = ['of', "of", "she", "she", 'With', 'with'] + all_words = word_lst[:-2] + no_create_word_lst[:-2] + vocab = Vocabulary().add_word_lst(word_lst) + vocab.add_word_lst(no_create_word_lst, no_create_entry=True) + embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', + lower=False, min_freq=2) + words = torch.LongTensor([[vocab.to_index(word) for word in all_words]]) + words = embed(words) + + min_freq_vocab = Vocabulary(min_freq=2).add_word_lst(word_lst) + min_freq_vocab.add_word_lst(no_create_word_lst, no_create_entry=True) + min_freq_embed = StaticEmbedding(min_freq_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', + lower=False) + min_freq_words = torch.LongTensor([[min_freq_vocab.to_index(word.lower()) for word in all_words]]) + min_freq_words = min_freq_embed(min_freq_words) + + for idx in range(len(all_words)): + word_i, word_j = words[0, idx], min_freq_words[0, idx] + with self.subTest(idx=idx, word=all_words[idx]): + assert torch.sum(word_i == word_j).eq(min_freq_embed.embed_size) \ No newline at end of file diff --git a/test/io/loader/test_classification_loader.py b/test/io/loader/test_classification_loader.py new file mode 100644 index 00000000..28f08921 --- /dev/null +++ b/test/io/loader/test_classification_loader.py @@ -0,0 +1,19 @@ + +import unittest +from fastNLP.io.loader.classification import YelpFullLoader +from fastNLP.io.loader.classification import YelpPolarityLoader +from fastNLP.io.loader.classification import IMDBLoader +from fastNLP.io.loader.classification import SST2Loader +from fastNLP.io.loader.classification import SSTLoader +import os + +@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") +class TestDownload(unittest.TestCase): + def test_download(self): + for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]: + loader().download() + + def test_load(self): + for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]: + data_bundle = loader().load() + print(data_bundle) diff --git a/test/io/loader/test_matching_loader.py b/test/io/loader/test_matching_loader.py new file mode 100644 index 00000000..5c1a91f1 --- /dev/null +++ b/test/io/loader/test_matching_loader.py @@ -0,0 +1,22 @@ + +import unittest +from fastNLP.io.loader.matching import RTELoader +from fastNLP.io.loader.matching import QNLILoader +from fastNLP.io.loader.matching import SNLILoader +from fastNLP.io.loader.matching import QuoraLoader +from fastNLP.io.loader.matching import MNLILoader +import os + +@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") +class TestDownload(unittest.TestCase): + def test_download(self): + for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]: + loader().download() + with self.assertRaises(Exception): + QuoraLoader().load() + + def test_load(self): + for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]: + data_bundle = loader().load() + print(data_bundle) + diff --git a/test/io/pipe/test_classification.py b/test/io/pipe/test_classification.py new file mode 100644 index 00000000..39dc71e0 --- /dev/null +++ b/test/io/pipe/test_classification.py @@ -0,0 +1,13 @@ +import unittest +import os + +from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe + +@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") +class TestPipe(unittest.TestCase): + def test_process_from_file(self): + for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: + with self.subTest(pipe=pipe): + print(pipe) + data_bundle = pipe(tokenizer='raw').process_from_file() + print(data_bundle) diff --git a/test/io/pipe/test_matching.py b/test/io/pipe/test_matching.py new file mode 100644 index 00000000..c057bb0c --- /dev/null +++ b/test/io/pipe/test_matching.py @@ -0,0 +1,26 @@ + +import unittest +import os + +from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe +from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe + + +@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") +class TestPipe(unittest.TestCase): + def test_process_from_file(self): + for pipe in [SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe]: + with self.subTest(pipe=pipe): + print(pipe) + data_bundle = pipe(tokenizer='raw').process_from_file() + print(data_bundle) + + +@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") +class TestBertPipe(unittest.TestCase): + def test_process_from_file(self): + for pipe in [SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe]: + with self.subTest(pipe=pipe): + print(pipe) + data_bundle = pipe(tokenizer='raw').process_from_file() + print(data_bundle) diff --git a/test/io/test_data_loader.py b/test/io/test_data_loader.py new file mode 100644 index 00000000..5b1bb749 --- /dev/null +++ b/test/io/test_data_loader.py @@ -0,0 +1,15 @@ +import unittest + +from fastNLP.core.const import Const +from fastNLP.io.data_loader import MNLILoader + + +class TestDataLoader(unittest.TestCase): + + def test_mnli_loader(self): + ds = MNLILoader().process('test/data_for_tests/sample_mnli.tsv', + to_lower=True, get_index=True, seq_len_type='mask') + self.assertTrue('train' in ds.datasets) + self.assertTrue(len(ds.datasets) == 1) + self.assertTrue(len(ds.datasets['train']) == 11) + self.assertTrue(isinstance(ds.datasets['train'][0][Const.INPUT_LENS(0)], list)) diff --git a/test/models/test_snli.py b/test/models/test_snli.py new file mode 100644 index 00000000..7a588a4c --- /dev/null +++ b/test/models/test_snli.py @@ -0,0 +1,9 @@ +import unittest +from .model_runner import * +from fastNLP.models.snli import ESIM + + +class TestSNLIModel(unittest.TestCase): + def test_snli(self): + model = ESIM((VOCAB_SIZE, 10), num_labels=NUM_CLS, dropout_rate=0) + RUNNER.run_model_with_task(NLI, model) diff --git a/test/modules/__init__.py b/test/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/modules/decoder/__init__.py b/test/modules/decoder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/modules/encoder/test_bert.py b/test/modules/decoder/test_bert.py similarity index 100% rename from test/modules/encoder/test_bert.py rename to test/modules/decoder/test_bert.py