diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..2b2b2b35 --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +.gitignore + +.DS_Store +.ipynb_checkpoints +*.pyc +__pycache__ +*.swp +.vscode/ +.idea/** + +caches + +# fitlog +.fitlog +logs/ +.fitconfig diff --git a/.travis.yml b/.travis.yml index 559fc86e..210d158a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,7 +8,7 @@ install: - pip install pytest-cov # command to run tests script: - - pytest --cov=./ + - pytest --cov=./ test/ after_success: - bash <(curl -s https://codecov.io/bash) diff --git a/docs/source/index.rst b/docs/source/index.rst index 03a192dc..da510437 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -56,6 +56,7 @@ fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models 快速入门 详细指南 科研指南 + 注释语法 API 文档 ------------- diff --git a/docs/source/user/example.rst b/docs/source/user/example.rst new file mode 100644 index 00000000..55588c79 --- /dev/null +++ b/docs/source/user/example.rst @@ -0,0 +1,104 @@ +====== +大标题 +====== + +.. note:: + 中文标题需要符号的数量至少是中文字数的两倍 + +.. warning:: + 符号的数量只可以多,不可以少。 + +小标题1 +########### + +小标题2 +********* + +小标题3(正常使用) +======================== + +小标题4 +------------------- + +参考 http://docutils.sourceforge.net/docs/user/rst/quickref.html + +常见语法 +============ + +*emphasis* + +**strong** + +`text` + +``inline literal`` + +http://docutils.sf.net/ 孤立的网址会自动生成链接 + +显示为特定的文字的链接 `sohu `_ + +突出显示的 + 上面文字 + +正常缩进 + + 形成锻炼 + + + +特殊模块 +============ + +选项会自动识别 + +-v An option +-o file Same with value +--delta A long option +--delta=len Same with value + + +图片 + +.. image:: ../figures/procedures.PNG + :height: 200 + :width: 560 + :scale: 50 + :alt: alternate text + :align: center + +显示一个冒号的代码块:: + + 中间要空一行 + +:: + + 不显示冒号的代码块 + +.. code-block:: python + :linenos: + :emphasize-lines: 1,3 + + print("专业的代码块") + print("") + print("有行号和高亮") + +数学块 + +.. math:: + + H_2O + Na = NaOH + H_2 \uparrow + + +各种连接 +=========== + +:doc:`/user/with_fitlog` + +:mod:`~fastNLP.core.batch` + +:class:`~fastNLP.Batch` + +~表示指显示最后一项 + +:meth:`fastNLP.DataSet.apply` + diff --git a/docs/source/user/quickstart.rst b/docs/source/user/quickstart.rst index 43056a26..12e541b7 100644 --- a/docs/source/user/quickstart.rst +++ b/docs/source/user/quickstart.rst @@ -49,7 +49,7 @@ .. code-block:: python from fastNLP.models import CNNText - model = CNNText((len(vocab),50), num_classes=5, padding=2, dropout=0.1) + model = CNNText((len(vocab),50), num_classes=5, dropout=0.1) :class:`~fastNLP.models.CNNText` 的网络结构如下:: @@ -121,4 +121,4 @@ In Epoch:6/Step:12, got best dev performance:AccuracyMetric: acc=0.8 Reloaded the best model. -这份教程只是简单地介绍了使用 fastNLP 工作的流程,具体的细节分析见 :doc:`/user/tutorial_one` \ No newline at end of file +这份教程只是简单地介绍了使用 fastNLP 工作的流程,具体的细节分析见 :doc:`/user/tutorial_one` diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index c67e5919..e666f65f 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -12,7 +12,11 @@ fastNLP 中最常用的组件可以直接从 fastNLP 包中 import ,他们的 __all__ = [ "Instance", "FieldArray", - "Batch", + + "DataSetIter", + "BatchIter", + "TorchLoaderIter", + "Vocabulary", "DataSet", "Const", diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index d6ab8983..792bff66 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -14,7 +14,7 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa 介绍core 的子模块的分工,好像必要性不大 """ -from .batch import Batch +from .batch import DataSetIter, BatchIter, TorchLoaderIter from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC from .const import Const from .dataset import DataSet diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 109d4fe9..ca48a8e1 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -3,7 +3,9 @@ batch 模块实现了 fastNLP 所需的 Batch 类。 """ __all__ = [ - "Batch" + "BatchIter", + "DataSetIter", + "TorchLoaderIter", ] import atexit @@ -12,8 +14,11 @@ from queue import Empty, Full import numpy as np import torch import torch.multiprocessing as mp +import torch.utils.data +from numbers import Number -from .sampler import RandomSampler +from .sampler import SequentialSampler +from .dataset import DataSet _python_is_exit = False @@ -26,160 +31,163 @@ def _set_python_is_exit(): atexit.register(_set_python_is_exit) -class Batch(object): - """ - 别名::class:`fastNLP.Batch` :class:`fastNLP.core.batch.Batch` - - Batch 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出, - 组成 `x` 和 `y`:: - - batch = Batch(data_set, batch_size=16, sampler=SequentialSampler()) - num_batch = len(batch) - for batch_x, batch_y in batch: - # do stuff ... - - :param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 - :param int batch_size: 取出的batch大小 - :param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.RandomSampler`. - - Default: ``None`` - :param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`. - - Default: ``False`` - :param bool prefetch: 若为 ``True`` 使用多进程预先取出下一batch. - - Default: ``False`` - """ - - def __init__(self, dataset, batch_size, sampler=None, as_numpy=False, prefetch=False): +class DataSetGetter: + def __init__(self, dataset: DataSet, as_numpy=False): self.dataset = dataset - self.batch_size = batch_size - if sampler is None: - sampler = RandomSampler() - self.sampler = sampler + self.inputs = {n: f for n, f in dataset.get_all_fields().items() if f.is_input} + self.targets = {n: f for n, f in dataset.get_all_fields().items() if f.is_target} self.as_numpy = as_numpy - self.idx_list = None - self.curidx = 0 - self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) - self.cur_batch_indices = None - self.prefetch = prefetch - self.lengths = 0 - - def fetch_one(self): - if self.curidx >= len(self.idx_list): - return None + self.idx_list = list(range(len(dataset))) + + def __getitem__(self, idx: int): + # mapping idx to sampled idx + idx = self.idx_list[idx] + inputs = {n:f.get(idx) for n, f in self.inputs.items()} + targets = {n:f.get(idx) for n, f in self.targets.items()} + return idx, inputs, targets + + def __len__(self): + return len(self.dataset) + + def collate_fn(self, batch: list): + batch_x = {n:[] for n in self.inputs.keys()} + batch_y = {n:[] for n in self.targets.keys()} + indices = [] + for idx, x, y in batch: + indices.append(idx) + for n, v in x.items(): + batch_x[n].append(v) + for n, v in y.items(): + batch_y[n].append(v) + + def pad_batch(batch_dict, field_array): + for n, vlist in batch_dict.items(): + f = field_array[n] + if f.padder is None: + batch_dict[n] = np.array(vlist) + else: + data = f.pad(vlist) + if not self.as_numpy: + try: + data, flag = _to_tensor(data, f.dtype) + except TypeError as e: + print(f"Field {n} cannot be converted to torch.tensor.") + raise e + batch_dict[n] = data + return batch_dict + + return (indices, + pad_batch(batch_x, self.inputs), + pad_batch(batch_y, self.targets)) + + def set_idx_list(self, idx_list): + if len(idx_list) != len(self.idx_list): + raise ValueError + self.idx_list = idx_list + + def __getattr__(self, item): + if hasattr(self.dataset, item): + return getattr(self.dataset, item) else: - endidx = min(self.curidx + self.batch_size, len(self.idx_list)) - batch_x, batch_y = {}, {} - - indices = self.idx_list[self.curidx:endidx] - self.cur_batch_indices = indices - - for field_name, field in self.dataset.get_all_fields().items(): - if field.is_target or field.is_input: - batch = field.get(indices) - if not self.as_numpy and field.padder is not None: - batch = _to_tensor(batch, field.dtype) - if field.is_target: - batch_y[field_name] = batch - if field.is_input: - batch_x[field_name] = batch - - self.curidx = endidx - return batch_x, batch_y - + raise AttributeError("'DataSetGetter' object has no attribute '{}'".format(item)) + + +class SamplerAdapter(torch.utils.data.Sampler): + def __init__(self, sampler, dataset): + self.sampler = sampler + self.dataset = dataset + def __iter__(self): - """ - Iterate on dataset, fetch batch data. Fetch process don't block the iterate process - :return: - """ - if self.prefetch: - return self._run_batch_iter(self) - - def batch_iter(): - self.init_iter() - while 1: - res = self.fetch_one() - if res is None: - break - yield res - - return batch_iter() - + return iter(self.sampler(self.dataset)) + + +class BatchIter: + def __init__(self): + self.dataiter = None + self.num_batches = None + self.cur_batch_indices = None + self.batch_size = None + def init_iter(self): - self.idx_list = self.sampler(self.dataset) - self.curidx = 0 - self.lengths = self.dataset.get_length() - + pass + + @staticmethod + def get_num_batches(num_samples, batch_size, drop_last): + num_batches = num_samples // batch_size + if not drop_last and (num_samples % batch_size > 0): + num_batches += 1 + return num_batches + + def __iter__(self): + self.init_iter() + for indices, batch_x, batch_y in self.dataiter: + self.cur_batch_indices = indices + yield batch_x, batch_y + + def get_batch_indices(self): + return self.cur_batch_indices + def __len__(self): return self.num_batches - - def get_batch_indices(self): - """ - 取得当前batch在DataSet中所在的index下标序列 - :return list(int) indexes: 下标序列 - """ - return self.cur_batch_indices - - @staticmethod - def _run_fetch(batch, q): - try: - global _python_is_exit - batch.init_iter() - # print('start fetch') - while 1: - res = batch.fetch_one() - # print('fetch one') - while 1: - try: - q.put(res, timeout=3) - break - except Full: - if _python_is_exit: - return - if res is None: - # print('fetch done, waiting processing') - break - # print('fetch exit') - except Exception as e: - q.put(e) - finally: - q.join() - - @staticmethod - def _run_batch_iter(batch): - q = mp.JoinableQueue(maxsize=10) - fetch_p = mp.Process(target=Batch._run_fetch, args=(batch, q)) - fetch_p.daemon = True - fetch_p.start() - # print('fork fetch process') - while 1: - try: - res = q.get(timeout=1) - q.task_done() - # print('get fetched') - if res is None: - break - elif isinstance(res, Exception): - raise res - yield res - except Empty as e: - if fetch_p.is_alive(): - continue - else: - break - fetch_p.terminate() - fetch_p.join() - # print('iter done') + @property + def dataset(self): + return self.dataiter.dataset + + +class DataSetIter(BatchIter): + def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, + num_workers=0, pin_memory=False, drop_last=False, + timeout=0, worker_init_fn=None): + super().__init__() + assert isinstance(dataset, DataSet) + sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) + dataset = DataSetGetter(dataset, as_numpy) + collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None + self.dataiter = torch.utils.data.DataLoader( + dataset=dataset, batch_size=batch_size, sampler=sampler, + collate_fn=collate_fn, num_workers=num_workers, + pin_memory=pin_memory, drop_last=drop_last, + timeout=timeout, worker_init_fn=worker_init_fn) + self.num_batches = self.get_num_batches(len(dataset), batch_size, drop_last) + self.batch_size = batch_size + + +class TorchLoaderIter(BatchIter): + def __init__(self, dataset): + super().__init__() + assert isinstance(dataset, torch.utils.data.DataLoader) + self.dataiter = dataset + self.num_batches = self.get_num_batches(len(dataset), dataset.batch_size, dataset.drop_last) + self.batch_size = dataset.batch_size + +class OnlineDataGettter: + # TODO + pass -def _to_tensor(batch, dtype): + +class OnlineDataIter(BatchIter): + # TODO + def __init__(self, dataset, batch_size=1, buffer_size=10000, sampler=None, as_numpy=False, + num_workers=0, pin_memory=False, drop_last=False, + timeout=0, worker_init_fn=None, **kwargs): + super().__init__() + + +def _to_tensor(batch, field_dtype): try: - if dtype in (int, np.int8, np.int16, np.int32, np.int64): - batch = torch.LongTensor(batch) - if dtype in (float, np.float32, np.float64): - batch = torch.FloatTensor(batch) - except: - pass - return batch + if field_dtype is not None and isinstance(field_dtype, type)\ + and issubclass(field_dtype, Number) \ + and not isinstance(batch, torch.Tensor): + if issubclass(batch.dtype.type, np.floating): + new_batch = torch.as_tensor(batch).float() # 默认使用float32 + elif issubclass(batch.dtype.type, np.integer): + new_batch = torch.as_tensor(batch).long() # 复用内存地址,避免复制 + else: + new_batch = torch.as_tensor(batch) + return new_batch, True + else: + return batch, False + except Exception as e: + raise e diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index e617cf2a..483f6dc1 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -438,26 +438,29 @@ class EarlyStopCallback(Callback): class FitlogCallback(Callback): """ - 该callback将loss和progress自动写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入 - 一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。 - 并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 - fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 + 别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback` + + 该callback可将loss和progress写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入 + 一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。 + 并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 + fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 :param DataSet,dict(DataSet) data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过 dict的方式传入。如果仅传入DataSet, 则被命名为test :param Tester tester: Tester对象,将在on_valid_end时调用。tester中的DataSet会被称为为`test` - :param int verbose: 是否在终端打印内容,0不打印 + :param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值),如果数据集较大建议将该值设置得 + 大一些,不然会导致log文件巨大。默认为0, 即不要记录loss。 + :param int verbose: 是否在终端打印evaluation的结果,0不打印。 :param bool log_exception: fitlog是否记录发生的exception信息 """ - # 还没有被导出到 fastNLP 层 - # 别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback` - - def __init__(self, data=None, tester=None, verbose=0, log_exception=False): + + def __init__(self, data=None, tester=None, log_loss_every=0, verbose=0, log_exception=False): super().__init__() self.datasets = {} self.testers = {} self._log_exception = log_exception + assert isinstance(log_loss_every, int) and log_loss_every>=0 if tester is not None: assert isinstance(tester, Tester), "Only fastNLP.Tester allowed." assert isinstance(data, dict) or data is None, "If tester is not None, only dict[DataSet] allowed for data." @@ -477,7 +480,9 @@ class FitlogCallback(Callback): raise TypeError("data receives dict[DataSet] or DataSet object.") self.verbose = verbose - + self._log_loss_every = log_loss_every + self._avg_loss = 0 + def on_train_begin(self): if (len(self.datasets) > 0 or len(self.testers) > 0) and self.trainer.dev_data is None: raise RuntimeError("Trainer has no dev data, you cannot pass extra data to do evaluation.") @@ -490,8 +495,12 @@ class FitlogCallback(Callback): fitlog.add_progress(total_steps=self.n_steps) def on_backward_begin(self, loss): - fitlog.add_loss(loss.item(), name='loss', step=self.step, epoch=self.epoch) - + if self._log_loss_every>0: + self._avg_loss += loss.item() + if self.step%self._log_loss_every==0: + fitlog.add_loss(self._avg_loss/self._log_loss_every, name='loss', step=self.step, epoch=self.epoch) + self._avg_loss = 0 + def on_valid_end(self, eval_result, metric_key, optimizer, better_result): if better_result: eval_result = deepcopy(eval_result) @@ -518,7 +527,7 @@ class FitlogCallback(Callback): def on_exception(self, exception): fitlog.finish(status=1) if self._log_exception: - fitlog.add_other(str(exception), name='except_info') + fitlog.add_other(repr(exception), name='except_info') class LRScheduler(Callback): diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 9f24adf2..4cd1ad9c 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -285,7 +285,8 @@ from .field import AutoPadder from .field import FieldArray from .instance import Instance from .utils import _get_func_signature - +from .field import AppendToTargetOrInputException +from .field import SetInputOrTargetException class DataSet(object): """ @@ -422,7 +423,7 @@ class DataSet(object): if len(self.field_arrays) == 0: # DataSet has no field yet for name, field in instance.fields.items(): - field = field.tolist() if isinstance(field, np.ndarray) else field + # field = field.tolist() if isinstance(field, np.ndarray) else field self.field_arrays[name] = FieldArray(name, [field]) # 第一个样本,必须用list包装起来 else: if len(self.field_arrays) != len(instance.fields): @@ -431,7 +432,11 @@ class DataSet(object): .format(len(self.field_arrays), len(instance.fields))) for name, field in instance.fields.items(): assert name in self.field_arrays - self.field_arrays[name].append(field) + try: + self.field_arrays[name].append(field) + except AppendToTargetOrInputException as e: + print(f"Cannot append to field:{name}.") + raise e def add_fieldarray(self, field_name, fieldarray): """ @@ -549,6 +554,7 @@ class DataSet(object): self.field_arrays[new_name].name = new_name else: raise KeyError("DataSet has no field named {}.".format(old_name)) + return self def set_target(self, *field_names, flag=True): """ @@ -565,7 +571,11 @@ class DataSet(object): assert isinstance(flag, bool), "Only bool type supported." for name in field_names: if name in self.field_arrays: - self.field_arrays[name].is_target = flag + try: + self.field_arrays[name].is_target = flag + except SetInputOrTargetException as e: + print(f"Cannot set field:{name} as target.") + raise e else: raise KeyError("{} is not a valid field name.".format(name)) @@ -581,7 +591,11 @@ class DataSet(object): """ for name in field_names: if name in self.field_arrays: - self.field_arrays[name].is_input = flag + try: + self.field_arrays[name].is_input = flag + except SetInputOrTargetException as e: + print(f"Cannot set field:{name} as input, exception happens at the {e.index} value.") + raise e else: raise KeyError("{} is not a valid field name.".format(name)) @@ -748,7 +762,20 @@ class DataSet(object): self._add_apply_field(results, new_field_name, kwargs) return results - + + def add_seq_len(self, field_name:str, new_field_name='seq_len'): + """ + 将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。 + + :param field_name: str. + :return: + """ + if self.has_field(field_name=field_name): + self.apply_field(len, field_name, new_field_name=new_field_name) + else: + raise KeyError(f"Field:{field_name} not found.") + return self + def drop(self, func, inplace=True): """ func接受一个Instance,返回bool值。返回值为True时,该Instance会被移除或者加入到返回的DataSet中。 @@ -778,7 +805,7 @@ class DataSet(object): """ 将DataSet按照ratio的比例拆分,返回两个DataSet - :param float ratio: 0 1: - # list 跟 非list 混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) - # >1维list - inner_type_set = set() - for l in content: - [inner_type_set.add(type(obj)) for obj in l] - if list not in inner_type_set: - # 二维list - self.content_dim = 2 - return self._basic_type_detection(inner_type_set) - else: - if len(inner_type_set) == 1: - # >2维list - inner_inner_type_set = set() - for _2d_list in content: - for _1d_list in _2d_list: - [inner_inner_type_set.add(type(obj)) for obj in _1d_list] - if list in inner_inner_type_set: - raise RuntimeError("FieldArray cannot handle 4-D or more-D list.") - # 3维list - self.content_dim = 3 - return self._basic_type_detection(inner_inner_type_set) - else: - # list 跟 非list 混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(inner_type_set))) - else: - # 一维list - for content_type in type_set: - if content_type not in self.BASIC_TYPES: - raise RuntimeError("Unexpected data type in Field '{}'. Expect one of {}. Got {}.".format( - self.name, self.BASIC_TYPES, content_type)) - self.content_dim = 1 - return self._basic_type_detection(type_set) - - def _basic_type_detection(self, type_set): - """ - :param type_set: a set of Python types - :return: one of self.BASIC_TYPES - """ - if len(type_set) == 1: - return type_set.pop() - elif len(type_set) == 2: - # 有多个basic type; 可能需要up-cast - if float in type_set and int in type_set: - # up-cast int to float - return float - else: - # str 跟 int 或者 float 混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) + + def _check_dtype_and_ndim(self): + """ + 检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有 + 通过将直接报错. + + :return: + """ + cell_0 = self.content[0] + index = 0 + try: + type_0, dim_0 = _get_ele_type_and_dim(cell_0) + for cell in self.content[1:]: + index += 1 + type_i, dim_i = _get_ele_type_and_dim(cell) + if type_i!=type_0: + raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}." + ".".format(type_i, index, type_0)) + if dim_0!=dim_i: + raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with " + "dimension:{}.".format(dim_i, index, dim_0)) + self._cell_ndim = dim_0 + self.dtype = type_0 + except SetInputOrTargetException as e: + e.index = index + raise e + + def append(self, val:Any): + """ + :param val: 把该val append到fieldarray。 + :return: + """ + if (self._is_target or self._is_input) and self._ignore_type is False: + type_, dim_ = _get_ele_type_and_dim(val) + if self.dtype!=type_: + raise AppendToTargetOrInputException(f"Value(type:{type_}) are of different types with " + f"previous values(type:{self.dtype}).") + if self._cell_ndim!=dim_: + raise AppendToTargetOrInputException(f"Value(dim:{dim_}) are of different dimensions with " + f"previous values(dim:{self._cell_ndim}).") + self.content.append(val) else: - # str, int, float混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) - - def _1d_list_check(self, val): - """如果不是1D list就报错 - """ - type_set = set((type(obj) for obj in val)) - if any(obj not in self.BASIC_TYPES for obj in type_set): - raise ValueError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) - self._basic_type_detection(type_set) - # otherwise: _basic_type_detection will raise error - return True - - def _2d_list_check(self, val): - """如果不是2D list 就报错 - """ - type_set = set(type(obj) for obj in val) - if list(type_set) != [list]: - raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) - inner_type_set = set() - for l in val: - for obj in l: - inner_type_set.add(type(obj)) - self._basic_type_detection(inner_type_set) - return True - - @staticmethod - def _map_to_np_type(basic_type): - type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray} - return type_mapping[basic_type] - - def __repr__(self): - return "FieldArray {}: {}".format(self.name, self.content.__repr__()) - - def append(self, val): - """将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有 - 的内容是匹配的。 - - :param Any val: 需要append的值。 - """ - if self.ignore_type is False: - if isinstance(val, list): - pass - elif isinstance(val, tuple): # 确保最外层是list - val = list(val) - elif isinstance(val, np.ndarray): - val = val.tolist() - elif any((isinstance(val, t) for t in self.BASIC_TYPES)): - pass - else: - raise RuntimeError( - "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) - - if self.is_input is True or self.is_target is True: - if type(val) == list: - if len(val) == 0: - raise ValueError("Cannot append an empty list.") - if self.content_dim == 2 and self._1d_list_check(val): - # 1维list检查 - pass - elif self.content_dim == 3 and self._2d_list_check(val): - # 2维list检查 - pass - else: - raise RuntimeError( - "Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val)) - elif type(val) in self.BASIC_TYPES and self.content_dim == 1: - # scalar检查 - if type(val) == float and self.pytype == int: - self.pytype = float - self.dtype = self._map_to_np_type(self.pytype) - else: - raise RuntimeError( - "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) - self.content.append(val) - + self.content.append(val) + def __getitem__(self, indices): return self.get(indices, pad=False) - + def __setitem__(self, idx, val): assert isinstance(idx, int) + if (self._is_target or self._is_input) and self.ignore_type is False: # 需要检测类型 + type_, dim_ = _get_ele_type_and_dim(val) + if self.dtype!=type_: + raise RuntimeError(f"Value(type:{type_}) are of different types with " + f"other values(type:{self.dtype}).") + if self._cell_ndim!=dim_: + raise RuntimeError(f"Value(dim:{dim_}) are of different dimensions with " + f"previous values(dim:{self._cell_ndim}).") self.content[idx] = val - + def get(self, indices, pad=True): """ 根据给定的indices返回内容 @@ -257,14 +170,17 @@ class FieldArray(object): if isinstance(indices, int): return self.content[indices] if self.is_input is False and self.is_target is False: - raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) - + raise RuntimeError("Please specify either is_input or is_target to True for {}".format(self.name)) + contents = [self.content[i] for i in indices] if self.padder is None or pad is False: return np.array(contents) else: - return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype) - + return self.pad(contents) + + def pad(self, contents): + return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) + def set_padder(self, padder): """ 设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 @@ -276,7 +192,7 @@ class FieldArray(object): self.padder = deepcopy(padder) else: self.padder = None - + def set_pad_val(self, pad_val): """ 修改padder的pad_val. @@ -286,7 +202,7 @@ class FieldArray(object): if self.padder is not None: self.padder.set_pad_val(pad_val) return self - + def __len__(self): """ Returns the size of FieldArray. @@ -294,7 +210,7 @@ class FieldArray(object): :return int length: """ return len(self.content) - + def to(self, other): """ 将other的属性复制给本FieldArray(other必须为FieldArray类型). @@ -303,22 +219,225 @@ class FieldArray(object): :param other: :class:`~fastNLP.FieldArray` 从哪个field拷贝属性 :return: :class:`~fastNLP.FieldArray` """ - assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) - + assert isinstance(other, FieldArray), "Only supports fastNLP.FieldArray type, not {}.".format(type(other)) + + self.ignore_type = other.ignore_type self.is_input = other.is_input self.is_target = other.is_target self.padder = other.padder - self.ignore_type = other.ignore_type - + return self + def split(self, sep:str=None, inplace:bool=True): + """ + 依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值 + + :param sep: 分割符,如果为None则直接调用str.split()。 + :param inplace: 如果为True,则将新生成值替换本field。否则返回list。 + :return: List[List[str]] or self + """ + new_contents = [] + for index, cell in enumerate(self.content): + try: + new_contents.append(cell.split(sep)) + except Exception as e: + print(f"Exception happens when process value in index {index}.") + raise e + return self._after_process(new_contents, inplace=inplace) + + def int(self, inplace:bool=True): + """ + 将本field中的值调用int(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), + (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) + + :param inplace: 如果为True,则将新生成值替换本field。否则返回list。 + :return: List[int], List[List[int]], self + """ + new_contents = [] + for index, cell in enumerate(self.content): + try: + if isinstance(cell, list): + new_contents.append([int(value) for value in cell]) + else: + new_contents.append(int(cell)) + except Exception as e: + print(f"Exception happens when process value in index {index}.") + print(e) + return self._after_process(new_contents, inplace=inplace) + + def float(self, inplace=True): + """ + 将本field中的值调用float(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), + (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) + + :param inplace: 如果为True,则将新生成值替换本field。否则返回list。 + :return: + """ + new_contents = [] + for index, cell in enumerate(self.content): + try: + if isinstance(cell, list): + new_contents.append([float(value) for value in cell]) + else: + new_contents.append(float(cell)) + except Exception as e: + print(f"Exception happens when process value in index {index}.") + raise e + return self._after_process(new_contents, inplace=inplace) + + def bool(self, inplace=True): + """ + 将本field中的值调用bool(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), + (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) + + :param inplace: 如果为True,则将新生成值替换本field。否则返回list。 + :return: + """ + new_contents = [] + for index, cell in enumerate(self.content): + try: + if isinstance(cell, list): + new_contents.append([bool(value) for value in cell]) + else: + new_contents.append(bool(cell)) + except Exception as e: + print(f"Exception happens when process value in index {index}.") + raise e + + return self._after_process(new_contents, inplace=inplace) + + def lower(self, inplace=True): + """ + 将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), + (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) + + :param inplace: 如果为True,则将新生成值替换本field。否则返回list。 + :return: List[int], List[List[int]], self + """ + new_contents = [] + for index, cell in enumerate(self.content): + try: + if isinstance(cell, list): + new_contents.append([value.lower() for value in cell]) + else: + new_contents.append(cell.lower()) + except Exception as e: + print(f"Exception happens when process value in index {index}.") + raise e + return self._after_process(new_contents, inplace=inplace) + + def upper(self, inplace=True): + """ + 将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), + (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) + + :param inplace: 如果为True,则将新生成值替换本field。否则返回list。 + :return: List[int], List[List[int]], self + """ + new_contents = [] + for index, cell in enumerate(self.content): + try: + if isinstance(cell, list): + new_contents.append([value.upper() for value in cell]) + else: + new_contents.append(cell.upper()) + except Exception as e: + print(f"Exception happens when process value in index {index}.") + raise e + return self._after_process(new_contents, inplace=inplace) -def _is_iterable(content): + def value_count(self): + """ + 返回该field下不同value的数量。多用于统计label数量 + + :return: Counter, key是label,value是出现次数 + """ + count = Counter() + + def cum(cell): + if _is_iterable(cell) and not isinstance(cell, str): + for cell_ in cell: + cum(cell_) + else: + count[cell] += 1 + for cell in self.content: + cum(cell) + return count + + def _after_process(self, new_contents, inplace): + """ + 当调用处理函数之后,决定是否要替换field。 + + :param new_contents: + :param inplace: + :return: self或者生成的content + """ + if inplace: + self.content = new_contents + try: + self.is_input = self.is_input + self.is_target = self.is_input + except SetInputOrTargetException as e: + print("The newly generated field cannot be set as input or target.") + raise e + return self + else: + return new_contents + + +def _get_ele_type_and_dim(cell:Any, dim=0): + """ + 识别cell的类别与dimension的数量 + + numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html + :param cell: + :param dim: + :return: + """ + if isinstance(cell, (str, Number, np.bool_)): + if hasattr(cell, 'dtype'): + return cell.dtype.type, dim + return type(cell), dim + elif isinstance(cell, list): + dim += 1 + res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] + types = set([i for i,j in res]) + dims = set([j for i,j in res]) + if len(types)>1: + raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) + elif len(types)==0: + raise SetInputOrTargetException("Empty value encountered.") + if len(dims)>1: + raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) + return types.pop(), dims.pop() + elif isinstance(cell, torch.Tensor): + return cell.dtype, cell.dim() + dim # 如果是torch.mean的结果是0 + elif isinstance(cell, np.ndarray): + if cell.dtype != np.dtype('O'): # 如果不是object的话说明是well-formatted的了 + return cell.dtype.type, cell.ndim + dim # dtype.type返回的会是np.int32, np.float等 + # 否则需要继续往下iterate + dim += 1 + res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] + types = set([i for i,j in res]) + dims = set([j for i,j in res]) + if len(types)>1: + raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) + elif len(types)==0: + raise SetInputOrTargetException("Empty value encountered.") + if len(dims)>1: + raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) + return types.pop(), dims.pop() + else: # 包含tuple, set, dict以及其它的类型 + raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") + + +def _is_iterable(value): + # 检查是否是iterable的, duck typing try: - _ = (e for e in content) - except TypeError: + iter(value) + return True + except BaseException as e: return False - return True class Padder: @@ -327,32 +446,35 @@ class Padder: 所有padder都需要继承这个类,并覆盖__call__方法。 用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。 - + .. py:function:: __call__(self, contents, field_name, field_ele_dtype): 传入的是List内容。假设有以下的DataSet。 - + :param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 deepcopy一份。 :param str, field_name: field的名称。 :param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 :return: np.array([padded_element]) - + """ - + def __init__(self, pad_val=0, **kwargs): self.pad_val = pad_val - + def set_pad_val(self, pad_val): self.pad_val = pad_val - - def __call__(self, contents, field_name, field_ele_dtype): + + @abstractmethod + def __call__(self, contents, field_name, field_ele_dtype, dim:int): """ 传入的是List内容。假设有以下的DataSet。 :param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 deepcopy一份。 :param str, field_name: field的名称。 - :param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 + :param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True, + 该这个值为None。 + :param dim: 这个field的维度。当ignore_type为True时,该值为None :return: np.array([padded_element]) Example:: @@ -394,50 +516,86 @@ class AutoPadder(Padder): 根据contents的数据自动判定是否需要做padding。 1 如果元素类型(元素类型是指field中最里层元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 - 型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行pad + 型为str, [[1,2], ...]的元素类型为int)的数据不为数值类型则不会进行pad + + 2 如果元素类型为数值类型,比如np.int64, np.float64, int, float, torch.int64等 - 2 如果元素类型为(np.int64, np.float64), + 2.1 如果该field的内容为数值类型(包括int, float等),比如为seq_len, 则不进行padding - 2.1 如果该field的内容为(np.int64, np.float64),比如为seq_len, 则不进行padding + 2.2 如果该field的内容等价于一维list, 那么会将Batch中的List pad为一样长。 - 2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 - 即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad + 2.3 如果该field的内容等价于二维list,那么会按照英语character padding的方式进行padding。如果是character padding建议使用 + :class: fastNLP.EngChar2DPadder. + + 2.4 如果该field的内容等价于三维list,则如果每个instance在每个维度上相等,会组成一个batch的tensor返回,这种情况应该是为图片 + 的情况。 + + 3 其它情况不进行处理,返回一个np.array类型。 """ - def __init__(self, pad_val=0): - """ - :param pad_val: int, padding的位置使用该index - """ super().__init__(pad_val=pad_val) - - def _is_two_dimension(self, contents): - """ - 判断contents是不是只有两个维度。[[1,2], [3]]是两个维度. [[[1,2], [3, 4, 5]], [[4,5]]]有三个维度 - :param contents: - :return: - """ - value = contents[0] - if isinstance(value, (np.ndarray, list)): - value = value[0] - if isinstance(value, (np.ndarray, list)): - return False - return True - return False - - def __call__(self, contents, field_name, field_ele_dtype): - - if not _is_iterable(contents[0]): - array = np.array([content for content in contents], dtype=field_ele_dtype) - elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents): - max_len = max([len(content) for content in contents]) - array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) - for i, content in enumerate(contents): - array[i][:len(content)] = content - elif field_ele_dtype is None: - array = np.array(contents) # 当ignore_type=True时,直接返回contents - else: # should only be str - array = np.array([content for content in contents]) - return array + + def __call__(self, contents, field_name, field_ele_dtype, dim): + if field_ele_dtype: + if dim>3: + return np.array(contents) + if isinstance(field_ele_dtype, type) and \ + (issubclass(field_ele_dtype, np.number) or issubclass(field_ele_dtype, Number)): + if dim==0: + array = np.array(contents, dtype=field_ele_dtype) + elif dim==1: + max_len = max(map(len, contents)) + array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) + for i, content_i in enumerate(contents): + array[i, :len(content_i)] = content_i + elif dim==2: + max_len = max(map(len, contents)) + max_word_len = max([max([len(content_ii) for content_ii in content_i]) for + content_i in contents]) + array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype) + for i, content_i in enumerate(contents): + for j, content_ii in enumerate(content_i): + array[i, j, :len(content_ii)] = content_ii + else: + shape = np.shape(contents) + if len(shape)==4: # 说明各dimension是相同的大小 + array = np.array(contents, dtype=field_ele_dtype) + else: + raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") + return array + elif str(field_ele_dtype).startswith('torch'): + if dim==0: + tensor = torch.tensor(contents).to(field_ele_dtype) + elif dim==1: + max_len = max(map(len, contents)) + tensor = torch.full((len(contents), max_len), fill_value=self.pad_val, dtype=field_ele_dtype) + for i, content_i in enumerate(contents): + tensor[i, :len(content_i)] = torch.tensor(content_i) + elif dim==2: + max_len = max(map(len, contents)) + max_word_len = max([max([len(content_ii) for content_ii in content_i]) for + content_i in contents]) + tensor = torch.full((len(contents), max_len, max_word_len), fill_value=self.pad_val, + dtype=field_ele_dtype) + for i, content_i in enumerate(contents): + for j, content_ii in enumerate(content_i): + tensor[i, j, :len(content_ii)] = torch.tensor(content_ii) + else: + shapes = set([np.shape(content_i) for content_i in contents]) + if len(shapes)>1: + raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") + shape = shapes.pop() + if len(shape)==3: + tensor = torch.full([len(contents)]+list(shape), fill_value=self.pad_val, dtype=field_ele_dtype) + for i, content_i in enumerate(contents): + tensor[i] = torch.tensor(content_i, dtype=field_ele_dtype) + else: + raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") + return tensor + else: + return np.array(contents) # 不进行任何操作 + else: + return np.array(contents) class EngChar2DPadder(Padder): @@ -463,7 +621,7 @@ class EngChar2DPadder(Padder): dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder """ - + def __init__(self, pad_val=0, pad_length=0): """ :param pad_val: int, pad的位置使用该index @@ -471,32 +629,10 @@ class EngChar2DPadder(Padder): 都pad或截取到该长度. """ super().__init__(pad_val=pad_val) - + self.pad_length = pad_length - - def _exactly_three_dims(self, contents, field_name): - """ - 检查传入的contents是否刚好是3维,如果不是3维就报错。理论上,第一个维度是batch,第二个维度是word,第三个维度是character - :param contents: - :param field_name: str - :return: - """ - if not isinstance(contents, list): - raise TypeError("contents should be a list, not {}.".format(type(contents))) - value = contents[0] - try: - value = value[0] - except: - raise ValueError("Field:{} only has one dimension.".format(field_name)) - try: - value = value[0] - except: - raise ValueError("Field:{} only has two dimensions.".format(field_name)) - - if _is_iterable(value): - raise ValueError("Field:{} has more than 3 dimension.".format(field_name)) - - def __call__(self, contents, field_name, field_ele_dtype): + + def __call__(self, contents, field_name, field_ele_dtype, dim): """ 期望输入类似于 [ @@ -510,24 +646,24 @@ class EngChar2DPadder(Padder): :param field_ele_dtype :return: """ - if field_ele_dtype not in (np.int64, np.float64): + if field_ele_dtype not in (np.int64, np.float64, int, float): raise TypeError('dtype of Field:{} should be np.int64 or np.float64 to do 2D padding, get {}.'.format( field_name, field_ele_dtype )) - self._exactly_three_dims(contents, field_name) + assert dim==2, f"Field:{field_name} has {dim}, EngChar2DPadder only supports input with 2 dimensions." if self.pad_length < 1: - max_char_length = max(max([[len(char_lst) for char_lst in word_lst] for word_lst in contents])) + max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents]) else: max_char_length = self.pad_length max_sent_length = max(len(word_lst) for word_lst in contents) batch_size = len(contents) dtype = type(contents[0][0][0]) - + padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, dtype=dtype) for b_idx, word_lst in enumerate(contents): for c_idx, char_lst in enumerate(word_lst): chars = char_lst[:max_char_length] padded_array[b_idx, c_idx, :len(chars)] = chars - + return padded_array diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 9dc02f3d..62e7a8c8 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -34,14 +34,23 @@ class LossBase(object): """ def __init__(self): - self.param_map = {} + self._param_map = {} # key是fun的参数,value是以该值从传入的dict取出value self._checked = False - + + @property + def param_map(self): + if len(self._param_map) == 0: # 如果为空说明还没有初始化 + func_spect = inspect.getfullargspec(self.get_loss) + func_args = [arg for arg in func_spect.args if arg != 'self'] + for arg in func_args: + self._param_map[arg] = arg + return self._param_map + def get_loss(self, *args, **kwargs): raise NotImplementedError def _init_param_map(self, key_map=None, **kwargs): - """检查key_map和其他参数map,并将这些映射关系添加到self.param_map + """检查key_map和其他参数map,并将这些映射关系添加到self._param_map :param dict key_map: 表示key的映射关系 :param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 @@ -53,30 +62,30 @@ class LossBase(object): raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) for key, value in key_map.items(): if value is None: - self.param_map[key] = key + self._param_map[key] = key continue if not isinstance(key, str): raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") if not isinstance(value, str): raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") - self.param_map[key] = value + self._param_map[key] = value value_counter[value].add(key) for key, value in kwargs.items(): if value is None: - self.param_map[key] = key + self._param_map[key] = key continue if not isinstance(value, str): raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") - self.param_map[key] = value + self._param_map[key] = value value_counter[value].add(key) for value, key_set in value_counter.items(): if len(key_set) > 1: raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") - # check consistence between signature and param_map + # check consistence between signature and _param_map func_spect = inspect.getfullargspec(self.get_loss) func_args = [arg for arg in func_spect.args if arg != 'self'] - for func_param, input_param in self.param_map.items(): + for func_param, input_param in self._param_map.items(): if func_param not in func_args: raise NameError( f"Parameter `{func_param}` is not in {_get_func_signature(self.get_loss)}. Please check the " @@ -96,7 +105,7 @@ class LossBase(object): :return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. """ fast_param = {} - if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: + if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: fast_param['pred'] = list(pred_dict.values())[0] fast_param['target'] = list(target_dict.values())[0] return fast_param @@ -115,49 +124,41 @@ class LossBase(object): return loss if not self._checked: - # 1. check consistence between signature and param_map + # 1. check consistence between signature and _param_map func_spect = inspect.getfullargspec(self.get_loss) func_args = set([arg for arg in func_spect.args if arg != 'self']) - for func_arg, input_arg in self.param_map.items(): + for func_arg, input_arg in self._param_map.items(): if func_arg not in func_args: raise NameError(f"`{func_arg}` not in {_get_func_signature(self.get_loss)}.") - # 2. only part of the param_map are passed, left are not + # 2. only part of the _param_map are passed, left are not for arg in func_args: - if arg not in self.param_map: - self.param_map[arg] = arg # This param does not need mapping. + if arg not in self._param_map: + self._param_map[arg] = arg # This param does not need mapping. self._evaluate_args = func_args - self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} - - # need to wrap inputs in dict. + self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self._param_map.items()} + mapped_pred_dict = {} mapped_target_dict = {} - duplicated = [] - for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): - not_duplicate_flag = 0 - if input_arg in self._reverse_param_map: - mapped_arg = self._reverse_param_map[input_arg] - not_duplicate_flag += 1 - else: - mapped_arg = input_arg + for input_arg, mapped_arg in self._reverse_param_map.items(): if input_arg in pred_dict: mapped_pred_dict[mapped_arg] = pred_dict[input_arg] - not_duplicate_flag += 1 if input_arg in target_dict: mapped_target_dict[mapped_arg] = target_dict[input_arg] - not_duplicate_flag += 1 - if not_duplicate_flag == 3: - duplicated.append(input_arg) # missing if not self._checked: + duplicated = [] + for input_arg, mapped_arg in self._reverse_param_map.items(): + if input_arg in pred_dict and input_arg in target_dict: + duplicated.append(input_arg) check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict]) # replace missing. missing = check_res.missing replaced_missing = list(missing) for idx, func_arg in enumerate(missing): # Don't delete `` in this information, nor add `` - replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ + replaced_missing[idx] = f"{self._param_map[func_arg]}" + f"(assign to `{func_arg}` " \ f"in `{self.__class__.__name__}`)" check_res = _CheckRes(missing=replaced_missing, @@ -170,6 +171,8 @@ class LossBase(object): if check_res.missing or check_res.duplicated: raise _CheckError(check_res=check_res, func_signature=_get_func_signature(self.get_loss)) + self._checked = True + refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) loss = self.get_loss(**refined_args) @@ -204,15 +207,12 @@ class LossFunc(LossBase): super(LossFunc, self).__init__() _check_function_or_method(func) + self.get_loss = func if key_map is not None: if not isinstance(key_map, dict): raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}") - self.param_map = key_map - if len(kwargs) > 0: - for key, val in kwargs.items(): - self.param_map.update({key: val}) + self._init_param_map(key_map, **kwargs) - self.get_loss = func class CrossEntropyLoss(LossBase): @@ -232,12 +232,16 @@ class CrossEntropyLoss(LossBase): """ def __init__(self, pred=None, target=None, padding_idx=-100): - # TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际需要(16,4) super(CrossEntropyLoss, self).__init__() self._init_param_map(pred=pred, target=target) self.padding_idx = padding_idx def get_loss(self, pred, target): + if pred.dim()>2: + if pred.size()[:2]==target.size(): + # F.cross_entropy在计算时,如果pred是(16, 10 ,4), 会在第二维上去log_softmax, 所以需要交换一下位置 + pred = pred.transpose(1, 2) + return F.cross_entropy(input=pred, target=target, ignore_index=self.padding_idx) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index f633a80f..d54bf8ec 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -22,7 +22,7 @@ from .utils import _check_arg_dict_list from .utils import _get_func_signature from .utils import seq_len_to_mask from .vocabulary import Vocabulary - +from abc import abstractmethod class MetricBase(object): """ @@ -115,17 +115,28 @@ class MetricBase(object): """ def __init__(self): - self.param_map = {} # key is param in function, value is input param. + self._param_map = {} # key is param in function, value is input param. self._checked = False - + + @property + def param_map(self): + if len(self._param_map) == 0: # 如果为空说明还没有初始化 + func_spect = inspect.getfullargspec(self.evaluate) + func_args = [arg for arg in func_spect.args if arg != 'self'] + for arg in func_args: + self._param_map[arg] = arg + return self._param_map + + @abstractmethod def evaluate(self, *args, **kwargs): raise NotImplementedError - + + @abstractmethod def get_metric(self, reset=True): raise NotImplemented def _init_param_map(self, key_map=None, **kwargs): - """检查key_map和其他参数map,并将这些映射关系添加到self.param_map + """检查key_map和其他参数map,并将这些映射关系添加到self._param_map :param dict key_map: 表示key的映射关系 :param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 @@ -137,30 +148,30 @@ class MetricBase(object): raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) for key, value in key_map.items(): if value is None: - self.param_map[key] = key + self._param_map[key] = key continue if not isinstance(key, str): raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") if not isinstance(value, str): raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") - self.param_map[key] = value + self._param_map[key] = value value_counter[value].add(key) for key, value in kwargs.items(): if value is None: - self.param_map[key] = key + self._param_map[key] = key continue if not isinstance(value, str): raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") - self.param_map[key] = value + self._param_map[key] = value value_counter[value].add(key) for value, key_set in value_counter.items(): if len(key_set) > 1: raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") - # check consistence between signature and param_map + # check consistence between signature and _param_map func_spect = inspect.getfullargspec(self.evaluate) func_args = [arg for arg in func_spect.args if arg != 'self'] - for func_param, input_param in self.param_map.items(): + for func_param, input_param in self._param_map.items(): if func_param not in func_args: raise NameError( f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " @@ -175,7 +186,7 @@ class MetricBase(object): :return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. """ fast_param = {} - if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: + if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: fast_param['pred'] = list(pred_dict.values())[0] fast_param['target'] = list(target_dict.values())[0] return fast_param @@ -204,42 +215,35 @@ class MetricBase(object): if not self._checked: if not callable(self.evaluate): raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") - # 1. check consistence between signature and param_map + # 1. check consistence between signature and _param_map func_spect = inspect.getfullargspec(self.evaluate) func_args = set([arg for arg in func_spect.args if arg != 'self']) - for func_arg, input_arg in self.param_map.items(): + for func_arg, input_arg in self._param_map.items(): if func_arg not in func_args: raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.") - # 2. only part of the param_map are passed, left are not + # 2. only part of the _param_map are passed, left are not for arg in func_args: - if arg not in self.param_map: - self.param_map[arg] = arg # This param does not need mapping. + if arg not in self._param_map: + self._param_map[arg] = arg # This param does not need mapping. self._evaluate_args = func_args - self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} + self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self._param_map.items()} # need to wrap inputs in dict. mapped_pred_dict = {} mapped_target_dict = {} - duplicated = [] - for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): - not_duplicate_flag = 0 - if input_arg in self._reverse_param_map: - mapped_arg = self._reverse_param_map[input_arg] - not_duplicate_flag += 1 - else: - mapped_arg = input_arg + for input_arg, mapped_arg in self._reverse_param_map.items(): if input_arg in pred_dict: mapped_pred_dict[mapped_arg] = pred_dict[input_arg] - not_duplicate_flag += 1 if input_arg in target_dict: mapped_target_dict[mapped_arg] = target_dict[input_arg] - not_duplicate_flag += 1 - if not_duplicate_flag == 3: - duplicated.append(input_arg) # missing if not self._checked: + duplicated = [] + for input_arg, mapped_arg in self._reverse_param_map.items(): + if input_arg in pred_dict and input_arg in target_dict: + duplicated.append(input_arg) check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) # only check missing. # replace missing. @@ -247,7 +251,7 @@ class MetricBase(object): replaced_missing = list(missing) for idx, func_arg in enumerate(missing): # Don't delete `` in this information, nor add `` - replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ + replaced_missing[idx] = f"{self._param_map[func_arg]}" + f"(assign to `{func_arg}` " \ f"in `{self.__class__.__name__}`)" check_res = _CheckRes(missing=replaced_missing, @@ -260,10 +264,10 @@ class MetricBase(object): if check_res.missing or check_res.duplicated: raise _CheckError(check_res=check_res, func_signature=_get_func_signature(self.evaluate)) + self._checked = True refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) self.evaluate(**refined_args) - self._checked = True return @@ -409,6 +413,37 @@ def _bmeso_tag_to_spans(tags, ignore_labels=None): ] +def _bioes_tag_to_spans(tags, ignore_labels=None): + """ + 给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'E-singer', 'O', 'O']。 + 返回[('singer', (1, 4))] (左闭右开区间) + + :param tags: List[str], + :param ignore_labels: List[str], 在该list中的label将被忽略 + :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] + """ + ignore_labels = set(ignore_labels) if ignore_labels else set() + + spans = [] + prev_bioes_tag = None + for idx, tag in enumerate(tags): + tag = tag.lower() + bioes_tag, label = tag[:1], tag[2:] + if bioes_tag in ('b', 's'): + spans.append((label, [idx, idx])) + elif bioes_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]: + spans[-1][1][1] = idx + elif bioes_tag == 'o': + pass + else: + spans.append((label, [idx, idx])) + prev_bioes_tag = bioes_tag + return [(span[0], (span[1][0], span[1][1] + 1)) + for span in spans + if span[0] not in ignore_labels + ] + + def _bio_tag_to_spans(tags, ignore_labels=None): """ 给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 @@ -438,7 +473,7 @@ def _bio_tag_to_spans(tags, ignore_labels=None): class SpanFPreRecMetric(MetricBase): - """ + r""" 别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric` 在序列标注问题中,以span的方式计算F, pre, rec. @@ -469,15 +504,15 @@ class SpanFPreRecMetric(MetricBase): :param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 :param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 :param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_len'取数据。 - :param str encoding_type: 目前支持bio, bmes + :param str encoding_type: 目前支持bio, bmes, bmeso, bioes :param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 个label :param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 label的f1, pre, rec :param str f_type: 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro': 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) - :param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 - 则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 + :param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . + 常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 """ def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, @@ -497,6 +532,8 @@ class SpanFPreRecMetric(MetricBase): self.tag_to_span_func = _bio_tag_to_spans elif self.encoding_type == 'bmeso': self.tag_to_span_func = _bmeso_tag_to_spans + elif self.encoding_type == 'bioes': + self.tag_to_span_func = _bioes_tag_to_spans else: raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") @@ -699,17 +736,17 @@ def _pred_topk(y_prob, k=1): class SQuADMetric(MetricBase): - """ + r""" 别名::class:`fastNLP.SQuADMetric` :class:`fastNLP.core.metrics.SQuADMetric` SQuAD数据集metric - :param pred1: 参数映射表中`pred1`的映射关系,None表示映射关系为`pred1`->`pred1` - :param pred2: 参数映射表中`pred2`的映射关系,None表示映射关系为`pred2`->`pred2` - :param target1: 参数映射表中`target1`的映射关系,None表示映射关系为`target1`->`target1` - :param target2: 参数映射表中`target2`的映射关系,None表示映射关系为`target2`->`target2` - :param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 - 则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 + :param pred1: 参数映射表中 `pred1` 的映射关系,None表示映射关系为 `pred1` -> `pred1` + :param pred2: 参数映射表中 `pred2` 的映射关系,None表示映射关系为 `pred2` -> `pred2` + :param target1: 参数映射表中 `target1` 的映射关系,None表示映射关系为 `target1` -> `target1` + :param target2: 参数映射表中 `target2` 的映射关系,None表示映射关系为 `target2` -> `target2` + :param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . + 常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 :param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 :param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py index 4f37e105..06e586c6 100644 --- a/fastNLP/core/predictor.py +++ b/fastNLP/core/predictor.py @@ -6,7 +6,7 @@ from collections import defaultdict import torch -from . import Batch +from . import DataSetIter from . import DataSet from . import SequentialSampler from .utils import _build_args @@ -44,8 +44,7 @@ class Predictor(object): self.network.eval() batch_output = defaultdict(list) - data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False, - prefetch=False) + data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) if hasattr(self.network, "predict"): predict_func = self.network.predict diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 883e0d01..4cdd4ffb 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -37,7 +37,7 @@ import warnings import torch import torch.nn as nn -from .batch import Batch +from .batch import BatchIter, DataSetIter from .dataset import DataSet from .metrics import _prepare_metrics from .sampler import SequentialSampler @@ -82,7 +82,7 @@ class Tester(object): :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 """ - def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1): + def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1): super(Tester, self).__init__() if not isinstance(data, DataSet): @@ -96,6 +96,14 @@ class Tester(object): self._model = _move_model_to_device(model, device=device) self.batch_size = batch_size self.verbose = verbose + + if isinstance(data, DataSet): + self.data_iterator = DataSetIter( + dataset=data, batch_size=batch_size, num_workers=num_workers, sampler=SequentialSampler()) + elif isinstance(data, BatchIter): + self.data_iterator = data + else: + raise TypeError("data type {} not support".format(type(data))) # 如果是DataParallel将没有办法使用predict方法 if isinstance(self._model, nn.DataParallel): @@ -112,7 +120,10 @@ class Tester(object): raise TypeError(f"`{_model_name}.predict` must be callable to be used " f"for evaluation, not `{type(self._predict_func)}`.") else: - self._predict_func = self._model.forward + if isinstance(model, nn.DataParallel): + self._predict_func = self._model.module.forward + else: + self._predict_func = self._model.forward def test(self): """开始进行验证,并返回验证结果。 @@ -124,7 +135,7 @@ class Tester(object): self._model_device = _get_model_device(self._model) network = self._model self._mode(network, is_test=True) - data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False) + data_iterator = self.data_iterator eval_results = {} try: with torch.no_grad(): diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 2523a957..a303f742 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -311,8 +311,9 @@ try: from tqdm.auto import tqdm except: from .utils import _pseudo_tqdm as tqdm +import warnings -from .batch import Batch +from .batch import DataSetIter, BatchIter from .callback import CallbackManager, CallbackException from .dataset import DataSet from .losses import _prepare_losser @@ -320,7 +321,6 @@ from .metrics import _prepare_metrics from .optimizer import Optimizer from .sampler import Sampler from .sampler import RandomSampler -from .sampler import SequentialSampler from .tester import Tester from .utils import _CheckError from .utils import _build_args @@ -351,6 +351,8 @@ class Trainer(object): :param int batch_size: 训练和验证的时候的batch大小。 :param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` :param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` + :param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch + :param num_workers: int, 有多少个线程来进行数据pad处理。 :param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 :param int n_epochs: 需要优化迭代多少次。 @@ -367,7 +369,6 @@ class Trainer(object): :param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 :param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模型。 保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 - :param prefetch: bool, 是否使用额外的进程对产生batch数据。理论上会使得Batch迭代更快。 :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 :param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 的计算位置进行管理。支持以下的输入: @@ -394,16 +395,17 @@ class Trainer(object): """ def __init__(self, train_data, model, optimizer=None, loss=None, - batch_size=32, sampler=None, update_every=1, - n_epochs=10, print_every=5, + batch_size=32, sampler=None, drop_last=False, update_every=1, + num_workers=0, n_epochs=10, print_every=5, dev_data=None, metrics=None, metric_key=None, - validate_every=-1, save_path=None, - prefetch=False, use_tqdm=True, device=None, - callbacks=None, - check_code_level=0): + validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False, + callbacks=None, check_code_level=0): + if prefetch and num_workers==0: + num_workers = 1 + if prefetch: + warnings.warn("prefetch is deprecated, will be removed in version 0.5.0, please use num_workers instead.") + super(Trainer, self).__init__() - if not isinstance(train_data, DataSet): - raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") if not isinstance(model, nn.Module): raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") @@ -430,18 +432,30 @@ class Trainer(object): if metric_key is not None: self.increase_better = False if metric_key[0] == "-" else True self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key - elif len(metrics) > 0: - self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') - + else: + self.metric_key = None # prepare loss losser = _prepare_losser(loss) # sampler check if sampler is not None and not isinstance(sampler, Sampler): raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) - - if check_code_level > -1: - _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, + + if sampler is None: + sampler = RandomSampler() + + if isinstance(train_data, DataSet): + self.data_iterator = DataSetIter( + dataset=train_data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, drop_last=drop_last) + elif isinstance(train_data, BatchIter): + self.data_iterator = train_data + else: + raise TypeError("train_data type {} not support".format(type(train_data))) + + self.model = _move_model_to_device(model, device=device) + + if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): + _check_code(dataset=train_data, model=self.model, losser=losser, metrics=metrics, dev_data=dev_data, metric_key=metric_key, check_level=check_code_level, batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) # _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 @@ -460,13 +474,9 @@ class Trainer(object): self.best_dev_epoch = None self.best_dev_step = None self.best_dev_perf = None - self.sampler = sampler if sampler is not None else RandomSampler() - self.prefetch = prefetch self.n_steps = (len(self.train_data) // self.batch_size + int( len(self.train_data) % self.batch_size != 0)) * self.n_epochs - - self.model = _move_model_to_device(self.model, device=device) - + if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer elif isinstance(optimizer, Optimizer): @@ -493,13 +503,16 @@ class Trainer(object): self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) - - def train(self, load_best_model=True): + + def train(self, load_best_model=True, on_exception='auto'): """ 使用该函数使Trainer开始训练。 - :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效, - 如果True, trainer将在返回之前重新加载dev表现最好的模型参数。 + :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 + 最好的模型参数。 + :param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 + 支持'ignore','raise', 'auto': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出; + 'auto'将ignore以下两种Exception: CallbackException与KeyboardInterrupt, raise其它exception. :return dict: 返回一个字典类型的数据, 内含以下内容:: @@ -528,10 +541,16 @@ class Trainer(object): self.callback_manager.on_train_begin() self._train() self.callback_manager.on_train_end() - except (CallbackException, KeyboardInterrupt) as e: + + except BaseException as e: self.callback_manager.on_exception(e) + if on_exception == 'auto': + if not isinstance(e, (CallbackException, KeyboardInterrupt)): + raise e + elif on_exception == 'raise': + raise e - if self.dev_data is not None and hasattr(self, 'best_dev_perf'): + if self.dev_data is not None and self.best_dev_perf is not None: print( "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + self.tester._format_eval_results(self.best_dev_perf), ) @@ -559,12 +578,14 @@ class Trainer(object): self.step = 0 self.epoch = 0 start = time.time() - + if isinstance(self.model, nn.DataParallel): + self._forward_func = self.model.module.forward + else: + self._forward_func = self.model.forward with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: self.pbar = pbar avg_loss = 0 - data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, - prefetch=self.prefetch) + data_iterator = self.data_iterator self.batch_per_epoch = data_iterator.num_batches for epoch in range(1, self.n_epochs + 1): self.epoch = epoch @@ -664,11 +685,11 @@ class Trainer(object): self.optimizer.step() def _data_forward(self, network, x): - x = _build_args(network.forward, **x) + x = _build_args(self._forward_func, **x) y = network(**x) if not isinstance(y, dict): raise TypeError( - f"The return value of {_get_func_signature(network.forward)} should be dict, got {type(y)}.") + f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") return y def _grad_backward(self, loss): @@ -737,7 +758,9 @@ class Trainer(object): :return bool value: True means current results on dev set is the best. """ - indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics) + indicator, indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics) + if self.metric_key is None: + self.metric_key = indicator is_better = True if self.best_metric_indicator is None: # first-time validation @@ -776,15 +799,34 @@ def _get_value_info(_dict): strs.append(_str) return strs - +from numbers import Number +from .batch import _to_tensor def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, metric_key=None, check_level=0): # check get_loss 方法 - model_devcie = model.parameters().__next__().device + model_devcie = _get_model_device(model=model) - batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) - for batch_count, (batch_x, batch_y) in enumerate(batch): + def _iter(): + start_idx = 0 + while start_idx 1 and metric_key is None: - raise RuntimeError( - f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?") else: # metric_key is set if metric_key not in metric_dict: raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}") indicator_val = metric_dict[metric_key] + indicator = metric_key else: raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics))) - return indicator_val + return indicator, indicator_val diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 79af296b..d26df966 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -3,7 +3,8 @@ utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户 """ __all__ = [ "cache_results", - "seq_len_to_mask" + "seq_len_to_mask", + "Option", ] import _pickle @@ -21,6 +22,32 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require 'varargs']) +class Option(dict): + """a dict can treat keys as attributes""" + def __getattr__(self, item): + try: + return self.__getitem__(item) + except KeyError: + raise AttributeError(item) + + def __setattr__(self, key, value): + if key.startswith('__') and key.endswith('__'): + raise AttributeError(key) + self.__setitem__(key, value) + + def __delattr__(self, item): + try: + self.pop(item) + except KeyError: + raise AttributeError(item) + + def __getstate__(self): + return self + + def __setstate__(self, state): + self.update(state) + + def _prepare_cache_filepath(filepath): """ 检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径 @@ -258,6 +285,7 @@ def _get_model_device(model): :param model: nn.Module :return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。 """ + # TODO 这个函数存在一定的风险,因为同一个模型可能存在某些parameter不在显卡中,比如BertEmbedding assert isinstance(model, nn.Module) parameters = list(model.parameters()) @@ -268,6 +296,13 @@ def _get_model_device(model): def _build_args(func, **kwargs): + """ + 根据func的初始化参数,从kwargs中选择func需要的参数 + + :param func: callable + :param kwargs: 参数 + :return:dict. func中用到的参数 + """ spect = inspect.getfullargspec(func) if spect.varkw is not None: return kwargs @@ -608,7 +643,7 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): warnings.warn(message=_unused_warn) -def seq_len_to_mask(seq_len): +def seq_len_to_mask(seq_len, max_len=None): """ 将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。 @@ -624,20 +659,26 @@ def seq_len_to_mask(seq_len): >>> mask = seq_len_to_mask(seq_len) >>> print(mask.shape) (14, 15) + >>> seq_len = torch.arange(2, 16) + >>> mask = seq_len_to_mask(seq_len, max_len=100) + >>>print(mask.size()) + torch.Size([14, 100]) :param np.ndarray,torch.LongTensor seq_len: shape将是(B,) + :param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有 + 区别,所以需要传入一个max_len使得mask的长度是pad到该长度。 :return: np.ndarray or torch.Tensor, shape将是(B, max_length)。 元素类似为bool或torch.uint8 """ if isinstance(seq_len, np.ndarray): assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." - max_len = int(seq_len.max()) + max_len = int(max_len) if max_len else int(seq_len.max()) broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) mask = broad_cast_seq_len < seq_len.reshape(-1, 1) elif isinstance(seq_len, torch.Tensor): assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}." batch_size = seq_len.size(0) - max_len = seq_len.max().long() + max_len = int(max_len) if max_len else seq_len.max().long() broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len) mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) else: diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index cbde9cba..66aabd3d 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -1,11 +1,27 @@ __all__ = [ - "Vocabulary" + "Vocabulary", + "VocabularyOption", ] from functools import wraps -from collections import Counter - +from collections import Counter, defaultdict from .dataset import DataSet +from .utils import Option +from functools import partial +import numpy as np + +class VocabularyOption(Option): + def __init__(self, + max_size=None, + min_freq=None, + padding='', + unknown=''): + super().__init__( + max_size=max_size, + min_freq=min_freq, + padding=padding, + unknown=unknown + ) def _check_build_vocab(func): @@ -74,7 +90,9 @@ class Vocabulary(object): self.word2idx = None self.idx2word = None self.rebuild = True - + # 用于承载不需要单独创建entry的词语,具体见from_dataset()方法 + self._no_create_word = defaultdict(int) + @_check_build_status def update(self, word_lst): """依次增加序列中词在词典中的出现频率 @@ -133,7 +151,7 @@ class Vocabulary(object): self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) self.build_reverse_vocab() self.rebuild = False - + def build_reverse_vocab(self): """ 基于 "word to index" dict, 构建 "index to word" dict. @@ -225,8 +243,12 @@ class Vocabulary(object): raise e else: raise RuntimeError("Only DataSet type is allowed.") - - def from_dataset(self, *datasets, field_name): + + @property + def _no_create_word_length(self): + return len(self._no_create_word) + + def from_dataset(self, *datasets, field_name, no_create_entry_dataset=None): """ 使用dataset的对应field中词构建词典:: @@ -238,6 +260,13 @@ class Vocabulary(object): 构建词典所使用的 field(s), 支持一个或多个field 若有多个 DataSet, 每个DataSet都必须有这些field. 目前仅支持的field结构: ``str`` , ``list(str)`` , ``list(list(str))`` + :param no_create_entry_dataset: 可以传入DataSet, List[DataSet]或者None(默认),该选项用在接下来的模型会使用pretrain + 的embedding(包括glove, word2vec, elmo与bert)且会finetune的情况。如果仅使用来自于train的数据建立vocabulary,会导致test与dev + 中的数据无法充分利用到来自于预训练embedding的信息,所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。 + 如果一个词出现在了train中,但是没在预训练模型中,embedding会为它用unk初始化,但它是单独的一个vector,如果 + finetune embedding的话,这个词在更新之后可能会有更好的表示; 而如果这个词仅出现在了dev或test中,那么就不能为它们单独建立vector, + 而应该让它指向unk这个vector的值。所以只位于no_create_entry_dataset中的token,将首先从预训练的词表中寻找它的表示, + 如果找到了,就使用该表示; 如果没有找到,则认为该词的表示应该为unk的表示。 :return self: """ if isinstance(field_name, str): @@ -245,19 +274,28 @@ class Vocabulary(object): elif not isinstance(field_name, list): raise TypeError('invalid argument field_name: {}'.format(field_name)) - def construct_vocab(ins): + def construct_vocab(ins, no_create_entry=False): for fn in field_name: field = ins[fn] if isinstance(field, str): + if no_create_entry and field not in self.word_count: + self._no_create_word[field] += 1 self.add_word(field) - elif isinstance(field, list): - if not isinstance(field[0], list): - self.add_word_lst(field) + elif isinstance(field, (list, np.ndarray)): + if not isinstance(field[0], (list, np.ndarray)): + for word in field: + if no_create_entry and word not in self.word_count: + self._no_create_word[word] += 1 + self.add_word(word) else: - if isinstance(field[0][0], list): + if isinstance(field[0][0], (list, np.ndarray)): raise RuntimeError("Only support field with 2 dimensions.") - [self.add_word_lst(w) for w in field] - + for words in field: + for word in words: + if no_create_entry and word not in self.word_count: + self._no_create_word[word] += 1 + self.add_word(word) + for idx, dataset in enumerate(datasets): if isinstance(dataset, DataSet): try: @@ -266,9 +304,27 @@ class Vocabulary(object): print("When processing the `{}` dataset, the following error occurred.".format(idx)) raise e else: - raise RuntimeError("Only DataSet type is allowed.") + raise TypeError("Only DataSet type is allowed.") + + if no_create_entry_dataset is not None: + partial_construct_vocab = partial(construct_vocab, no_create_entry=True) + if isinstance(no_create_entry_dataset, DataSet): + no_create_entry_dataset.apply(partial_construct_vocab) + elif isinstance(no_create_entry_dataset, list): + for dataset in no_create_entry_dataset: + if not isinstance(dataset, DataSet): + raise TypeError("Only DataSet type is allowed.") + dataset.apply(partial_construct_vocab) return self - + + def _is_word_no_create_entry(self, word): + """ + 判断当前的word是否是不需要创建entry的,具体参见from_dataset的说明 + :param word: str + :return: bool + """ + return word in self._no_create_word + def to_index(self, w): """ 将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出 @@ -323,6 +379,7 @@ class Vocabulary(object): self.word2idx = None self.idx2word = None self.rebuild = True + self._no_create_word.clear() def __getstate__(self): """Use to prepare data for pickle. @@ -344,5 +401,7 @@ class Vocabulary(object): def __repr__(self): return "Vocabulary({}...)".format(list(self.word_count.keys())[:5]) + @_check_build_vocab def __iter__(self): - return iter(list(self.word_count.keys())) + for word, index in self.word2idx.items(): + yield word, index diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py index c8d6a441..28f466a8 100644 --- a/fastNLP/io/__init__.py +++ b/fastNLP/io/__init__.py @@ -26,6 +26,6 @@ __all__ = [ ] from .embed_loader import EmbedLoader -from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \ - PeopleDailyCorpusLoader, Conll2003Loader +from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, \ + SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader from .model_io import ModelLoader, ModelSaver diff --git a/fastNLP/io/base_loader.py b/fastNLP/io/base_loader.py index 4ab1e2d0..465fb7e8 100644 --- a/fastNLP/io/base_loader.py +++ b/fastNLP/io/base_loader.py @@ -1,10 +1,14 @@ __all__ = [ - "BaseLoader" + "BaseLoader", + 'DataInfo', + 'DataSetLoader', ] import _pickle as pickle import os - +from typing import Union, Dict +import os +from ..core.dataset import DataSet class BaseLoader(object): """ @@ -51,24 +55,169 @@ class BaseLoader(object): return obj -class DataLoaderRegister: - _readers = {} - - @classmethod - def set_reader(cls, reader_cls, read_fn_name): - # def wrapper(reader_cls): - if read_fn_name in cls._readers: - raise KeyError( - 'duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls, - read_fn_name)) - if hasattr(reader_cls, 'load'): - cls._readers[read_fn_name] = reader_cls().load - return reader_cls - - @classmethod - def get_reader(cls, read_fn_name): - if read_fn_name in cls._readers: - return cls._readers[read_fn_name] - raise AttributeError('no read function: {}'.format(read_fn_name)) - - # TODO 这个类使用在何处? + + +def _download_from_url(url, path): + try: + from tqdm.auto import tqdm + except: + from ..core.utils import _pseudo_tqdm as tqdm + import requests + + """Download file""" + r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) + chunk_size = 16 * 1024 + total_size = int(r.headers.get('Content-length', 0)) + with open(path, "wb") as file, \ + tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t: + for chunk in r.iter_content(chunk_size): + if chunk: + file.write(chunk) + t.update(len(chunk)) + + +def _uncompress(src, dst): + import zipfile + import gzip + import tarfile + import os + + def unzip(src, dst): + with zipfile.ZipFile(src, 'r') as f: + f.extractall(dst) + + def ungz(src, dst): + with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: + length = 16 * 1024 # 16KB + buf = f.read(length) + while buf: + uf.write(buf) + buf = f.read(length) + + def untar(src, dst): + with tarfile.open(src, 'r:gz') as f: + f.extractall(dst) + + fn, ext = os.path.splitext(src) + _, ext_2 = os.path.splitext(fn) + if ext == '.zip': + unzip(src, dst) + elif ext == '.gz' and ext_2 != '.tar': + ungz(src, dst) + elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz': + untar(src, dst) + else: + raise ValueError('unsupported file {}'.format(src)) + + +class DataInfo: + """ + 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 + + :param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict + :param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader` + :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict + """ + + def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None): + self.vocabs = vocabs or {} + self.embeddings = embeddings or {} + self.datasets = datasets or {} + + def __repr__(self): + _str = 'In total {} datasets:\n'.format(len(self.datasets)) + for name, dataset in self.datasets.items(): + _str += '\t{} has {} instances.\n'.format(name, len(dataset)) + _str += 'In total {} vocabs:\n'.format(len(self.vocabs)) + for name, vocab in self.vocabs.items(): + _str += '\t{} has {} entries.\n'.format(name, len(vocab)) + return _str + +class DataSetLoader: + """ + 别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` + + 定义了各种 DataSetLoader 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。 + + 开发者至少应该编写如下内容: + + - _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet` + - load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet` + - process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet` + + **process 函数中可以 调用load 函数或 _load 函数** + + """ + URL = '' + DATA_DIR = '' + + ROOT_DIR = '.fastnlp/datasets/' + UNCOMPRESS = True + + def _download(self, url: str, pdir: str, uncompress=True) -> str: + """ + + 从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 + + :param url: 下载的网站 + :param pdir: 下载到的目录 + :param uncompress: 是否自动解压缩 + :return: 数据的存放路径 + """ + fn = os.path.basename(url) + path = os.path.join(pdir, fn) + """check data exists""" + if not os.path.exists(path): + os.makedirs(pdir, exist_ok=True) + _download_from_url(url, path) + if uncompress: + dst = os.path.join(pdir, 'data') + if not os.path.exists(dst): + _uncompress(path, dst) + return dst + return path + + def download(self): + return self._download( + self.URL, + os.path.join(self.ROOT_DIR, self.DATA_DIR), + uncompress=self.UNCOMPRESS) + + def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]: + """ + 从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。 + 如果处理多个路径,传入的 dict 中的 key 与返回的 dict 中的 key 保存一致。 + + :param Union[str, Dict[str, str]] paths: 文件路径 + :return: :class:`~fastNLP.DataSet` 类的对象或存储多个 :class:`~fastNLP.DataSet` 的字典 + """ + if isinstance(paths, str): + return self._load(paths) + return {name: self._load(path) for name, path in paths.items()} + + def _load(self, path: str) -> DataSet: + """从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象 + + :param str path: 文件路径 + :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 + """ + raise NotImplementedError + + def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: + """ + 对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。 + + 从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。 + 如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。 + + 返回的 :class:`DataInfo` 对象有如下属性: + + - vocabs: 由从数据集中获取的词表组成的字典,每个词表 + - embeddings: (可选) 数据集对应的词嵌入 + - datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` + + :param paths: 原始数据读取的路径 + :param options: 根据不同的任务和数据集,设计自己的参数 + :return: 返回一个 DataInfo + """ + raise NotImplementedError diff --git a/fastNLP/io/data_loader/sst.py b/fastNLP/io/data_loader/sst.py new file mode 100644 index 00000000..1e1b8bef --- /dev/null +++ b/fastNLP/io/data_loader/sst.py @@ -0,0 +1,95 @@ +from typing import Iterable +from nltk import Tree +from ..base_loader import DataInfo, DataSetLoader +from ...core.vocabulary import VocabularyOption, Vocabulary +from ...core.dataset import DataSet +from ...core.instance import Instance +from ..embed_loader import EmbeddingOption, EmbedLoader + + +class SSTLoader(DataSetLoader): + URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' + DATA_DIR = 'sst/' + + """ + 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` + + 读取SST数据集, DataSet包含fields:: + + words: list(str) 需要分类的文本 + target: str 文本的标签 + + 数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip + + :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` + :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` + """ + + def __init__(self, subtree=False, fine_grained=False): + self.subtree = subtree + + tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', + '3': 'positive', '4': 'very positive'} + if not fine_grained: + tag_v['0'] = tag_v['1'] + tag_v['4'] = tag_v['3'] + self.tag_v = tag_v + + def _load(self, path): + """ + + :param str path: 存储数据的路径 + :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 + """ + datalist = [] + with open(path, 'r', encoding='utf-8') as f: + datas = [] + for l in f: + datas.extend([(s, self.tag_v[t]) + for s, t in self._get_one(l, self.subtree)]) + ds = DataSet() + for words, tag in datas: + ds.append(Instance(words=words, target=tag)) + return ds + + @staticmethod + def _get_one(data, subtree): + tree = Tree.fromstring(data) + if subtree: + return [(t.leaves(), t.label()) for t in tree.subtrees()] + return [(tree.leaves(), tree.label())] + + def process(self, + paths, + train_ds: Iterable[str] = None, + src_vocab_op: VocabularyOption = None, + tgt_vocab_op: VocabularyOption = None, + src_embed_op: EmbeddingOption = None): + input_name, target_name = 'words', 'target' + src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) + + info = DataInfo(datasets=self.load(paths)) + _train_ds = [info.datasets[name] + for name in train_ds] if train_ds else info.datasets.values() + src_vocab.from_dataset(*_train_ds, field_name=input_name) + tgt_vocab.from_dataset(*_train_ds, field_name=target_name) + src_vocab.index_dataset( + *info.datasets.values(), + field_name=input_name, new_field_name=input_name) + tgt_vocab.index_dataset( + *info.datasets.values(), + field_name=target_name, new_field_name=target_name) + info.vocabs = { + input_name: src_vocab, + target_name: tgt_vocab + } + + if src_embed_op is not None: + src_embed_op.vocab = src_vocab + init_emb = EmbedLoader.load_with_vocab(**src_embed_op) + info.embeddings[input_name] = init_emb + + return info + diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 6b2a47f3..cde7517a 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -13,8 +13,6 @@ dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的 为 fastNLP 提供 DataSetLoader 的开发者请参考 :class:`~fastNLP.io.DataSetLoader` 的介绍。 """ __all__ = [ - 'DataInfo', - 'DataSetLoader', 'CSVLoader', 'JsonLoader', 'ConllLoader', @@ -25,158 +23,18 @@ __all__ = [ 'Conll2003Loader', ] -from nltk.tree import Tree - +import os +from nltk import Tree +from typing import Union, Dict +from ..core.vocabulary import Vocabulary from ..core.dataset import DataSet from ..core.instance import Instance from ..core.vocabulary import Vocabulary from .file_reader import _read_csv, _read_json, _read_conll -from typing import Union, Dict -import os - - -def _download_from_url(url, path): - try: - from tqdm.auto import tqdm - except: - from ..core.utils import _pseudo_tqdm as tqdm - import requests - - """Download file""" - r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) - chunk_size = 16 * 1024 - total_size = int(r.headers.get('Content-length', 0)) - with open(path, "wb") as file, \ - tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t: - for chunk in r.iter_content(chunk_size): - if chunk: - file.write(chunk) - t.update(len(chunk)) - return - - -def _uncompress(src, dst): - import zipfile - import gzip - import tarfile - import os - - def unzip(src, dst): - with zipfile.ZipFile(src, 'r') as f: - f.extractall(dst) - - def ungz(src, dst): - with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: - length = 16 * 1024 # 16KB - buf = f.read(length) - while buf: - uf.write(buf) - buf = f.read(length) - - def untar(src, dst): - with tarfile.open(src, 'r:gz') as f: - f.extractall(dst) - - fn, ext = os.path.splitext(src) - _, ext_2 = os.path.splitext(fn) - if ext == '.zip': - unzip(src, dst) - elif ext == '.gz' and ext_2 != '.tar': - ungz(src, dst) - elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz': - untar(src, dst) - else: - raise ValueError('unsupported file {}'.format(src)) - - -class DataInfo: - """ - 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 - - :param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict - :param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader` - :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict - """ - - def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None): - self.vocabs = vocabs or {} - self.embeddings = embeddings or {} - self.datasets = datasets or {} - - -class DataSetLoader: - """ - 别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` - - 定义了各种 DataSetLoader (针对特定数据上的特定任务) 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。 - - 开发者至少应该编写如下内容: - - - _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet` - - load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet` - - process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet` - - **process 函数中可以 调用load 函数或 _load 函数** - - """ - - def _download(self, url: str, path: str, uncompress=True) -> str: - """ - - 从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 - - :param url: 下载的网站 - :param path: 下载到的目录 - :param uncompress: 是否自动解压缩 - :return: 数据的存放路径 - """ - pdir = os.path.dirname(path) - os.makedirs(pdir, exist_ok=True) - _download_from_url(url, path) - if uncompress: - dst = os.path.join(pdir, 'data') - _uncompress(path, dst) - return dst - return path - - def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]: - """ - 从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。 - 如果处理多个路径,传入的 dict 中的 key 与返回的 dict 中的 key 保存一致。 - - :param Union[str, Dict[str, str]] paths: 文件路径 - :return: :class:`~fastNLP.DataSet` 类的对象或存储多个 :class:`~fastNLP.DataSet` 的字典 - """ - if isinstance(paths, str): - return self._load(paths) - return {name: self._load(path) for name, path in paths.items()} - - def _load(self, path: str) -> DataSet: - """从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象 - - :param str path: 文件路径 - :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 - """ - raise NotImplementedError - - def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: - """ - 对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。 - - 从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。 - 如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。 - - 返回的 :class:`DataInfo` 对象有如下属性: - - - vocabs: 由从数据集中获取的词表组成的字典,每个词表 - - embeddings: (可选) 数据集对应的词嵌入 - - datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` - - :param paths: 原始数据读取的路径 - :param options: 根据不同的任务和数据集,设计自己的参数 - :return: 返回一个 DataInfo - """ - raise NotImplementedError +from .base_loader import DataSetLoader, DataInfo +from .data_loader.sst import SSTLoader +from ..core.const import Const +from ..modules.encoder._bert import BertTokenizer class PeopleDailyCorpusLoader(DataSetLoader): @@ -185,12 +43,12 @@ class PeopleDailyCorpusLoader(DataSetLoader): 读取人民日报数据集 """ - + def __init__(self, pos=True, ner=True): super(PeopleDailyCorpusLoader, self).__init__() self.pos = pos self.ner = ner - + def _load(self, data_path): with open(data_path, "r", encoding="utf-8") as f: sents = f.readlines() @@ -235,7 +93,7 @@ class PeopleDailyCorpusLoader(DataSetLoader): example.append(sent_ner) examples.append(example) return self.convert(examples) - + def convert(self, data): """ @@ -263,7 +121,8 @@ class ConllLoader(DataSetLoader): """ 别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.dataset_loader.ConllLoader` - 读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html + 读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html. 数据中以"-DOCSTART-"开头的行将被忽略,因为 + 该符号在conll 2003中被用为文档分割符。 列号从0开始, 每列对应内容为:: @@ -286,7 +145,7 @@ class ConllLoader(DataSetLoader): :param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` :param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` """ - + def __init__(self, headers, indexes=None, dropna=False): super(ConllLoader, self).__init__() if not isinstance(headers, (list, tuple)): @@ -300,7 +159,7 @@ class ConllLoader(DataSetLoader): if len(indexes) != len(headers): raise ValueError self.indexes = indexes - + def _load(self, path): ds = DataSet() for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): @@ -318,7 +177,7 @@ class Conll2003Loader(ConllLoader): 关于数据集的更多信息,参考: https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data """ - + def __init__(self): headers = [ 'tokens', 'pos', 'chunks', 'ner', @@ -356,56 +215,6 @@ def _cut_long_sentence(sent, max_sample_length=200): return cutted_sentence -class SSTLoader(DataSetLoader): - """ - 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` - - 读取SST数据集, DataSet包含fields:: - - words: list(str) 需要分类的文本 - target: str 文本的标签 - - 数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip - - :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` - :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` - """ - - def __init__(self, subtree=False, fine_grained=False): - self.subtree = subtree - - tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', - '3': 'positive', '4': 'very positive'} - if not fine_grained: - tag_v['0'] = tag_v['1'] - tag_v['4'] = tag_v['3'] - self.tag_v = tag_v - - def _load(self, path): - """ - - :param str path: 存储数据的路径 - :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 - """ - datalist = [] - with open(path, 'r', encoding='utf-8') as f: - datas = [] - for l in f: - datas.extend([(s, self.tag_v[t]) - for s, t in self._get_one(l, self.subtree)]) - ds = DataSet() - for words, tag in datas: - ds.append(Instance(words=words, target=tag)) - return ds - - @staticmethod - def _get_one(data, subtree): - tree = Tree.fromstring(data) - if subtree: - return [(t.leaves(), t.label()) for t in tree.subtrees()] - return [(tree.leaves(), tree.label())] - - class JsonLoader(DataSetLoader): """ 别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.dataset_loader.JsonLoader` @@ -419,7 +228,7 @@ class JsonLoader(DataSetLoader): :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . Default: ``False`` """ - + def __init__(self, fields=None, dropna=False): super(JsonLoader, self).__init__() self.dropna = dropna @@ -430,7 +239,7 @@ class JsonLoader(DataSetLoader): for k, v in fields.items(): self.fields[k] = k if v is None else v self.fields_list = list(self.fields.keys()) - + def _load(self, path): ds = DataSet() for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): @@ -454,27 +263,27 @@ class SNLILoader(JsonLoader): 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip """ - + def __init__(self): fields = { - 'sentence1_parse': 'words1', - 'sentence2_parse': 'words2', - 'gold_label': 'target', + 'sentence1_parse': Const.INPUTS(0), + 'sentence2_parse': Const.INPUTS(1), + 'gold_label': Const.TARGET, } super(SNLILoader, self).__init__(fields=fields) - + def _load(self, path): ds = super(SNLILoader, self)._load(path) - + def parse_tree(x): t = Tree.fromstring(x) return t.leaves() - + ds.apply(lambda ins: parse_tree( - ins['words1']), new_field_name='words1') + ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) ds.apply(lambda ins: parse_tree( - ins['words2']), new_field_name='words2') - ds.drop(lambda x: x['target'] == '-') + ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) + ds.drop(lambda x: x[Const.TARGET] == '-') return ds @@ -562,12 +371,12 @@ class CSVLoader(DataSetLoader): :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . Default: ``False`` """ - + def __init__(self, headers=None, sep=",", dropna=False): self.headers = headers self.sep = sep self.dropna = dropna - + def _load(self, path): ds = DataSet() for idx, data in _read_csv(path, headers=self.headers, @@ -582,7 +391,7 @@ def _add_seg_tag(data): :param data: list of ([word], [pos], [heads], [head_tags]) :return: list of ([word], [pos]) """ - + _processed = [] for word_list, pos_list, _, _ in data: new_sample = [] diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py index fb024e73..e046f1df 100644 --- a/fastNLP/io/embed_loader.py +++ b/fastNLP/io/embed_loader.py @@ -1,5 +1,6 @@ __all__ = [ - "EmbedLoader" + "EmbedLoader", + "EmbeddingOption", ] import os @@ -9,6 +10,21 @@ import numpy as np from ..core.vocabulary import Vocabulary from .base_loader import BaseLoader +from ..core.utils import Option + + +class EmbeddingOption(Option): + def __init__(self, + embed_filepath=None, + dtype=np.float32, + normalize=True, + error='ignore'): + super().__init__( + embed_filepath=embed_filepath, + dtype=dtype, + normalize=normalize, + error=error + ) class EmbedLoader(BaseLoader): @@ -20,9 +36,9 @@ class EmbedLoader(BaseLoader): def __init__(self): super(EmbedLoader, self).__init__() - + @staticmethod - def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'): + def load_with_vocab(embed_filepath, vocab, dtype=np.float32, padding='', unknown='', normalize=True, error='ignore'): """ 从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 word2vec(第一行只有两个元素)还是glove格式的数据。 @@ -31,6 +47,8 @@ class EmbedLoader(BaseLoader): :param vocab: 词表 :class:`~fastNLP.Vocabulary` 类型,读取出现在vocab中的词的embedding。 没有出现在vocab中的词的embedding将通过找到的词的embedding的正态分布采样出来,以使得整个Embedding是同分布的。 :param dtype: 读出的embedding的类型 + :param str padding: 词表中padding的token + :param str unknown: 词表中unknown的token :param bool normalize: 是否将每个vector归一化到norm为1 :param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。 这里主要可能出错的地方在于词表有空行或者词表出现了维度不一致。 @@ -54,9 +72,16 @@ class EmbedLoader(BaseLoader): for idx, line in enumerate(f, start_idx): try: parts = line.strip().split() - if parts[0] in vocab: - index = vocab.to_index(parts[0]) - matrix[index] = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) + word = ''.join(parts[:-dim]) + nums = parts[-dim:] + # 对齐unk与pad + if word==padding and vocab.padding is not None: + word = vocab.padding + elif word==unknown and vocab.unknown is not None: + word = vocab.unknown + if word in vocab: + index = vocab.to_index(word) + matrix[index] = np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim) hit_flags[index] = True except Exception as e: if error == 'ignore': @@ -87,14 +112,14 @@ class EmbedLoader(BaseLoader): :param str embed_filepath: 预训练的embedding的路径。 :param dtype: 读出的embedding的类型 - :param str padding: the padding tag for vocabulary. - :param str unknown: the unknown tag for vocabulary. + :param str padding: 词表中的padding的token. 并以此用做vocab的padding。 + :param str unknown: 词表中的unknown的token. 并以此用做vocab的unknown。 :param bool normalize: 是否将每个vector归一化到norm为1 :param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。这里主要可能出错的地 方在于词表有空行或者词表出现了维度不一致。 - :return numpy.ndarray: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 - :return numpy.ndarray: Vocabulary Embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与 + :return (numpy.ndarray, Vocabulary): Embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与 是否使用padding, 以及unknown有没有在词表中找到对应的词。 Vocabulary中的词的顺序与Embedding的顺序是一一对应的。 + """ vocab = Vocabulary(padding=padding, unknown=unknown) vec_dict = {} @@ -111,15 +136,16 @@ class EmbedLoader(BaseLoader): for idx, line in enumerate(f, start=start): try: parts = line.strip().split() - word = parts[0] if dim == -1: dim = len(parts) - 1 - vec = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) + word = ''.join(parts[:-dim]) + nums = parts[-dim:] + vec = np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim) vec_dict[word] = vec vocab.add_word(word) if unknown is not None and unknown == word: found_unknown = True - if found_pad is not None and padding == word: + if padding is not None and padding == word: found_pad = True except Exception as e: if error == 'ignore': diff --git a/fastNLP/io/file_reader.py b/fastNLP/io/file_reader.py index 5963bb56..34b5d7c0 100644 --- a/fastNLP/io/file_reader.py +++ b/fastNLP/io/file_reader.py @@ -90,11 +90,12 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): return sample with open(path, 'r', encoding=encoding) as f: sample = [] - start = next(f) - if '-DOCSTART-' not in start: + start = next(f).strip() + if '-DOCSTART-' not in start and start!='': sample.append(start.split()) for line_idx, line in enumerate(f, 1): - if line.startswith('\n'): + line = line.strip() + if line=='': if len(sample): try: res = parse_conll(sample) @@ -107,7 +108,8 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): elif line.startswith('#'): continue else: - sample.append(line.split()) + if not line.startswith('-DOCSTART-'): + sample.append(line.split()) if len(sample) > 0: try: res = parse_conll(sample) @@ -115,4 +117,5 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): except Exception as e: if dropna: return - raise ValueError('invalid instance at line: {}'.format(line_idx)) + print('invalid instance at line: {}'.format(line_idx)) + raise e diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py new file mode 100644 index 00000000..d178626b --- /dev/null +++ b/fastNLP/io/file_utils.py @@ -0,0 +1,255 @@ + +import os +from pathlib import Path +from urllib.parse import urlparse +import re +import requests +import tempfile +from tqdm import tqdm +import shutil +import hashlib + + +def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: + """ + 给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 + 将文件放入到cache_dir中 + """ + if cache_dir is None: + dataset_cache = Path(get_defalt_path()) + else: + dataset_cache = cache_dir + + parsed = urlparse(url_or_filename) + + if parsed.scheme in ("http", "https"): + # URL, so get it from the cache (downloading if necessary) + return get_from_cache(url_or_filename, dataset_cache) + elif parsed.scheme == "" and Path(os.path.join(dataset_cache, url_or_filename)).exists(): + # File, and it exists. + return Path(url_or_filename) + elif parsed.scheme == "": + # File, but it doesn't exist. + raise FileNotFoundError("file {} not found".format(url_or_filename)) + else: + # Something unknown + raise ValueError( + "unable to parse {} as a URL or as a local path".format(url_or_filename) + ) + +def get_filepath(filepath): + """ + 如果filepath中只有一个文件,则直接返回对应的全路径 + :param filepath: + :return: + """ + if os.path.isdir(filepath): + files = os.listdir(filepath) + if len(files)==1: + return os.path.join(filepath, files[0]) + else: + return filepath + return filepath + +def get_defalt_path(): + """ + 获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 + + :return: + """ + if 'FASTNLP_CACHE_DIR' in os.environ: + fastnlp_cache_dir = os.environ.get('FASTNLP_CACHE_DIR') + if os.path.exists(fastnlp_cache_dir): + return fastnlp_cache_dir + raise RuntimeError("Some errors happens on cache directory.") + else: + raise RuntimeError("There function is not available right now.") + fastnlp_cache_dir = os.path.expanduser(os.path.join("~", ".fastNLP")) + return fastnlp_cache_dir + +def _get_base_url(name): + # 返回的URL结尾必须是/ + if 'FASTNLP_BASE_URL' in os.environ: + fastnlp_base_url = os.environ['FASTNLP_BASE_URL'] + return fastnlp_base_url + raise RuntimeError("There function is not available right now.") + +def split_filename_suffix(filepath): + """ + 给定filepath返回对应的name和suffix + :param filepath: + :return: filename, suffix + """ + filename = os.path.basename(filepath) + if filename.endswith('.tar.gz'): + return filename[:-7], '.tar.gz' + return os.path.splitext(filename) + +def get_from_cache(url: str, cache_dir: Path = None) -> Path: + """ + 尝试在cache_dir中寻找url定义的资源; 如果没有找到。则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。 + 如果从url中下载的资源解压后有多个文件,则返回directory的路径; 如果只有一个资源,则返回具体的路径。 + + """ + cache_dir.mkdir(parents=True, exist_ok=True) + + filename = re.sub(r".+/", "", url) + dir_name, suffix = split_filename_suffix(filename) + sep_index = dir_name[::-1].index('-') + if sep_index<0: + check_sum = None + else: + check_sum = dir_name[-sep_index+1:] + sep_index = len(dir_name) if sep_index==-1 else -sep_index-1 + dir_name = dir_name[:sep_index] + + # 寻找与它名字匹配的内容, 而不关心后缀 + match_dir_name = match_file(dir_name, cache_dir) + if match_dir_name: + dir_name = match_dir_name + cache_path = cache_dir / dir_name + + # get cache path to put the file + if cache_path.exists(): + return get_filepath(cache_path) + + # make HEAD request to check ETag TODO ETag可以用来判断资源是否已经更新了,之后需要加上 + response = requests.head(url, headers={"User-Agent": "fastNLP"}) + if response.status_code != 200: + raise IOError( + f"HEAD request failed for url {url} with status code {response.status_code}." + ) + + # add ETag to filename if it exists + # etag = response.headers.get("ETag") + + if not cache_path.exists(): + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + fd, temp_filename = tempfile.mkstemp() + print("%s not found in cache, downloading to %s"%(url, temp_filename)) + + # GET file object + req = requests.get(url, stream=True, headers={"User-Agent": "fastNLP"}) + content_length = req.headers.get("Content-Length") + total = int(content_length) if content_length is not None else None + progress = tqdm(unit="B", total=total) + sha256 = hashlib.sha256() + with open(temp_filename, "wb") as temp_file: + for chunk in req.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + sha256.update(chunk) + # check sum + digit = sha256.hexdigest()[:8] + if not check_sum: + assert digit == check_sum, "File corrupted when download." + progress.close() + print(f"Finish download from {url}.") + + # 开始解压 + delete_temp_dir = None + if suffix in ('.zip', '.tar.gz'): + uncompress_temp_dir = tempfile.mkdtemp() + delete_temp_dir = uncompress_temp_dir + print(f"Start to uncompress file to {uncompress_temp_dir}.") + if suffix == '.zip': + unzip_file(Path(temp_filename), Path(uncompress_temp_dir)) + else: + untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir)) + filenames = os.listdir(uncompress_temp_dir) + if len(filenames)==1: + if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])): + uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0]) + + cache_path.mkdir(parents=True, exist_ok=True) + print("Finish un-compressing file.") + else: + uncompress_temp_dir = temp_filename + cache_path = str(cache_path) + suffix + success = False + try: + # 复制到指定的位置 + print(f"Copy file to {cache_path}.") + if os.path.isdir(uncompress_temp_dir): + for filename in os.listdir(uncompress_temp_dir): + shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path/filename) + else: + shutil.copyfile(uncompress_temp_dir, cache_path) + success = True + except Exception as e: + print(e) + raise e + finally: + if not success: + if cache_path.exists(): + if cache_path.is_file(): + os.remove(cache_path) + else: + shutil.rmtree(cache_path) + if delete_temp_dir: + shutil.rmtree(delete_temp_dir) + os.close(fd) + os.remove(temp_filename) + + return get_filepath(cache_path) + +def unzip_file(file: Path, to: Path): + # unpack and write out in CoNLL column-like format + from zipfile import ZipFile + + with ZipFile(file, "r") as zipObj: + # Extract all the contents of zip file in current directory + zipObj.extractall(to) + +def untar_gz_file(file:Path, to:Path): + import tarfile + + with tarfile.open(file, 'r:gz') as tar: + tar.extractall(to) + +def match_file(dir_name:str, cache_dir:str)->str: + """ + 匹配的原则是,在cache_dir下的文件: (1) 与dir_name完全一致; (2) 除了后缀以外和dir_name完全一致。 + 如果找到了两个匹配的结果将报错. 如果找到了则返回匹配的文件的名称; 没有找到返回空字符串 + + :param dir_name: 需要匹配的名称 + :param cache_dir: 在该目录下找匹配dir_name是否存在 + :return: str + """ + files = os.listdir(cache_dir) + matched_filenames = [] + for file_name in files: + if re.match(dir_name+'$', file_name) or re.match(dir_name+'\\..*', file_name): + matched_filenames.append(file_name) + if len(matched_filenames)==0: + return '' + elif len(matched_filenames)==1: + return matched_filenames[-1] + else: + raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.") + +if __name__ == '__main__': + cache_dir = Path('caches') + cache_dir = None + # 需要对cache_dir进行测试 + base_url = 'http://0.0.0.0:8888/file/download' + # if True: + # for filename in os.listdir(cache_dir): + # if os.path.isdir(os.path.join(cache_dir, filename)): + # shutil.rmtree(os.path.join(cache_dir, filename)) + # else: + # os.remove(os.path.join(cache_dir, filename)) + # 1. 测试.txt文件 + print(cached_path(base_url + '/{}'.format('txt_test-bcb4fe65.txt'), cache_dir)) + # 2. 测试.zip文件(只有一个文件) + print(cached_path(base_url + '/{}'.format('zip_test-40966d39.zip'), cache_dir)) + # 3. 测试.zip文件(有多个文件) + print(cached_path(base_url + '/{}'.format('zip_pack_test-70c0b20d.zip'), cache_dir)) + # 4. 测试.tar.gz文件 + print(cached_path(base_url + '/{}'.format('tar_gz_test-3e2679cf.tar.gz'), cache_dir)) + # 5. 测试.tar.gz多个文件 + print(cached_path(base_url + '/{}'.format('tar_gz_pack_test-08dfdccd.tar.gz'), cache_dir)) + + # 6. 测试.pkl文件 diff --git a/fastNLP/models/bert.py b/fastNLP/models/bert.py index 960132ad..4846c7fa 100644 --- a/fastNLP/models/bert.py +++ b/fastNLP/models/bert.py @@ -10,6 +10,35 @@ from ..core.const import Const from ..modules.encoder import BertModel +class BertConfig: + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02 + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + + class BertForSequenceClassification(BaseModel): """BERT model for classification. This module is composed of the BERT model with a linear layer on top of @@ -44,14 +73,19 @@ class BertForSequenceClassification(BaseModel): config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) num_labels = 2 - model = BertForSequenceClassification(config, num_labels) + model = BertForSequenceClassification(num_labels, config) logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, num_labels, bert_dir): + def __init__(self, num_labels, config=None, bert_dir=None): super(BertForSequenceClassification, self).__init__() self.num_labels = num_labels - self.bert = BertModel.from_pretrained(bert_dir) + if bert_dir is not None: + self.bert = BertModel.from_pretrained(bert_dir) + else: + if config is None: + config = BertConfig() + self.bert = BertModel(**config.__dict__) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, num_labels) @@ -106,14 +140,19 @@ class BertForMultipleChoice(BaseModel): config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) num_choices = 2 - model = BertForMultipleChoice(config, num_choices, bert_dir) + model = BertForMultipleChoice(num_choices, config, bert_dir) logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, num_choices, bert_dir): + def __init__(self, num_choices, config=None, bert_dir=None): super(BertForMultipleChoice, self).__init__() self.num_choices = num_choices - self.bert = BertModel.from_pretrained(bert_dir) + if bert_dir is not None: + self.bert = BertModel.from_pretrained(bert_dir) + else: + if config is None: + config = BertConfig() + self.bert = BertModel(**config.__dict__) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, 1) @@ -174,14 +213,19 @@ class BertForTokenClassification(BaseModel): num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) num_labels = 2 bert_dir = 'your-bert-file-dir' - model = BertForTokenClassification(config, num_labels, bert_dir) + model = BertForTokenClassification(num_labels, config, bert_dir) logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, num_labels, bert_dir): + def __init__(self, num_labels, config=None, bert_dir=None): super(BertForTokenClassification, self).__init__() self.num_labels = num_labels - self.bert = BertModel.from_pretrained(bert_dir) + if bert_dir is not None: + self.bert = BertModel.from_pretrained(bert_dir) + else: + if config is None: + config = BertConfig() + self.bert = BertModel(**config.__dict__) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, num_labels) @@ -252,9 +296,14 @@ class BertForQuestionAnswering(BaseModel): start_logits, end_logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, bert_dir): + def __init__(self, config=None, bert_dir=None): super(BertForQuestionAnswering, self).__init__() - self.bert = BertModel.from_pretrained(bert_dir) + if bert_dir is not None: + self.bert = BertModel.from_pretrained(bert_dir) + else: + if config is None: + config = BertConfig() + self.bert = BertModel(**config.__dict__) # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version # self.dropout = nn.Dropout(config.hidden_dropout_prob) self.qa_outputs = nn.Linear(config.hidden_size, 2) diff --git a/fastNLP/models/cnn_text_classification.py b/fastNLP/models/cnn_text_classification.py index 3a71a80a..081dd510 100644 --- a/fastNLP/models/cnn_text_classification.py +++ b/fastNLP/models/cnn_text_classification.py @@ -7,6 +7,7 @@ import torch.nn as nn from ..core.const import Const as C from ..modules import encoder +from fastNLP import seq_len_to_mask class CNNText(torch.nn.Module): @@ -21,15 +22,13 @@ class CNNText(torch.nn.Module): :param int num_classes: 一共有多少类 :param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 :param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 - :param int padding: 对句子前后的pad的大小, 用0填充。 :param float dropout: Dropout的大小 """ def __init__(self, init_embed, num_classes, - kernel_nums=(3, 4, 5), - kernel_sizes=(3, 4, 5), - padding=0, + kernel_nums=(30, 40, 50), + kernel_sizes=(1, 3, 5), dropout=0.5): super(CNNText, self).__init__() @@ -38,8 +37,7 @@ class CNNText(torch.nn.Module): self.conv_pool = encoder.ConvMaxpool( in_channels=self.embed.embedding_dim, out_channels=kernel_nums, - kernel_sizes=kernel_sizes, - padding=padding) + kernel_sizes=kernel_sizes) self.dropout = nn.Dropout(dropout) self.fc = nn.Linear(sum(kernel_nums), num_classes) @@ -51,7 +49,11 @@ class CNNText(torch.nn.Module): :return output: dict of torch.LongTensor, [batch_size, num_classes] """ x = self.embed(words) # [N,L] -> [N,L,C] - x = self.conv_pool(x) # [N,L,C] -> [N,C] + if seq_len is not None: + mask = seq_len_to_mask(seq_len) + x = self.conv_pool(x, mask) + else: + x = self.conv_pool(x) # [N,L,C] -> [N,C] x = self.dropout(x) x = self.fc(x) # [N,C] -> [N, N_class] return {C.OUTPUT: x} diff --git a/fastNLP/modules/decoder/crf.py b/fastNLP/modules/decoder/crf.py index beb2b9be..c0717d6f 100644 --- a/fastNLP/modules/decoder/crf.py +++ b/fastNLP/modules/decoder/crf.py @@ -9,7 +9,7 @@ from torch import nn from ..utils import initial_parameter -def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): +def allowed_transitions(id2target, encoding_type='bio', include_start_end=False): """ 别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.crf.allowed_transitions` @@ -17,7 +17,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): :param dict id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 - :param str encoding_type: 支持"bio", "bmes", "bmeso"。 + :param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。 :param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; 为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 @@ -58,7 +58,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): """ - :param str encoding_type: 支持"BIO", "BMES", "BEMSO"。 + :param str encoding_type: 支持"BIO", "BMES", "BEMSO", 'bioes'。 :param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag :param str from_label: 比如"PER", "LOC"等label :param str to_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag @@ -134,9 +134,19 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label return to_tag in ['b', 's', 'end', 'o'] else: raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) - + elif encoding_type == 'bioes': + if from_tag == 'start': + return to_tag in ['b', 's', 'o'] + elif from_tag == 'b': + return to_tag in ['i', 'e'] and from_label == to_label + elif from_tag == 'i': + return to_tag in ['i', 'e'] and from_label == to_label + elif from_tag in ['e', 's', 'o']: + return to_tag in ['b', 's', 'end', 'o'] + else: + raise ValueError("Unexpect tag type {}. Expect only 'B', 'I', 'E', 'S', 'O'.".format(from_tag)) else: - raise ValueError("Only support BIO, BMES, BMESO encoding type, got {}.".format(encoding_type)) + raise ValueError("Only support BIO, BMES, BMESO, BIOES encoding type, got {}.".format(encoding_type)) class ConditionalRandomField(nn.Module): diff --git a/fastNLP/modules/encoder/__init__.py b/fastNLP/modules/encoder/__init__.py index bdc4cbf3..349bce69 100644 --- a/fastNLP/modules/encoder/__init__.py +++ b/fastNLP/modules/encoder/__init__.py @@ -7,6 +7,12 @@ __all__ = [ "ConvMaxpool", "Embedding", + "StaticEmbedding", + "ElmoEmbedding", + "BertEmbedding", + "StackEmbedding", + "LSTMCharEmbedding", + "CNNCharEmbedding", "LSTM", @@ -18,10 +24,12 @@ __all__ = [ "VarLSTM", "VarGRU" ] -from .bert import BertModel +from ._bert import BertModel +from .bert import BertWordPieceEncoder from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder from .conv_maxpool import ConvMaxpool -from .embedding import Embedding +from .embedding import Embedding, StaticEmbedding, ElmoEmbedding, BertEmbedding, \ + StackEmbedding, LSTMCharEmbedding, CNNCharEmbedding from .lstm import LSTM from .star_transformer import StarTransformer from .transformer import TransformerEncoder diff --git a/fastNLP/modules/encoder/_bert.py b/fastNLP/modules/encoder/_bert.py new file mode 100644 index 00000000..254917e5 --- /dev/null +++ b/fastNLP/modules/encoder/_bert.py @@ -0,0 +1,961 @@ + + + +""" +这个页面的代码很大程度上参考了https://github.com/huggingface/pytorch-pretrained-BERT的代码 +""" + + + +from ...core.vocabulary import Vocabulary +import collections + +import unicodedata +from ...io.file_utils import _get_base_url, cached_path +import numpy as np +from itertools import chain +import copy +import json +import math +import os + +import torch +from torch import nn +import glob + +CONFIG_FILE = 'bert_config.json' +MODEL_WEIGHTS = 'pytorch_model.bin' + + +def gelu(x): + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class BertLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + super(BertLayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + + +class BertEmbeddings(nn.Module): + def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size) + self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) + self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None): + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): + super(BertSelfAttention, self).__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (hidden_size, num_attention_heads)) + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + self.value = nn.Linear(hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertSelfOutput(nn.Module): + def __init__(self, hidden_size, hidden_dropout_prob): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): + super(BertAttention, self).__init__() + self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) + self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) + + def forward(self, input_tensor, attention_mask): + self_output = self.self(input_tensor, attention_mask) + attention_output = self.output(self_output, input_tensor) + return attention_output + + +class BertIntermediate(nn.Module): + def __init__(self, hidden_size, intermediate_size, hidden_act): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + self.intermediate_act_fn = ACT2FN[hidden_act] \ + if isinstance(hidden_act, str) else hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): + super(BertOutput, self).__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, + intermediate_size, hidden_act): + super(BertLayer, self).__init__() + self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, + hidden_dropout_prob) + self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) + self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) + + def forward(self, hidden_states, attention_mask): + attention_output = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob, + hidden_dropout_prob, + intermediate_size, hidden_act): + super(BertEncoder, self).__init__() + layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, + intermediate_size, hidden_act) + self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): + all_encoder_layers = [] + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, attention_mask) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + if not output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class BertPooler(nn.Module): + def __init__(self, hidden_size): + super(BertPooler, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertModel(nn.Module): + """BERT(Bidirectional Embedding Representations from Transformers). + + 如果你想使用预训练好的权重矩阵,请在以下网址下载. + sources:: + + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", + 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", + + + 用预训练权重矩阵来建立BERT模型:: + + model = BertModel.from_pretrained("path/to/weights/directory") + + 用随机初始化权重矩阵来建立BERT模型:: + + model = BertModel() + + :param int vocab_size: 词表大小,默认值为30522,为BERT English uncase版本的词表大小 + :param int hidden_size: 隐层大小,默认值为768,为BERT base的版本 + :param int num_hidden_layers: 隐藏层数,默认值为12,为BERT base的版本 + :param int num_attention_heads: 多头注意力头数,默认值为12,为BERT base的版本 + :param int intermediate_size: FFN隐藏层大小,默认值是3072,为BERT base的版本 + :param str hidden_act: FFN隐藏层激活函数,默认值为``gelu`` + :param float hidden_dropout_prob: FFN隐藏层dropout,默认值为0.1 + :param float attention_probs_dropout_prob: Attention层的dropout,默认值为0.1 + :param int max_position_embeddings: 最大的序列长度,默认值为512, + :param int type_vocab_size: 最大segment数量,默认值为2 + :param int initializer_range: 初始化权重范围,默认值为0.02 + """ + + def __init__(self, vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02): + super(BertModel, self).__init__() + self.hidden_size = hidden_size + self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings, + type_vocab_size, hidden_dropout_prob) + self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads, + attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size, + hidden_act) + self.pooler = BertPooler(hidden_size) + self.initializer_range = initializer_range + + self.apply(self.init_bert_weights) + + def init_bert_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings(input_ids, token_type_ids) + encoded_layers = self.encoder(embedding_output, + extended_attention_mask, + output_all_encoded_layers=output_all_encoded_layers) + sequence_output = encoded_layers[-1] + pooled_output = self.pooler(sequence_output) + if not output_all_encoded_layers: + encoded_layers = encoded_layers[-1] + return encoded_layers, pooled_output + + @classmethod + def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs): + # Load config + config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) + config = json.load(open(config_file, "r")) + # config = BertConfig.from_json_file(config_file) + # logger.info("Model config {}".format(config)) + # Instantiate model. + model = cls(*inputs, **config, **kwargs) + if state_dict is None: + files = glob.glob(os.path.join(pretrained_model_dir, '*.bin')) + if len(files)==0: + raise FileNotFoundError(f"There is no *.bin file in {pretrained_model_dir}") + elif len(files)>1: + raise FileExistsError(f"There are multiple *.bin files in {pretrained_model_dir}") + weights_path = files[0] + state_dict = torch.load(weights_path) + + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(model, prefix='' if hasattr(model, 'bert') else 'bert.') + if len(missing_keys) > 0: + print("Weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, missing_keys)) + if len(unexpected_keys) > 0: + print("Weights from pretrained model not used in {}: {}".format( + model.__class__.__name__, unexpected_keys)) + return model + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, "r", encoding="utf-8") as reader: + while True: + token = reader.readline() + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, + do_lower_case=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + self.never_split = never_split + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = self._clean_text(text) + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case and token not in self.never_split: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + if text in self.never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False + + +class BertTokenizer(object): + """Runs end-to-end tokenization: punctuation splitting + wordpiece""" + + def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BertTokenizer. + + Args: + vocab_file: Path to a one-wordpiece-per-line vocabulary file + do_lower_case: Whether to lower case the input + Only has an effect when do_wordpiece_only=False + do_basic_tokenize: Whether to do basic tokenization before wordpiece. + max_len: An artificial maximum length to truncate tokenized sequences to; + Effective maximum length is always the minimum of this + value (if specified) and the underlying BERT model's + sequence length. + never_split: List of tokens which will never be split during tokenization. + Only has an effect when do_wordpiece_only=False + """ + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " + "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict( + [(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, + never_split=never_split) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + self.max_len = max_len if max_len is not None else int(1e12) + + def _reinit_on_new_vocab(self, vocab): + """ + 在load bert之后,可能会对vocab进行重新排列。重新排列之后调用这个函数重新初始化与vocab相关的性质 + + :param vocab: + :return: + """ + self.vocab = vocab + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + + def tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def convert_tokens_to_ids(self, tokens): + """Converts a sequence of tokens into ids using the vocab.""" + ids = [] + for token in tokens: + ids.append(self.vocab[token]) + if len(ids) > self.max_len: + print( + "Token indices sequence length is longer than the specified maximum " + " sequence length for this BERT model ({} > {}). Running this" + " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) + ) + return ids + + def convert_ids_to_tokens(self, ids): + """Converts a sequence of ids in wordpiece tokens using the vocab.""" + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary to a directory or file.""" + index = 0 + if os.path.isdir(vocab_path): + vocab_file = os.path.join(vocab_path, VOCAB_NAME) + else: + vocab_file = vocab_path + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + print("Saving vocabulary to {}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!".format(vocab_file)) + index = token_index + writer.write(token + u'\n') + index += 1 + return vocab_file + + @classmethod + def from_pretrained(cls, model_dir, *inputs, **kwargs): + """ + 给定path,直接读取vocab. + + """ + pretrained_model_name_or_path = os.path.join(model_dir, VOCAB_NAME) + print("loading vocabulary file {}".format(pretrained_model_name_or_path)) + max_len = 512 + kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + # Instantiate tokenizer. + tokenizer = cls(pretrained_model_name_or_path, *inputs, **kwargs) + return tokenizer + +VOCAB_NAME = 'vocab.txt' + +class _WordBertModel(nn.Module): + def __init__(self, model_dir:str, vocab:Vocabulary, layers:str='-1', pool_method:str='first', include_cls_sep:bool=False): + super().__init__() + + self.tokenzier = BertTokenizer.from_pretrained(model_dir) + self.encoder = BertModel.from_pretrained(model_dir) + # 检查encoder_layer_number是否合理 + encoder_layer_number = len(self.encoder.encoder.layer) + self.layers = list(map(int, layers.split(','))) + for layer in self.layers: + if layer<0: + assert -layer<=encoder_layer_number, f"The layer index:{layer} is out of scope for " \ + f"a bert model with {encoder_layer_number} layers." + else: + assert layer