diff --git a/docs/source/user/quickstart.rst b/docs/source/user/quickstart.rst index 43056a26..12e541b7 100644 --- a/docs/source/user/quickstart.rst +++ b/docs/source/user/quickstart.rst @@ -49,7 +49,7 @@ .. code-block:: python from fastNLP.models import CNNText - model = CNNText((len(vocab),50), num_classes=5, padding=2, dropout=0.1) + model = CNNText((len(vocab),50), num_classes=5, dropout=0.1) :class:`~fastNLP.models.CNNText` 的网络结构如下:: @@ -121,4 +121,4 @@ In Epoch:6/Step:12, got best dev performance:AccuracyMetric: acc=0.8 Reloaded the best model. -这份教程只是简单地介绍了使用 fastNLP 工作的流程,具体的细节分析见 :doc:`/user/tutorial_one` \ No newline at end of file +这份教程只是简单地介绍了使用 fastNLP 工作的流程,具体的细节分析见 :doc:`/user/tutorial_one` diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 9aab146c..ce1a82f4 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -13,6 +13,7 @@ import numpy as np import torch import torch.multiprocessing as mp import torch.utils.data +from numbers import Number from .sampler import RandomSampler from .dataset import DataSet @@ -150,8 +151,10 @@ class Batch1(object): for field_name, field in self.dataset.get_all_fields().items(): if field.is_target or field.is_input: batch = field.get(indices) - if not self.as_numpy and field.padder is not None: - batch = _to_tensor(batch, field.dtype) + if not self.as_numpy and \ + field.dtype is not None and \ + issubclass(field.dtype, Number) and not isinstance(batch, torch.Tensor): + batch = _to_tensor(batch) if field.is_target: batch_y[field_name] = batch if field.is_input: @@ -246,12 +249,12 @@ class Batch1(object): # print('iter done') -def _to_tensor(batch, dtype): +def _to_tensor(batch): try: - if dtype in (int, np.int8, np.int16, np.int32, np.int64): - batch = torch.LongTensor(batch) - if dtype in (float, np.float32, np.float64): - batch = torch.FloatTensor(batch) + if issubclass(batch.dtype.type, np.floating): + batch = torch.as_tensor(batch).float() # 默认使用float32 + else: + batch = torch.as_tensor(batch) # 复用内存地址,避免复制 except: pass return batch diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 9f24adf2..b011d15a 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -285,7 +285,8 @@ from .field import AutoPadder from .field import FieldArray from .instance import Instance from .utils import _get_func_signature - +from .field import AppendToTargetOrInputException +from .field import SetInputOrTargetException class DataSet(object): """ @@ -422,7 +423,7 @@ class DataSet(object): if len(self.field_arrays) == 0: # DataSet has no field yet for name, field in instance.fields.items(): - field = field.tolist() if isinstance(field, np.ndarray) else field + # field = field.tolist() if isinstance(field, np.ndarray) else field self.field_arrays[name] = FieldArray(name, [field]) # 第一个样本,必须用list包装起来 else: if len(self.field_arrays) != len(instance.fields): @@ -431,7 +432,11 @@ class DataSet(object): .format(len(self.field_arrays), len(instance.fields))) for name, field in instance.fields.items(): assert name in self.field_arrays - self.field_arrays[name].append(field) + try: + self.field_arrays[name].append(field) + except AppendToTargetOrInputException as e: + print(f"Cannot append to field:{name}.") + raise e def add_fieldarray(self, field_name, fieldarray): """ @@ -549,6 +554,7 @@ class DataSet(object): self.field_arrays[new_name].name = new_name else: raise KeyError("DataSet has no field named {}.".format(old_name)) + return self def set_target(self, *field_names, flag=True): """ @@ -565,7 +571,11 @@ class DataSet(object): assert isinstance(flag, bool), "Only bool type supported." for name in field_names: if name in self.field_arrays: - self.field_arrays[name].is_target = flag + try: + self.field_arrays[name].is_target = flag + except SetInputOrTargetException as e: + print(f"Cannot set field:{name} as target.") + raise e else: raise KeyError("{} is not a valid field name.".format(name)) @@ -581,7 +591,11 @@ class DataSet(object): """ for name in field_names: if name in self.field_arrays: - self.field_arrays[name].is_input = flag + try: + self.field_arrays[name].is_input = flag + except SetInputOrTargetException as e: + print(f"Cannot set field:{name} as input, exception happens at the {e.index} value.") + raise e else: raise KeyError("{} is not a valid field name.".format(name)) @@ -748,7 +762,20 @@ class DataSet(object): self._add_apply_field(results, new_field_name, kwargs) return results - + + def add_seq_len(self, field_name:str, new_field_name='seq_len'): + """ + 将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。 + + :param field_name: str. + :return: + """ + if self.has_field(field_name=field_name): + self.apply_field(len, field_name, new_field_name=new_field_name) + else: + raise KeyError(f"Field:{field_name} not found.") + return self + def drop(self, func, inplace=True): """ func接受一个Instance,返回bool值。返回值为True时,该Instance会被移除或者加入到返回的DataSet中。 diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index 9ef8d963..faa306f3 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -1,251 +1,164 @@ -""" -field模块实现了 FieldArray 和若干 Padder。 FieldArray 是 :class:`~fastNLP.DataSet` 中一列的存储方式, -原理部分请参考 :doc:`fastNLP.core.dataset` - -""" -__all__ = [ - "FieldArray", - "Padder", - "AutoPadder", - "EngChar2DPadder" -] -from copy import deepcopy +from numbers import Number +import torch import numpy as np - - -class FieldArray(object): - """ - 别名::class:`fastNLP.FieldArray` :class:`fastNLP.core.field.FieldArray` - - FieldArray 是用于保存 :class:`~fastNLP.DataSet` 中一个field的类型。 - - :param str name: FieldArray的名称 - :param list,numpy.ndarray content: 列表的元素可以为list,int,float, - :param bool is_target: 这个field是否是一个target field。 - :param bool is_input: 这个field是否是一个input field。 - :param padder: :class:`~fastNLP.Padder` 类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过 - fieldarray.set_pad_val()。默认为None,即使用 :class:`~fastNLP.AutoPadder` 。 - :param bool ignore_type: 是否忽略该field的type,一般如果这个field不需要转为torch.FloatTensor或torch.LongTensor, - 就可以设置为True。具体意义请参考 :class:`~fastNLP.DataSet` 。 - """ - - def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False): +from typing import Any +from abc import abstractmethod +from copy import deepcopy +from collections import Counter + +class SetInputOrTargetException(Exception): + def __init__(self, msg, index=None, field_name=None): + super().__init__(msg) + self.msg = msg + self.index = index # 标示在哪个数据遭遇到问题了 + self.field_name = field_name # 标示当前field的名称 + +class AppendToTargetOrInputException(Exception): + def __init__(self, msg, index=None, field_name=None): + super().__init__(msg) + self.msg = msg + self.index = index # 标示在哪个数据遭遇到问题了 + self.field_name = field_name # 标示当前field的名称 + +class FieldArray: + def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False): + if len(content)==0: + raise RuntimeError("Empty fieldarray is not allowed.") + _content = content + try: + _content = list(_content) + except BaseException as e: + print(f"Cannot convert content(of type:{type(content)}) into list.") + raise e self.name = name - if isinstance(content, list): - # 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list - # 如果DataSet使用list of Instance 初始化, content可能是 [list]/[array]/[2D list] - for idx, item in enumerate(content): - # 这是使用list of Instance 初始化时第一个样本:FieldArray(name, [field]) - # 将[np.array] 转化为 list of list - # 也可以支持[array, array, array]的情况 - if isinstance(item, np.ndarray): - content[idx] = content[idx].tolist() - elif isinstance(content, np.ndarray): - content = content.tolist() # convert np.ndarray into 2-D list - else: - raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) - if len(content) == 0: - raise RuntimeError("Cannot initialize FieldArray with empty list.") - - self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 - self.content_dim = None # 表示content是多少维的list + self.content = _content + self._ignore_type = ignore_type + # 根据input的情况设置input,target等 + self._cell_ndim = None # 多少维度 + self.dtype = None # 最内层的element都是什么类型的 + self._is_input = False + self._is_target = False + + if is_input: + self.is_input = is_input + if is_target: + self.is_target = is_target + if padder is None: padder = AutoPadder(pad_val=0) else: - assert isinstance(padder, Padder), "padder must be of type Padder." + assert isinstance(padder, Padder), "padder must be of type fastNLP.Padder." padder = deepcopy(padder) self.set_padder(padder) - self.ignore_type = ignore_type - - self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array - - self.pytype = None - self.dtype = None - self._is_input = None - self._is_target = None - - if is_input is not None or is_target is not None: - self.is_input = is_input - self.is_target = is_target - - def _set_dtype(self): - if self.ignore_type is False: - self.pytype = self._type_detection(self.content) - self.dtype = self._map_to_np_type(self.pytype) - + + @property + def ignore_type(self): + return self._ignore_type + + @ignore_type.setter + def ignore_type(self, value): + if value: + self._cell_ndim = None + self.dtype = None + self._ignore_type = value + @property def is_input(self): return self._is_input - + @is_input.setter def is_input(self, value): """ 当 field_array.is_input = True / False 时被调用 """ - if value is True: - self._set_dtype() + # 如果(value为True)且(_is_input和_is_target都是False)且(ignore_type为False) + if value is True and \ + self._is_target is False and \ + self._ignore_type is False: + self._check_dtype_and_ndim() + if value is False and self._is_target is False: + self.dtype = None + self._cell_ndim = None self._is_input = value - + @property def is_target(self): return self._is_target - + @is_target.setter def is_target(self, value): """ 当 field_array.is_target = True / False 时被调用 """ - if value is True: - self._set_dtype() + if value is True and \ + self._is_input is False and \ + self._ignore_type is False: + self._check_dtype_and_ndim() + if value is False and self._is_input is False: + self.dtype = None + self._cell_ndim = None self._is_target = value - - def _type_detection(self, content): - """ - 当该field被设置为is_input或者is_target时被调用 - - """ - if len(content) == 0: - raise RuntimeError("Empty list in Field {}.".format(self.name)) - - type_set = set([type(item) for item in content]) - - if list in type_set: - if len(type_set) > 1: - # list 跟 非list 混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) - # >1维list - inner_type_set = set() - for l in content: - [inner_type_set.add(type(obj)) for obj in l] - if list not in inner_type_set: - # 二维list - self.content_dim = 2 - return self._basic_type_detection(inner_type_set) - else: - if len(inner_type_set) == 1: - # >2维list - inner_inner_type_set = set() - for _2d_list in content: - for _1d_list in _2d_list: - [inner_inner_type_set.add(type(obj)) for obj in _1d_list] - if list in inner_inner_type_set: - raise RuntimeError("FieldArray cannot handle 4-D or more-D list.") - # 3维list - self.content_dim = 3 - return self._basic_type_detection(inner_inner_type_set) - else: - # list 跟 非list 混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(inner_type_set))) - else: - # 一维list - for content_type in type_set: - if content_type not in self.BASIC_TYPES: - raise RuntimeError("Unexpected data type in Field '{}'. Expect one of {}. Got {}.".format( - self.name, self.BASIC_TYPES, content_type)) - self.content_dim = 1 - return self._basic_type_detection(type_set) - - def _basic_type_detection(self, type_set): - """ - :param type_set: a set of Python types - :return: one of self.BASIC_TYPES - """ - if len(type_set) == 1: - return type_set.pop() - elif len(type_set) == 2: - # 有多个basic type; 可能需要up-cast - if float in type_set and int in type_set: - # up-cast int to float - return float - else: - # str 跟 int 或者 float 混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) + + def _check_dtype_and_ndim(self): + """ + 检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有 + 通过将直接报错. + + :return: + """ + cell_0 = self.content[0] + index = 0 + try: + type_0, dim_0 = _get_ele_type_and_dim(cell_0) + for cell in self.content[1:]: + index += 1 + type_i, dim_i = _get_ele_type_and_dim(cell) + if type_i!=type_0: + raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}." + ".".format(type_i, index, type_0)) + if dim_0!=dim_i: + raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with " + "dimension:{}.".format(dim_i, index, dim_0)) + self._cell_ndim = dim_0 + self.dtype = type_0 + except SetInputOrTargetException as e: + e.index = index + raise e + + def append(self, val:Any): + """ + :param val: 把该val append到fieldarray。 + :return: + """ + if (self._is_target or self._is_input) and self._ignore_type is False: + type_, dim_ = _get_ele_type_and_dim(val) + if self.dtype!=type_: + raise AppendToTargetOrInputException(f"Value(type:{type_}) are of different types with " + f"previous values(type:{self.dtype}).") + if self._cell_ndim!=dim_: + raise AppendToTargetOrInputException(f"Value(dim:{dim_}) are of different dimensions with " + f"previous values(dim:{self._cell_ndim}).") + self.content.append(val) else: - # str, int, float混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) - - def _1d_list_check(self, val): - """如果不是1D list就报错 - """ - type_set = set((type(obj) for obj in val)) - if any(obj not in self.BASIC_TYPES for obj in type_set): - raise ValueError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) - self._basic_type_detection(type_set) - # otherwise: _basic_type_detection will raise error - return True - - def _2d_list_check(self, val): - """如果不是2D list 就报错 - """ - type_set = set(type(obj) for obj in val) - if list(type_set) != [list]: - raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) - inner_type_set = set() - for l in val: - for obj in l: - inner_type_set.add(type(obj)) - self._basic_type_detection(inner_type_set) - return True - - @staticmethod - def _map_to_np_type(basic_type): - type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray} - return type_mapping[basic_type] - - def __repr__(self): - return "FieldArray {}: {}".format(self.name, self.content.__repr__()) - - def append(self, val): - """将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有 - 的内容是匹配的。 - - :param Any val: 需要append的值。 - """ - if self.ignore_type is False: - if isinstance(val, list): - pass - elif isinstance(val, tuple): # 确保最外层是list - val = list(val) - elif isinstance(val, np.ndarray): - val = val.tolist() - elif any((isinstance(val, t) for t in self.BASIC_TYPES)): - pass - else: - raise RuntimeError( - "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) - - if self.is_input is True or self.is_target is True: - if type(val) == list: - if len(val) == 0: - raise ValueError("Cannot append an empty list.") - if self.content_dim == 2 and self._1d_list_check(val): - # 1维list检查 - pass - elif self.content_dim == 3 and self._2d_list_check(val): - # 2维list检查 - pass - else: - raise RuntimeError( - "Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val)) - elif type(val) in self.BASIC_TYPES and self.content_dim == 1: - # scalar检查 - if type(val) == float and self.pytype == int: - self.pytype = float - self.dtype = self._map_to_np_type(self.pytype) - else: - raise RuntimeError( - "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) - self.content.append(val) - + self.content.append(val) + def __getitem__(self, indices): return self.get(indices, pad=False) - + def __setitem__(self, idx, val): assert isinstance(idx, int) + if (self._is_target or self._is_input) and self.ignore_type is False: # 需要检测类型 + type_, dim_ = _get_ele_type_and_dim(val) + if self.dtype!=type_: + raise RuntimeError(f"Value(type:{type_}) are of different types with " + f"other values(type:{self.dtype}).") + if self._cell_ndim!=dim_: + raise RuntimeError(f"Value(dim:{dim_}) are of different dimensions with " + f"previous values(dim:{self._cell_ndim}).") self.content[idx] = val - + def get(self, indices, pad=True): """ 根据给定的indices返回内容 @@ -257,14 +170,14 @@ class FieldArray(object): if isinstance(indices, int): return self.content[indices] if self.is_input is False and self.is_target is False: - raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) - + raise RuntimeError("Please specify either is_input or is_target to True for {}".format(self.name)) + contents = [self.content[i] for i in indices] if self.padder is None or pad is False: return np.array(contents) else: - return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype) - + return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) + def set_padder(self, padder): """ 设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 @@ -276,7 +189,7 @@ class FieldArray(object): self.padder = deepcopy(padder) else: self.padder = None - + def set_pad_val(self, pad_val): """ 修改padder的pad_val. @@ -286,7 +199,7 @@ class FieldArray(object): if self.padder is not None: self.padder.set_pad_val(pad_val) return self - + def __len__(self): """ Returns the size of FieldArray. @@ -294,7 +207,7 @@ class FieldArray(object): :return int length: """ return len(self.content) - + def to(self, other): """ 将other的属性复制给本FieldArray(other必须为FieldArray类型). @@ -303,22 +216,216 @@ class FieldArray(object): :param other: :class:`~fastNLP.FieldArray` 从哪个field拷贝属性 :return: :class:`~fastNLP.FieldArray` """ - assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) - + assert isinstance(other, FieldArray), "Only supports fastNLP.FieldArray type, not {}.".format(type(other)) + + self.ignore_type = other.ignore_type self.is_input = other.is_input self.is_target = other.is_target self.padder = other.padder - self.ignore_type = other.ignore_type - + return self + def split(self, sep:str=None, inplace:bool=True): + """ + 依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值 + + :param sep: 分割符,如果为None则直接调用str.split()。 + :param inplace: 如果为True,则将新生成值替换本field。否则返回list。 + :return: List[List[str]] or self + """ + new_contents = [] + for index, cell in enumerate(self.content): + try: + new_contents.append(cell.split(sep)) + except Exception as e: + print(f"Exception happens when process value in index {index}.") + print(e) + return self._after_process(new_contents, inplace=inplace) + + def int(self, inplace:bool=True): + """ + 将本field中的值调用int(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), + (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) + + :param inplace: 如果为True,则将新生成值替换本field。否则返回list。 + :return: List[int], List[List[int]], self + """ + new_contents = [] + for index, cell in enumerate(self.content): + try: + if isinstance(cell, list): + new_contents.append([int(value) for value in cell]) + else: + new_contents.append(int(cell)) + except Exception as e: + print(f"Exception happens when process value in index {index}.") + print(e) + return self._after_process(new_contents, inplace=inplace) + + def float(self, inplace=True): + """ + 将本field中的值调用float(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), + (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) + + :param inplace: 如果为True,则将新生成值替换本field。否则返回list。 + :return: + """ + new_contents = [] + for index, cell in enumerate(self.content): + try: + if isinstance(cell, list): + new_contents.append([float(value) for value in cell]) + else: + new_contents.append(float(cell)) + except Exception as e: + print(f"Exception happens when process value in index {index}.") + print(e) + return self._after_process(new_contents, inplace=inplace) + + def bool(self, inplace=True): + """ + 将本field中的值调用bool(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), + (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) + + :param inplace: 如果为True,则将新生成值替换本field。否则返回list。 + :return: + """ + new_contents = [] + for index, cell in enumerate(self.content): + try: + if isinstance(cell, list): + new_contents.append([bool(value) for value in cell]) + else: + new_contents.append(bool(cell)) + except Exception as e: + print(f"Exception happens when process value in index {index}.") + print(e) -def _is_iterable(content): + return self._after_process(new_contents, inplace=inplace) + + def lower(self, inplace=True): + """ + 将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), + (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) + + :param inplace: 如果为True,则将新生成值替换本field。否则返回list。 + :return: List[int], List[List[int]], self + """ + new_contents = [] + for index, cell in enumerate(self.content): + try: + if isinstance(cell, list): + new_contents.append([value.lower() for value in cell]) + else: + new_contents.append(cell.lower()) + except Exception as e: + print(f"Exception happens when process value in index {index}.") + print(e) + return self._after_process(new_contents, inplace=inplace) + + def upper(self, inplace=True): + """ + 将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), + (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) + + :param inplace: 如果为True,则将新生成值替换本field。否则返回list。 + :return: List[int], List[List[int]], self + """ + new_contents = [] + for index, cell in enumerate(self.content): + try: + if isinstance(cell, list): + new_contents.append([value.upper() for value in cell]) + else: + new_contents.append(cell.upper()) + except Exception as e: + print(f"Exception happens when process value in index {index}.") + print(e) + return self._after_process(new_contents, inplace=inplace) + + def value_count(self): + """ + 返回该field下不同value的数量。多用于统计label数量 + + :return: Counter, key是label,value是出现次数 + """ + count = Counter() + for cell in self.content: + count[cell] += 1 + return count + + def _after_process(self, new_contents, inplace): + """ + 当调用处理函数之后,决定是否要替换field。 + + :param new_contents: + :param inplace: + :return: self或者生成的content + """ + if inplace: + self.content = new_contents + try: + self.is_input = self.is_input + self.is_target = self.is_input + except SetInputOrTargetException as e: + print("The newly generated field cannot be set as input or target.") + raise e + return self + else: + return new_contents + + +def _get_ele_type_and_dim(cell:Any, dim=0): + """ + 识别cell的类别与dimension的数量 + + numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html + :param cell: + :param dim: + :return: + """ + if isinstance(cell, (str, Number, np.bool_)): + return type(cell), dim + elif isinstance(cell, list): + dim += 1 + res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] + types = set([i for i,j in res]) + dims = set([j for i,j in res]) + if len(types)>1: + raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) + elif len(types)==0: + raise SetInputOrTargetException("Empty value encountered.") + if len(dims)>1: + raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) + return types.pop(), dims.pop() + elif isinstance(cell, torch.Tensor): + return cell.dtype, cell.dim() + dim # 如果是torch.mean的结果是0 + elif isinstance(cell, np.ndarray): + if cell.dtype != np.dtype('O'): # 如果不是object的话说明是well-formatted的了 + return cell.dtype.type, cell.ndim + dim + # 否则需要继续往下iterate + dim += 1 + res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] + types = set([i for i,j in res]) + dims = set([j for i,j in res]) + if len(types)>1: + raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) + elif len(types)==0: + raise SetInputOrTargetException("Empty value encountered.") + if len(dims)>1: + raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) + return types.pop(), dims.pop() + else: # 包含tuple, set, dict以及其它的类型 + raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") + + +def _is_iterable(value): + # 检查是否是iterable的, duck typing try: - _ = (e for e in content) - except TypeError: + iter(value) + return True + except BaseException as e: return False - return True class Padder: @@ -327,32 +434,35 @@ class Padder: 所有padder都需要继承这个类,并覆盖__call__方法。 用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。 - + .. py:function:: __call__(self, contents, field_name, field_ele_dtype): 传入的是List内容。假设有以下的DataSet。 - + :param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 deepcopy一份。 :param str, field_name: field的名称。 :param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 :return: np.array([padded_element]) - + """ - + def __init__(self, pad_val=0, **kwargs): self.pad_val = pad_val - + def set_pad_val(self, pad_val): self.pad_val = pad_val - - def __call__(self, contents, field_name, field_ele_dtype): + + @abstractmethod + def __call__(self, contents, field_name, field_ele_dtype, dim:int): """ 传入的是List内容。假设有以下的DataSet。 :param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 deepcopy一份。 :param str, field_name: field的名称。 - :param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 + :param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True, + 该这个值为None。 + :param dim: 这个field的维度。当ignore_type为True时,该值为None :return: np.array([padded_element]) Example:: @@ -394,50 +504,87 @@ class AutoPadder(Padder): 根据contents的数据自动判定是否需要做padding。 1 如果元素类型(元素类型是指field中最里层元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 - 型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行pad + 型为str, [[1,2], ...]的元素类型为int)的数据不为数值类型则不会进行pad - 2 如果元素类型为(np.int64, np.float64), + 2 如果元素类型为数值类型,比如np.int64, np.float64, int, float, torch.int64等 - 2.1 如果该field的内容为(np.int64, np.float64),比如为seq_len, 则不进行padding + 2.1 如果该field的内容为数值类型(包括int, float等),比如为seq_len, 则不进行padding - 2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 - 即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad + 2.2 如果该field的内容等价于一维list, 那么会将Batch中的List pad为一样长。 + + 2.3 如果该field的内容等价于二维list,那么会按照英语character padding的方式进行padding。如果是character padding建议使用 + :class: fastNLP.EngChar2DPadder. + + 2.4 如果该field的内容等价于三维list,则如果每个instance在每个维度上相等,会组成一个batch的tensor返回,这种情况应该是为图片 + 的情况。 + + 3 其它情况不进行处理,返回一个np.array类型。 """ - def __init__(self, pad_val=0): - """ - :param pad_val: int, padding的位置使用该index - """ super().__init__(pad_val=pad_val) - - def _is_two_dimension(self, contents): - """ - 判断contents是不是只有两个维度。[[1,2], [3]]是两个维度. [[[1,2], [3, 4, 5]], [[4,5]]]有三个维度 - :param contents: - :return: - """ - value = contents[0] - if isinstance(value, (np.ndarray, list)): - value = value[0] - if isinstance(value, (np.ndarray, list)): - return False - return True - return False - - def __call__(self, contents, field_name, field_ele_dtype): - - if not _is_iterable(contents[0]): - array = np.array([content for content in contents], dtype=field_ele_dtype) - elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents): - max_len = max([len(content) for content in contents]) - array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) - for i, content in enumerate(contents): - array[i][:len(content)] = content - elif field_ele_dtype is None: - array = np.array(contents) # 当ignore_type=True时,直接返回contents - else: # should only be str - array = np.array([content for content in contents]) - return array + + def __call__(self, contents, field_name, field_ele_dtype, dim): + if field_ele_dtype: + if dim>3: + return np.array(contents) + if isinstance(field_ele_dtype, np.dtype) or field_ele_dtype in (float, int, bool, str): + if isinstance(field_ele_dtype, np.number) or field_ele_dtype in (float, int, bool): + if dim==0: + array = np.array(contents, dtype=field_ele_dtype) + elif dim==1: + max_len = max(map(len, contents)) + array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) + for i, content_i in enumerate(contents): + array[i, :len(content_i)] = content_i + elif dim==2: + max_len = max(map(len, contents)) + max_word_len = max([max([len(content_ii) for content_ii in content_i]) for + content_i in contents]) + array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype) + for i, content_i in enumerate(contents): + for j, content_ii in enumerate(content_i): + array[i, j, :len(content_ii)] = content_ii + else: + shape = np.shape(contents) + if len(shape)==4: # 说明各dimension是相同的大小 + array = np.array(contents, dtype=field_ele_dtype) + else: + raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") + return array + return np.array(contents) + elif str(field_ele_dtype).startswith('torch'): + if dim==0: + tensor = torch.tensor(contents).to(field_ele_dtype) + elif dim==1: + max_len = max(map(len, contents)) + tensor = torch.full((len(contents), max_len), fill_value=self.pad_val, dtype=field_ele_dtype) + for i, content_i in enumerate(contents): + tensor[i, :len(content_i)] = torch.tensor(content_i) + elif dim==2: + max_len = max(map(len, contents)) + max_word_len = max([max([len(content_ii) for content_ii in content_i]) for + content_i in contents]) + tensor = torch.full((len(contents), max_len, max_word_len), fill_value=self.pad_val, + dtype=field_ele_dtype) + for i, content_i in enumerate(contents): + for j, content_ii in enumerate(content_i): + tensor[i, j, :len(content_ii)] = torch.tensor(content_ii) + else: + shapes = set([np.shape(content_i) for content_i in contents]) + if len(shapes)>1: + raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") + shape = shapes.pop() + if len(shape)==3: + tensor = torch.full([len(contents)]+list(shape), fill_value=self.pad_val, dtype=field_ele_dtype) + for i, content_i in enumerate(contents): + tensor[i] = torch.tensor(content_i, dtype=field_ele_dtype) + else: + raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") + return tensor + else: + return np.array(contents) # 不进行任何操作 + else: + return np.array(contents) class EngChar2DPadder(Padder): @@ -463,7 +610,7 @@ class EngChar2DPadder(Padder): dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder """ - + def __init__(self, pad_val=0, pad_length=0): """ :param pad_val: int, pad的位置使用该index @@ -471,32 +618,10 @@ class EngChar2DPadder(Padder): 都pad或截取到该长度. """ super().__init__(pad_val=pad_val) - + self.pad_length = pad_length - - def _exactly_three_dims(self, contents, field_name): - """ - 检查传入的contents是否刚好是3维,如果不是3维就报错。理论上,第一个维度是batch,第二个维度是word,第三个维度是character - :param contents: - :param field_name: str - :return: - """ - if not isinstance(contents, list): - raise TypeError("contents should be a list, not {}.".format(type(contents))) - value = contents[0] - try: - value = value[0] - except: - raise ValueError("Field:{} only has one dimension.".format(field_name)) - try: - value = value[0] - except: - raise ValueError("Field:{} only has two dimensions.".format(field_name)) - - if _is_iterable(value): - raise ValueError("Field:{} has more than 3 dimension.".format(field_name)) - - def __call__(self, contents, field_name, field_ele_dtype): + + def __call__(self, contents, field_name, field_ele_dtype, dim): """ 期望输入类似于 [ @@ -510,11 +635,11 @@ class EngChar2DPadder(Padder): :param field_ele_dtype :return: """ - if field_ele_dtype not in (np.int64, np.float64): + if field_ele_dtype not in (np.int64, np.float64, int, float): raise TypeError('dtype of Field:{} should be np.int64 or np.float64 to do 2D padding, get {}.'.format( field_name, field_ele_dtype )) - self._exactly_three_dims(contents, field_name) + assert dim==2, f"Field:{field_name} has {dim}, EngChar2DPadder only supports input with 2 dimensions." if self.pad_length < 1: max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents]) else: @@ -522,12 +647,12 @@ class EngChar2DPadder(Padder): max_sent_length = max(len(word_lst) for word_lst in contents) batch_size = len(contents) dtype = type(contents[0][0][0]) - + padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, dtype=dtype) for b_idx, word_lst in enumerate(contents): for c_idx, char_lst in enumerate(word_lst): chars = char_lst[:max_char_length] padded_array[b_idx, c_idx, :len(chars)] = chars - + return padded_array diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 868d67b1..19c33c86 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -22,7 +22,7 @@ from .utils import _check_arg_dict_list from .utils import _get_func_signature from .utils import seq_len_to_mask from .vocabulary import Vocabulary - +from abc import abstractmethod class MetricBase(object): """ @@ -117,10 +117,12 @@ class MetricBase(object): def __init__(self): self.param_map = {} # key is param in function, value is input param. self._checked = False - + + @abstractmethod def evaluate(self, *args, **kwargs): raise NotImplementedError - + + @abstractmethod def get_metric(self, reset=True): raise NotImplemented diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 57a31a69..d7694e00 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -532,7 +532,7 @@ class Trainer(object): self._train() self.callback_manager.on_train_end() - except Exception as e: + except BaseException as e: self.callback_manager.on_exception(e) if on_exception == 'auto': if not isinstance(e, (CallbackException, KeyboardInterrupt)): diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 9dab47b5..1eb2b70e 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -285,6 +285,7 @@ def _get_model_device(model): :param model: nn.Module :return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。 """ + # TODO 这个函数存在一定的风险,因为同一个模型可能存在某些parameter不在显卡中,比如BertEmbedding assert isinstance(model, nn.Module) parameters = list(model.parameters()) @@ -295,6 +296,13 @@ def _get_model_device(model): def _build_args(func, **kwargs): + """ + 根据func的初始化参数,从kwargs中选择func需要的参数 + + :param func: callable + :param kwargs: 参数 + :return:dict. func中用到的参数 + """ spect = inspect.getfullargspec(func) if spect.varkw is not None: return kwargs diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index bca28e10..1d5d6f32 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -148,7 +148,7 @@ class Vocabulary(object): self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) self.build_reverse_vocab() self.rebuild = False - + def build_reverse_vocab(self): """ 基于 "word to index" dict, 构建 "index to word" dict. @@ -359,5 +359,7 @@ class Vocabulary(object): def __repr__(self): return "Vocabulary({}...)".format(list(self.word_count.keys())[:5]) + @_check_build_vocab def __iter__(self): - return iter(list(self.word_count.keys())) + for word, index in self.word2idx.items(): + yield word, index diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py index bc37777e..5237a8a7 100644 --- a/fastNLP/io/embed_loader.py +++ b/fastNLP/io/embed_loader.py @@ -26,6 +26,7 @@ class EmbeddingOption(Option): error=error ) + class EmbedLoader(BaseLoader): """ 别名::class:`fastNLP.io.EmbedLoader` :class:`fastNLP.io.embed_loader.EmbedLoader` @@ -35,9 +36,9 @@ class EmbedLoader(BaseLoader): def __init__(self): super(EmbedLoader, self).__init__() - + @staticmethod - def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'): + def load_with_vocab(embed_filepath, vocab, dtype=np.float32, padding='', unknown='', normalize=True, error='ignore'): """ 从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 word2vec(第一行只有两个元素)还是glove格式的数据。 @@ -46,6 +47,8 @@ class EmbedLoader(BaseLoader): :param vocab: 词表 :class:`~fastNLP.Vocabulary` 类型,读取出现在vocab中的词的embedding。 没有出现在vocab中的词的embedding将通过找到的词的embedding的正态分布采样出来,以使得整个Embedding是同分布的。 :param dtype: 读出的embedding的类型 + :param str padding: 词表中padding的token + :param str unknown: 词表中unknown的token :param bool normalize: 是否将每个vector归一化到norm为1 :param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。 这里主要可能出错的地方在于词表有空行或者词表出现了维度不一致。 @@ -69,8 +72,14 @@ class EmbedLoader(BaseLoader): for idx, line in enumerate(f, start_idx): try: parts = line.strip().split() - if parts[0] in vocab: - index = vocab.to_index(parts[0]) + word = parts[0] + # 对齐unk与pad + if word==padding and vocab.padding is not None: + word = vocab.padding + elif word==unknown and vocab.unknown is not None: + word = vocab.unknown + if word in vocab: + index = vocab.to_index(word) matrix[index] = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) hit_flags[index] = True except Exception as e: @@ -102,14 +111,14 @@ class EmbedLoader(BaseLoader): :param str embed_filepath: 预训练的embedding的路径。 :param dtype: 读出的embedding的类型 - :param str padding: the padding tag for vocabulary. - :param str unknown: the unknown tag for vocabulary. + :param str padding: 词表中的padding的token. 并以此用做vocab的padding。 + :param str unknown: 词表中的unknown的token. 并以此用做vocab的unknown。 :param bool normalize: 是否将每个vector归一化到norm为1 :param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。这里主要可能出错的地 方在于词表有空行或者词表出现了维度不一致。 - :return numpy.ndarray: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 - :return numpy.ndarray: Vocabulary Embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与 + :return (numpy.ndarray, Vocabulary): Embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与 是否使用padding, 以及unknown有没有在词表中找到对应的词。 Vocabulary中的词的顺序与Embedding的顺序是一一对应的。 + """ vocab = Vocabulary(padding=padding, unknown=unknown) vec_dict = {} @@ -134,7 +143,7 @@ class EmbedLoader(BaseLoader): vocab.add_word(word) if unknown is not None and unknown == word: found_unknown = True - if found_pad is not None and padding == word: + if padding is not None and padding == word: found_pad = True except Exception as e: if error == 'ignore': diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py new file mode 100644 index 00000000..11c7ab64 --- /dev/null +++ b/fastNLP/io/file_utils.py @@ -0,0 +1,255 @@ + +import os +from pathlib import Path +from urllib.parse import urlparse +import re +import requests +import tempfile +from tqdm import tqdm +import shutil +import hashlib + + +def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: + """ + 给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 + 将文件放入到 + """ + if cache_dir is None: + dataset_cache = Path(get_defalt_path()) + else: + dataset_cache = cache_dir + + parsed = urlparse(url_or_filename) + + if parsed.scheme in ("http", "https"): + # URL, so get it from the cache (downloading if necessary) + return get_from_cache(url_or_filename, dataset_cache) + elif parsed.scheme == "" and Path(os.path.join(dataset_cache, url_or_filename)).exists(): + # File, and it exists. + return Path(url_or_filename) + elif parsed.scheme == "": + # File, but it doesn't exist. + raise FileNotFoundError("file {} not found".format(url_or_filename)) + else: + # Something unknown + raise ValueError( + "unable to parse {} as a URL or as a local path".format(url_or_filename) + ) + +def get_filepath(filepath): + """ + 如果filepath中只有一个文件,则直接返回对应的全路径 + :param filepath: + :return: + """ + if os.path.isdir(filepath): + files = os.listdir(filepath) + if len(files)==1: + return os.path.join(filepath, files[0]) + else: + return filepath + return filepath + +def get_defalt_path(): + """ + 获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 + + :return: + """ + if 'FASTNLP_CACHE_DIR' in os.environ: + fastnlp_cache_dir = os.environ.get('FASTNLP_CACHE_DIR') + if os.path.exists(fastnlp_cache_dir): + return fastnlp_cache_dir + raise RuntimeError("Some errors happens on cache directory.") + else: + raise RuntimeError("There function is not available right now.") + fastnlp_cache_dir = os.path.expanduser(os.path.join("~", ".fastNLP")) + return fastnlp_cache_dir + +def _get_base_url(name): + # 返回的URL结尾必须是/ + if 'FASTNLP_BASE_URL' in os.environ: + fastnlp_base_url = os.environ['FASTNLP_BASE_URL'] + return fastnlp_base_url + raise RuntimeError("There function is not available right now.") + +def split_filename_suffix(filepath): + """ + 给定filepath返回对应的name和suffix + :param filepath: + :return: filename, suffix + """ + filename = os.path.basename(filepath) + if filename.endswith('.tar.gz'): + return filename[:-7], '.tar.gz' + return os.path.splitext(filename) + +def get_from_cache(url: str, cache_dir: Path = None) -> Path: + """ + 尝试在cache_dir中寻找url定义的资源; 如果没有找到。则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。 + 如果从url中下载的资源解压后有多个文件,则返回directory的路径; 如果只有一个资源,则返回具体的路径 + + """ + cache_dir.mkdir(parents=True, exist_ok=True) + + filename = re.sub(r".+/", "", url) + dir_name, suffix = split_filename_suffix(filename) + sep_index = dir_name[::-1].index('-') + if sep_index<0: + check_sum = None + else: + check_sum = dir_name[-sep_index+1:] + sep_index = len(dir_name) if sep_index==-1 else -sep_index-1 + dir_name = dir_name[:sep_index] + + # 寻找与它名字匹配的内容, 而不关心后缀 + match_dir_name = match_file(dir_name, cache_dir) + if match_dir_name: + dir_name = match_dir_name + cache_path = cache_dir / dir_name + + # get cache path to put the file + if cache_path.exists(): + return get_filepath(cache_path) + + # make HEAD request to check ETag TODO ETag可以用来判断资源是否已经更新了,之后需要加上 + response = requests.head(url, headers={"User-Agent": "fastNLP"}) + if response.status_code != 200: + raise IOError( + f"HEAD request failed for url {url} with status code {response.status_code}." + ) + + # add ETag to filename if it exists + # etag = response.headers.get("ETag") + + if not cache_path.exists(): + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + fd, temp_filename = tempfile.mkstemp() + print("%s not found in cache, downloading to %s"%(url, temp_filename)) + + # GET file object + req = requests.get(url, stream=True, headers={"User-Agent": "fastNLP"}) + content_length = req.headers.get("Content-Length") + total = int(content_length) if content_length is not None else None + progress = tqdm(unit="B", total=total) + sha256 = hashlib.sha256() + with open(temp_filename, "wb") as temp_file: + for chunk in req.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + sha256.update(chunk) + # check sum + digit = sha256.hexdigest()[:8] + if not check_sum: + assert digit == check_sum, "File corrupted when download." + progress.close() + print(f"Finish download from {url}.") + + # 开始解压 + delete_temp_dir = None + if suffix in ('.zip', '.tar.gz'): + uncompress_temp_dir = tempfile.mkdtemp() + delete_temp_dir = uncompress_temp_dir + print(f"Start to uncompress file to {uncompress_temp_dir}.") + if suffix == '.zip': + unzip_file(Path(temp_filename), Path(uncompress_temp_dir)) + else: + untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir)) + filenames = os.listdir(uncompress_temp_dir) + if len(filenames)==1: + if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])): + uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0]) + + cache_path.mkdir(parents=True, exist_ok=True) + print("Finish un-compressing file.") + else: + uncompress_temp_dir = temp_filename + cache_path = str(cache_path) + suffix + success = False + try: + # 复制到指定的位置 + print(f"Copy file to {cache_path}.") + if os.path.isdir(uncompress_temp_dir): + for filename in os.listdir(uncompress_temp_dir): + shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path/filename) + else: + shutil.copyfile(uncompress_temp_dir, cache_path) + success = True + except Exception as e: + print(e) + raise e + finally: + if not success: + if cache_path.exists(): + if cache_path.is_file(): + os.remove(cache_path) + else: + shutil.rmtree(cache_path) + if delete_temp_dir: + shutil.rmtree(delete_temp_dir) + os.close(fd) + os.remove(temp_filename) + + return get_filepath(cache_path) + +def unzip_file(file: Path, to: Path): + # unpack and write out in CoNLL column-like format + from zipfile import ZipFile + + with ZipFile(file, "r") as zipObj: + # Extract all the contents of zip file in current directory + zipObj.extractall(to) + +def untar_gz_file(file:Path, to:Path): + import tarfile + + with tarfile.open(file, 'r:gz') as tar: + tar.extractall(to) + +def match_file(dir_name:str, cache_dir:str)->str: + """ + 匹配的原则是,在cache_dir下的文件: (1) 与dir_name完全一致; (2) 除了后缀以外和dir_name完全一致。 + 如果找到了两个匹配的结果将报错. 如果找到了则返回匹配的文件的名称; 没有找到返回空字符串 + + :param dir_name: 需要匹配的名称 + :param cache_dir: 在该目录下找匹配dir_name是否存在 + :return: str + """ + files = os.listdir(cache_dir) + matched_filenames = [] + for file_name in files: + if re.match(dir_name+'$', file_name) or re.match(dir_name+'\\..*', file_name): + matched_filenames.append(file_name) + if len(matched_filenames)==0: + return '' + elif len(matched_filenames)==1: + return matched_filenames[-1] + else: + raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.") + +if __name__ == '__main__': + cache_dir = Path('caches') + cache_dir = None + # 需要对cache_dir进行测试 + base_url = 'http://0.0.0.0:8888/file/download' + # if True: + # for filename in os.listdir(cache_dir): + # if os.path.isdir(os.path.join(cache_dir, filename)): + # shutil.rmtree(os.path.join(cache_dir, filename)) + # else: + # os.remove(os.path.join(cache_dir, filename)) + # 1. 测试.txt文件 + print(cached_path(base_url + '/{}'.format('txt_test-bcb4fe65.txt'), cache_dir)) + # 2. 测试.zip文件(只有一个文件) + print(cached_path(base_url + '/{}'.format('zip_test-40966d39.zip'), cache_dir)) + # 3. 测试.zip文件(有多个文件) + print(cached_path(base_url + '/{}'.format('zip_pack_test-70c0b20d.zip'), cache_dir)) + # 4. 测试.tar.gz文件 + print(cached_path(base_url + '/{}'.format('tar_gz_test-3e2679cf.tar.gz'), cache_dir)) + # 5. 测试.tar.gz多个文件 + print(cached_path(base_url + '/{}'.format('tar_gz_pack_test-08dfdccd.tar.gz'), cache_dir)) + + # 6. 测试.pkl文件 diff --git a/fastNLP/models/cnn_text_classification.py b/fastNLP/models/cnn_text_classification.py index 3a71a80a..081dd510 100644 --- a/fastNLP/models/cnn_text_classification.py +++ b/fastNLP/models/cnn_text_classification.py @@ -7,6 +7,7 @@ import torch.nn as nn from ..core.const import Const as C from ..modules import encoder +from fastNLP import seq_len_to_mask class CNNText(torch.nn.Module): @@ -21,15 +22,13 @@ class CNNText(torch.nn.Module): :param int num_classes: 一共有多少类 :param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 :param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 - :param int padding: 对句子前后的pad的大小, 用0填充。 :param float dropout: Dropout的大小 """ def __init__(self, init_embed, num_classes, - kernel_nums=(3, 4, 5), - kernel_sizes=(3, 4, 5), - padding=0, + kernel_nums=(30, 40, 50), + kernel_sizes=(1, 3, 5), dropout=0.5): super(CNNText, self).__init__() @@ -38,8 +37,7 @@ class CNNText(torch.nn.Module): self.conv_pool = encoder.ConvMaxpool( in_channels=self.embed.embedding_dim, out_channels=kernel_nums, - kernel_sizes=kernel_sizes, - padding=padding) + kernel_sizes=kernel_sizes) self.dropout = nn.Dropout(dropout) self.fc = nn.Linear(sum(kernel_nums), num_classes) @@ -51,7 +49,11 @@ class CNNText(torch.nn.Module): :return output: dict of torch.LongTensor, [batch_size, num_classes] """ x = self.embed(words) # [N,L] -> [N,L,C] - x = self.conv_pool(x) # [N,L,C] -> [N,C] + if seq_len is not None: + mask = seq_len_to_mask(seq_len) + x = self.conv_pool(x, mask) + else: + x = self.conv_pool(x) # [N,L,C] -> [N,C] x = self.dropout(x) x = self.fc(x) # [N,C] -> [N, N_class] return {C.OUTPUT: x} diff --git a/fastNLP/modules/encoder/_bert.py b/fastNLP/modules/encoder/_bert.py new file mode 100644 index 00000000..fc62ea9c --- /dev/null +++ b/fastNLP/modules/encoder/_bert.py @@ -0,0 +1,625 @@ + + + +""" +这个页面的代码很大程度上参考了https://github.com/huggingface/pytorch-pretrained-BERT的代码 +""" + + +import torch +from torch import nn + +from ... import Vocabulary +import collections + +import os +import unicodedata +from ...io.file_utils import _get_base_url, cached_path +from .bert import BertModel +import numpy as np +from itertools import chain + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, "r", encoding="utf-8") as reader: + while True: + token = reader.readline() + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, + do_lower_case=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + self.never_split = never_split + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = self._clean_text(text) + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case and token not in self.never_split: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + if text in self.never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False + + +class BertTokenizer(object): + """Runs end-to-end tokenization: punctuation splitting + wordpiece""" + + def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BertTokenizer. + + Args: + vocab_file: Path to a one-wordpiece-per-line vocabulary file + do_lower_case: Whether to lower case the input + Only has an effect when do_wordpiece_only=False + do_basic_tokenize: Whether to do basic tokenization before wordpiece. + max_len: An artificial maximum length to truncate tokenized sequences to; + Effective maximum length is always the minimum of this + value (if specified) and the underlying BERT model's + sequence length. + never_split: List of tokens which will never be split during tokenization. + Only has an effect when do_wordpiece_only=False + """ + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " + "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict( + [(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, + never_split=never_split) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + self.max_len = max_len if max_len is not None else int(1e12) + + def tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def convert_tokens_to_ids(self, tokens): + """Converts a sequence of tokens into ids using the vocab.""" + ids = [] + for token in tokens: + ids.append(self.vocab[token]) + if len(ids) > self.max_len: + print( + "Token indices sequence length is longer than the specified maximum " + " sequence length for this BERT model ({} > {}). Running this" + " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) + ) + return ids + + def convert_ids_to_tokens(self, ids): + """Converts a sequence of ids in wordpiece tokens using the vocab.""" + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary to a directory or file.""" + index = 0 + if os.path.isdir(vocab_path): + vocab_file = os.path.join(vocab_path, VOCAB_NAME) + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + print("Saving vocabulary to {}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!".format(vocab_file)) + index = token_index + writer.write(token + u'\n') + index += 1 + return vocab_file + + @classmethod + def from_pretrained(cls, model_dir, *inputs, **kwargs): + """ + 给定path,直接读取vocab. + + """ + pretrained_model_name_or_path = os.path.join(model_dir, VOCAB_NAME) + print("loading vocabulary file {}".format(pretrained_model_name_or_path)) + max_len = 512 + kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + # Instantiate tokenizer. + tokenizer = cls(pretrained_model_name_or_path, *inputs, **kwargs) + return tokenizer + +VOCAB_NAME = 'vocab.txt' + +class _WordBertModel(nn.Module): + def __init__(self, model_dir:str, vocab:Vocabulary, layers:str='-1', pool_method:str='first', include_cls_sep:bool=False): + super().__init__() + + self.tokenzier = BertTokenizer.from_pretrained(model_dir) + self.encoder = BertModel.from_pretrained(model_dir) + # 检查encoder_layer_number是否合理 + encoder_layer_number = len(self.encoder.encoder.layer) + self.layers = list(map(int, layers.split(','))) + for layer in self.layers: + if layer<0: + assert -layer<=encoder_layer_number, f"The layer index:{layer} is out of scope for " \ + f"a bert model with {encoder_layer_number} layers." + else: + assert layer