@@ -184,7 +184,7 @@ def add_r(base_path='../fastNLP'): | |||||
for f in files: | for f in files: | ||||
if f.endswith(".py"): | if f.endswith(".py"): | ||||
check_file_r(os.path.abspath(os.path.join(path,f))) | check_file_r(os.path.abspath(os.path.join(path,f))) | ||||
sys.exit(0) | |||||
# sys.exit(0) | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
fastNLP 由 :mod:`~fastNLP.core` 、 :mod:`~fastNLP.io` 、:mod:`~fastNLP.embeddings` 、 :mod:`~fastNLP.modules`、 | fastNLP 由 :mod:`~fastNLP.core` 、 :mod:`~fastNLP.io` 、:mod:`~fastNLP.embeddings` 、 :mod:`~fastNLP.modules`、 | ||||
:mod:`~fastNLP.models` 等子模块组成,你可以查看每个模块的文档。 | :mod:`~fastNLP.models` 等子模块组成,你可以查看每个模块的文档。 | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fastNLP 包中直接 import。当然你也同样可以从 core 模块的子模块中 import, | core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fastNLP 包中直接 import。当然你也同样可以从 core 模块的子模块中 import, | ||||
例如 :class:`~fastNLP.DataSetIter` 组件有两种 import 的方式:: | 例如 :class:`~fastNLP.DataSetIter` 组件有两种 import 的方式:: | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
Logger 是fastNLP中记录日志的模块,logger封装了logging模块的Logger, | Logger 是fastNLP中记录日志的模块,logger封装了logging模块的Logger, | ||||
具体使用方式与直接使用logging.Logger相同,同时也新增一些简单好用的API | 具体使用方式与直接使用logging.Logger相同,同时也新增一些简单好用的API | ||||
使用方式: | 使用方式: | ||||
@@ -124,11 +124,11 @@ class FastNLPLogger(logging.getLoggerClass()): | |||||
super().__init__(name) | super().__init__(name) | ||||
def add_file(self, path='./log.txt', level='INFO'): | def add_file(self, path='./log.txt', level='INFO'): | ||||
"""add log output file and the output level""" | |||||
r"""add log output file and the output level""" | |||||
_add_file_handler(self, path, level) | _add_file_handler(self, path, level) | ||||
def set_stdout(self, stdout='tqdm', level='INFO'): | def set_stdout(self, stdout='tqdm', level='INFO'): | ||||
"""set stdout format and the output level""" | |||||
r"""set stdout format and the output level""" | |||||
_set_stdout_handler(self, stdout, level) | _set_stdout_handler(self, stdout, level) | ||||
@@ -139,7 +139,7 @@ logging.setLoggerClass(FastNLPLogger) | |||||
# print(logging.getLogger()) | # print(logging.getLogger()) | ||||
def _init_logger(path=None, stdout='tqdm', level='INFO'): | def _init_logger(path=None, stdout='tqdm', level='INFO'): | ||||
"""initialize logger""" | |||||
r"""initialize logger""" | |||||
level = _get_level(level) | level = _get_level(level) | ||||
# logger = logging.getLogger() | # logger = logging.getLogger() | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [] | __all__ = [] | ||||
@@ -74,7 +74,7 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): | |||||
def _data_parallel_wrapper(func_name, device_ids, output_device): | def _data_parallel_wrapper(func_name, device_ids, output_device): | ||||
""" | |||||
r""" | |||||
这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数 | 这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数 | ||||
:param str, func_name: 对network中的这个函数进行多卡运行 | :param str, func_name: 对network中的这个函数进行多卡运行 | ||||
@@ -95,7 +95,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device): | |||||
def _model_contains_inner_module(model): | def _model_contains_inner_module(model): | ||||
""" | |||||
r""" | |||||
:param nn.Module model: 模型文件,判断是否内部包含model.module, 多用于check模型是否是nn.DataParallel, | :param nn.Module model: 模型文件,判断是否内部包含model.module, 多用于check模型是否是nn.DataParallel, | ||||
nn.parallel.DistributedDataParallel。主要是在做形参匹配的时候需要使用最内部的model的function。 | nn.parallel.DistributedDataParallel。主要是在做形参匹配的时候需要使用最内部的model的function。 | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
batch 模块实现了 fastNLP 所需的 :class:`~fastNLP.core.batch.DataSetIter` 类。 | batch 模块实现了 fastNLP 所需的 :class:`~fastNLP.core.batch.DataSetIter` 类。 | ||||
""" | """ | ||||
@@ -49,7 +49,7 @@ def _pad(batch_dict, dataset, as_numpy): | |||||
class DataSetGetter: | class DataSetGetter: | ||||
""" | |||||
r""" | |||||
传递给torch.utils.data.DataLoader获取数据,DataLoder会传入int的idx获取数据(调用这里的__getitem__()函数)。 | 传递给torch.utils.data.DataLoader获取数据,DataLoder会传入int的idx获取数据(调用这里的__getitem__()函数)。 | ||||
""" | """ | ||||
def __init__(self, dataset: DataSet, as_numpy=False): | def __init__(self, dataset: DataSet, as_numpy=False): | ||||
@@ -70,7 +70,7 @@ class DataSetGetter: | |||||
return len(self.dataset) | return len(self.dataset) | ||||
def collate_fn(self, ins_list: list): | def collate_fn(self, ins_list: list): | ||||
""" | |||||
r""" | |||||
:param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]] | :param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]] | ||||
:return: | :return: | ||||
@@ -104,7 +104,7 @@ class DataSetGetter: | |||||
class SamplerAdapter(torch.utils.data.Sampler): | class SamplerAdapter(torch.utils.data.Sampler): | ||||
""" | |||||
r""" | |||||
用于传入torch.utils.data.DataLoader中,DataLoader会调用__iter__()方法获取index(一次只取一个int) | 用于传入torch.utils.data.DataLoader中,DataLoader会调用__iter__()方法获取index(一次只取一个int) | ||||
""" | """ | ||||
@@ -121,7 +121,7 @@ class SamplerAdapter(torch.utils.data.Sampler): | |||||
class BatchIter: | class BatchIter: | ||||
""" | |||||
r""" | |||||
Trainer用于迭代数据的类。继承该类,并实现get_num_batches(), get_batch_indices(), num_batches(), __iter__()方法以及dataset属性。 | Trainer用于迭代数据的类。继承该类,并实现get_num_batches(), get_batch_indices(), num_batches(), __iter__()方法以及dataset属性。 | ||||
""" | """ | ||||
@@ -166,7 +166,7 @@ class BatchIter: | |||||
@staticmethod | @staticmethod | ||||
def get_num_batches(num_samples, batch_size, drop_last): | def get_num_batches(num_samples, batch_size, drop_last): | ||||
""" | |||||
r""" | |||||
计算batch的数量。用于前端显示进度 | 计算batch的数量。用于前端显示进度 | ||||
:param int num_samples: | :param int num_samples: | ||||
@@ -180,7 +180,7 @@ class BatchIter: | |||||
return num_batches | return num_batches | ||||
def get_batch_indices(self): | def get_batch_indices(self): | ||||
""" | |||||
r""" | |||||
获取最近输出的batch的index。用于溯源当前batch的数据 | 获取最近输出的batch的index。用于溯源当前batch的数据 | ||||
:return: | :return: | ||||
@@ -192,7 +192,7 @@ class BatchIter: | |||||
@property | @property | ||||
def dataset(self): | def dataset(self): | ||||
""" | |||||
r""" | |||||
获取正在参与iterate的dataset | 获取正在参与iterate的dataset | ||||
:return: | :return: | ||||
@@ -201,7 +201,7 @@ class BatchIter: | |||||
@abc.abstractmethod | @abc.abstractmethod | ||||
def __iter__(self): | def __iter__(self): | ||||
""" | |||||
r""" | |||||
用于实际数据循环的类,返回值需要为两个dict, 第一个dict中的内容会认为是input, 第二个dict中的内容会认为是target | 用于实际数据循环的类,返回值需要为两个dict, 第一个dict中的内容会认为是input, 第二个dict中的内容会认为是target | ||||
:return: | :return: | ||||
@@ -210,7 +210,7 @@ class BatchIter: | |||||
class DataSetIter(BatchIter): | class DataSetIter(BatchIter): | ||||
""" | |||||
r""" | |||||
DataSetIter 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出, | DataSetIter 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出, | ||||
组成 `x` 和 `y`:: | 组成 `x` 和 `y`:: | ||||
@@ -223,7 +223,7 @@ class DataSetIter(BatchIter): | |||||
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, | def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, | ||||
num_workers=0, pin_memory=False, drop_last=False, | num_workers=0, pin_memory=False, drop_last=False, | ||||
timeout=0, worker_init_fn=None, collate_fn=None): | timeout=0, worker_init_fn=None, collate_fn=None): | ||||
""" | |||||
r""" | |||||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | :param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | ||||
:param int batch_size: 取出的batch大小 | :param int batch_size: 取出的batch大小 | ||||
@@ -258,7 +258,7 @@ class DataSetIter(BatchIter): | |||||
class TorchLoaderIter(BatchIter): | class TorchLoaderIter(BatchIter): | ||||
""" | |||||
r""" | |||||
与DataSetIter类似,但可以用于非fastNLP的数据容器对象,然后将其传入到Trainer中。 | 与DataSetIter类似,但可以用于非fastNLP的数据容器对象,然后将其传入到Trainer中。 | ||||
只需要保证数据容器实现了实现了以下的方法 | 只需要保证数据容器实现了实现了以下的方法 | ||||
@@ -387,7 +387,7 @@ class TorchLoaderIter(BatchIter): | |||||
def __init__(self, dataset, batch_size=1, sampler=None, | def __init__(self, dataset, batch_size=1, sampler=None, | ||||
num_workers=0, pin_memory=False, drop_last=False, | num_workers=0, pin_memory=False, drop_last=False, | ||||
timeout=0, worker_init_fn=None, collate_fn=None): | timeout=0, worker_init_fn=None, collate_fn=None): | ||||
""" | |||||
r""" | |||||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | :param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | ||||
:param int batch_size: 取出的batch大小 | :param int batch_size: 取出的batch大小 | ||||
@@ -421,7 +421,7 @@ class TorchLoaderIter(BatchIter): | |||||
def _to_tensor(batch, field_dtype): | def _to_tensor(batch, field_dtype): | ||||
""" | |||||
r""" | |||||
:param batch: np.array() | :param batch: np.array() | ||||
:param field_dtype: 数据类型 | :param field_dtype: 数据类型 | ||||
@@ -101,7 +101,7 @@ except: | |||||
class Callback(object): | class Callback(object): | ||||
""" | |||||
r""" | |||||
Callback是fastNLP中被设计用于增强 :class:`~fastNLP.Trainer` 的类。 | Callback是fastNLP中被设计用于增强 :class:`~fastNLP.Trainer` 的类。 | ||||
如果Callback被传递给了 Trainer , 则 Trainer 会在对应的阶段调用Callback的函数, | 如果Callback被传递给了 Trainer , 则 Trainer 会在对应的阶段调用Callback的函数, | ||||
具体调用时机可以通过 :mod:`trainer 模块<fastNLP.core.trainer>` 查看。 | 具体调用时机可以通过 :mod:`trainer 模块<fastNLP.core.trainer>` 查看。 | ||||
@@ -116,60 +116,60 @@ class Callback(object): | |||||
@property | @property | ||||
def trainer(self): | def trainer(self): | ||||
""" | |||||
r""" | |||||
该属性可以通过self.trainer获取到,一般情况下不需要使用这个属性。 | 该属性可以通过self.trainer获取到,一般情况下不需要使用这个属性。 | ||||
""" | """ | ||||
return self._trainer | return self._trainer | ||||
@property | @property | ||||
def step(self): | def step(self): | ||||
"""当前运行到的step, 范围为[1, self.n_steps+1)""" | |||||
r"""当前运行到的step, 范围为[1, self.n_steps+1)""" | |||||
return self._trainer.step | return self._trainer.step | ||||
@property | @property | ||||
def n_steps(self): | def n_steps(self): | ||||
"""Trainer一共会采多少个batch。当Trainer中update_every设置为非1的值时,该值不等于update的次数""" | |||||
r"""Trainer一共会采多少个batch。当Trainer中update_every设置为非1的值时,该值不等于update的次数""" | |||||
return self._trainer.n_steps | return self._trainer.n_steps | ||||
@property | @property | ||||
def batch_size(self): | def batch_size(self): | ||||
"""train和evaluate时的batch_size为多大""" | |||||
r"""train和evaluate时的batch_size为多大""" | |||||
return self._trainer.batch_size | return self._trainer.batch_size | ||||
@property | @property | ||||
def epoch(self): | def epoch(self): | ||||
"""当前运行的epoch数,范围是[1, self.n_epochs+1)""" | |||||
r"""当前运行的epoch数,范围是[1, self.n_epochs+1)""" | |||||
return self._trainer.epoch | return self._trainer.epoch | ||||
@property | @property | ||||
def n_epochs(self): | def n_epochs(self): | ||||
"""一共会运行多少个epoch""" | |||||
r"""一共会运行多少个epoch""" | |||||
return self._trainer.n_epochs | return self._trainer.n_epochs | ||||
@property | @property | ||||
def optimizer(self): | def optimizer(self): | ||||
"""初始化Trainer时传递的Optimizer""" | |||||
r"""初始化Trainer时传递的Optimizer""" | |||||
return self._trainer.optimizer | return self._trainer.optimizer | ||||
@property | @property | ||||
def model(self): | def model(self): | ||||
"""正在被Trainer训练的模型""" | |||||
r"""正在被Trainer训练的模型""" | |||||
return self._trainer.model | return self._trainer.model | ||||
@property | @property | ||||
def pbar(self): | def pbar(self): | ||||
"""如果在Callback中需要打印内容,请使用self.pbar.write(str)。否则可能出现命令行显示效果不太好的问题。在 | |||||
r"""如果在Callback中需要打印内容,请使用self.pbar.write(str)。否则可能出现命令行显示效果不太好的问题。在 | |||||
on_train_begin(), on_train_end(), on_exception()中请不要使用该属性,通过print输出即可。""" | on_train_begin(), on_train_end(), on_exception()中请不要使用该属性,通过print输出即可。""" | ||||
return self._trainer.pbar | return self._trainer.pbar | ||||
@property | @property | ||||
def update_every(self): | def update_every(self): | ||||
"""Trainer中的模型多少次反向传播才进行一次梯度更新,在Trainer初始化时传入的。""" | |||||
r"""Trainer中的模型多少次反向传播才进行一次梯度更新,在Trainer初始化时传入的。""" | |||||
return self._trainer.update_every | return self._trainer.update_every | ||||
@property | @property | ||||
def batch_per_epoch(self): | def batch_per_epoch(self): | ||||
"""每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。""" | |||||
r"""每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。""" | |||||
return self._trainer.batch_per_epoch | return self._trainer.batch_per_epoch | ||||
@property | @property | ||||
@@ -185,7 +185,7 @@ class Callback(object): | |||||
return getattr(self._trainer, 'logger', logger) | return getattr(self._trainer, 'logger', logger) | ||||
def on_train_begin(self): | def on_train_begin(self): | ||||
""" | |||||
r""" | |||||
在Train过程开始之前调用。 | 在Train过程开始之前调用。 | ||||
:return: | :return: | ||||
@@ -193,7 +193,7 @@ class Callback(object): | |||||
pass | pass | ||||
def on_epoch_begin(self): | def on_epoch_begin(self): | ||||
""" | |||||
r""" | |||||
在每个epoch开始之前调用一次 | 在每个epoch开始之前调用一次 | ||||
:return: | :return: | ||||
@@ -201,7 +201,7 @@ class Callback(object): | |||||
pass | pass | ||||
def on_batch_begin(self, batch_x, batch_y, indices): | def on_batch_begin(self, batch_x, batch_y, indices): | ||||
""" | |||||
r""" | |||||
每次采集到一个batch的数据则调用一次。这里对batch_x或batch_y删除添加内容是可以影响到Trainer中内容的。所以在这一步 | 每次采集到一个batch的数据则调用一次。这里对batch_x或batch_y删除添加内容是可以影响到Trainer中内容的。所以在这一步 | ||||
可以进行一些负采样之类的操作 | 可以进行一些负采样之类的操作 | ||||
@@ -214,7 +214,7 @@ class Callback(object): | |||||
pass | pass | ||||
def on_loss_begin(self, batch_y, predict_y): | def on_loss_begin(self, batch_y, predict_y): | ||||
""" | |||||
r""" | |||||
在计算loss前调用,即这里修改batch_y或predict_y的值是可以影响到loss计算的。 | 在计算loss前调用,即这里修改batch_y或predict_y的值是可以影响到loss计算的。 | ||||
:param dict batch_y: 在DataSet中被设置为target的field的batch集合。 | :param dict batch_y: 在DataSet中被设置为target的field的batch集合。 | ||||
@@ -224,7 +224,7 @@ class Callback(object): | |||||
pass | pass | ||||
def on_backward_begin(self, loss): | def on_backward_begin(self, loss): | ||||
""" | |||||
r""" | |||||
在loss得到之后,但在反向传播之前。可能可以进行loss是否为NaN的检查。 | 在loss得到之后,但在反向传播之前。可能可以进行loss是否为NaN的检查。 | ||||
:param torch.Tensor loss: 计算得到的loss值 | :param torch.Tensor loss: 计算得到的loss值 | ||||
@@ -233,7 +233,7 @@ class Callback(object): | |||||
pass | pass | ||||
def on_backward_end(self): | def on_backward_end(self): | ||||
""" | |||||
r""" | |||||
反向梯度传播已完成,但由于update_every的设置,可能并不是每一次调用都有梯度。到这一步,还没有更新参数。 | 反向梯度传播已完成,但由于update_every的设置,可能并不是每一次调用都有梯度。到这一步,还没有更新参数。 | ||||
:return: | :return: | ||||
@@ -241,7 +241,7 @@ class Callback(object): | |||||
pass | pass | ||||
def on_step_end(self): | def on_step_end(self): | ||||
""" | |||||
r""" | |||||
到这里模型的参数已经按照梯度更新。但可能受update_every影响,并不是每次都更新了。 | 到这里模型的参数已经按照梯度更新。但可能受update_every影响,并不是每次都更新了。 | ||||
:return: | :return: | ||||
@@ -249,14 +249,14 @@ class Callback(object): | |||||
pass | pass | ||||
def on_batch_end(self): | def on_batch_end(self): | ||||
""" | |||||
r""" | |||||
这一步与on_step_end是紧接着的。只是为了对称性加上了这一步。 | 这一步与on_step_end是紧接着的。只是为了对称性加上了这一步。 | ||||
""" | """ | ||||
pass | pass | ||||
def on_valid_begin(self): | def on_valid_begin(self): | ||||
""" | |||||
r""" | |||||
如果Trainer中设置了验证,则发生验证前会调用该函数 | 如果Trainer中设置了验证,则发生验证前会调用该函数 | ||||
:return: | :return: | ||||
@@ -264,7 +264,7 @@ class Callback(object): | |||||
pass | pass | ||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | ||||
""" | |||||
r""" | |||||
每次执行验证集的evaluation后会调用。 | 每次执行验证集的evaluation后会调用。 | ||||
:param Dict[str: Dict[str: float]] eval_result: , evaluation的结果。一个例子为{'AccuracyMetric':{'acc':1.0}},即 | :param Dict[str: Dict[str: float]] eval_result: , evaluation的结果。一个例子为{'AccuracyMetric':{'acc':1.0}},即 | ||||
@@ -277,19 +277,19 @@ class Callback(object): | |||||
pass | pass | ||||
def on_epoch_end(self): | def on_epoch_end(self): | ||||
""" | |||||
r""" | |||||
每个epoch结束将会调用该方法 | 每个epoch结束将会调用该方法 | ||||
""" | """ | ||||
pass | pass | ||||
def on_train_end(self): | def on_train_end(self): | ||||
""" | |||||
r""" | |||||
训练结束,调用该方法 | 训练结束,调用该方法 | ||||
""" | """ | ||||
pass | pass | ||||
def on_exception(self, exception): | def on_exception(self, exception): | ||||
""" | |||||
r""" | |||||
当训练过程出现异常,会触发该方法 | 当训练过程出现异常,会触发该方法 | ||||
:param exception: 某种类型的Exception,比如KeyboardInterrupt等 | :param exception: 某种类型的Exception,比如KeyboardInterrupt等 | ||||
""" | """ | ||||
@@ -297,7 +297,7 @@ class Callback(object): | |||||
def _transfer(func): | def _transfer(func): | ||||
"""装饰器,将对CallbackManager的调用转发到各个Callback子类. | |||||
r"""装饰器,将对CallbackManager的调用转发到各个Callback子类. | |||||
:param func: | :param func: | ||||
:return: | :return: | ||||
@@ -315,11 +315,11 @@ def _transfer(func): | |||||
class CallbackManager(Callback): | class CallbackManager(Callback): | ||||
""" | |||||
r""" | |||||
内部使用的Callback管理类 | 内部使用的Callback管理类 | ||||
""" | """ | ||||
def __init__(self, env, callbacks=None): | def __init__(self, env, callbacks=None): | ||||
""" | |||||
r""" | |||||
:param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | :param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | ||||
:param List[Callback] callbacks: | :param List[Callback] callbacks: | ||||
@@ -433,12 +433,12 @@ class DistCallbackManager(CallbackManager): | |||||
class GradientClipCallback(Callback): | class GradientClipCallback(Callback): | ||||
""" | |||||
r""" | |||||
每次backward前,将parameter的gradient clip到某个范围。 | 每次backward前,将parameter的gradient clip到某个范围。 | ||||
""" | """ | ||||
def __init__(self, parameters=None, clip_value=1, clip_type='norm'): | def __init__(self, parameters=None, clip_value=1, clip_type='norm'): | ||||
""" | |||||
r""" | |||||
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。 | :param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。 | ||||
如果为None则默认对Trainer的model中所有参数进行clip | 如果为None则默认对Trainer的model中所有参数进行clip | ||||
@@ -477,12 +477,12 @@ class GradientClipCallback(Callback): | |||||
class EarlyStopCallback(Callback): | class EarlyStopCallback(Callback): | ||||
""" | |||||
r""" | |||||
多少个epoch没有变好就停止训练,相关类 :class:`~fastNLP.core.callback.EarlyStopError` | 多少个epoch没有变好就停止训练,相关类 :class:`~fastNLP.core.callback.EarlyStopError` | ||||
""" | """ | ||||
def __init__(self, patience): | def __init__(self, patience): | ||||
""" | |||||
r""" | |||||
:param int patience: epoch的数量 | :param int patience: epoch的数量 | ||||
""" | """ | ||||
@@ -508,7 +508,7 @@ class EarlyStopCallback(Callback): | |||||
class FitlogCallback(Callback): | class FitlogCallback(Callback): | ||||
""" | |||||
r""" | |||||
该callback可将loss和progress写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入 | 该callback可将loss和progress写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入 | ||||
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。 | 一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。 | ||||
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 | 并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 | ||||
@@ -516,7 +516,7 @@ class FitlogCallback(Callback): | |||||
""" | """ | ||||
def __init__(self, data=None, tester=None, log_loss_every=0, verbose=0, log_exception=False): | def __init__(self, data=None, tester=None, log_loss_every=0, verbose=0, log_exception=False): | ||||
""" | |||||
r""" | |||||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要 | :param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要 | ||||
传入多个DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。data的结果的名称以'data'开头。 | 传入多个DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。data的结果的名称以'data'开头。 | ||||
@@ -608,13 +608,13 @@ class FitlogCallback(Callback): | |||||
class EvaluateCallback(Callback): | class EvaluateCallback(Callback): | ||||
""" | |||||
r""" | |||||
通过使用该Callback可以使得Trainer在evaluate dev之外还可以evaluate其它数据集,比如测试集。每一次验证dev之前都会先验证EvaluateCallback | 通过使用该Callback可以使得Trainer在evaluate dev之外还可以evaluate其它数据集,比如测试集。每一次验证dev之前都会先验证EvaluateCallback | ||||
中的数据。 | 中的数据。 | ||||
""" | """ | ||||
def __init__(self, data=None, tester=None): | def __init__(self, data=None, tester=None): | ||||
""" | |||||
r""" | |||||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用Trainer中的metric对数据进行验证。如果需要传入多个 | :param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用Trainer中的metric对数据进行验证。如果需要传入多个 | ||||
DataSet请通过dict的方式传入。 | DataSet请通过dict的方式传入。 | ||||
:param ~fastNLP.Tester,Dict[~fastNLP.DataSet] tester: Tester对象, 通过使用Tester对象,可以使得验证的metric与Trainer中 | :param ~fastNLP.Tester,Dict[~fastNLP.DataSet] tester: Tester对象, 通过使用Tester对象,可以使得验证的metric与Trainer中 | ||||
@@ -668,12 +668,12 @@ class EvaluateCallback(Callback): | |||||
raise e | raise e | ||||
class LRScheduler(Callback): | class LRScheduler(Callback): | ||||
""" | |||||
r""" | |||||
对PyTorch LR Scheduler的包装以使得其可以被Trainer所使用 | 对PyTorch LR Scheduler的包装以使得其可以被Trainer所使用 | ||||
""" | """ | ||||
def __init__(self, lr_scheduler): | def __init__(self, lr_scheduler): | ||||
""" | |||||
r""" | |||||
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler | :param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler | ||||
""" | """ | ||||
super(LRScheduler, self).__init__() | super(LRScheduler, self).__init__() | ||||
@@ -688,12 +688,12 @@ class LRScheduler(Callback): | |||||
class ControlC(Callback): | class ControlC(Callback): | ||||
""" | |||||
r""" | |||||
检测到 control+C 时的反馈 | 检测到 control+C 时的反馈 | ||||
""" | """ | ||||
def __init__(self, quit_all): | def __init__(self, quit_all): | ||||
""" | |||||
r""" | |||||
:param bool quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer | :param bool quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer | ||||
""" | """ | ||||
super(ControlC, self).__init__() | super(ControlC, self).__init__() | ||||
@@ -713,26 +713,26 @@ class ControlC(Callback): | |||||
class SmoothValue(object): | class SmoothValue(object): | ||||
"""work for LRFinder""" | |||||
r"""work for LRFinder""" | |||||
def __init__(self, beta: float): | def __init__(self, beta: float): | ||||
self.beta, self.n, self.mov_avg = beta, 0, 0 | self.beta, self.n, self.mov_avg = beta, 0, 0 | ||||
self.smooth = None | self.smooth = None | ||||
def add_value(self, val: float) -> None: | def add_value(self, val: float) -> None: | ||||
"""Add `val` to calculate updated smoothed value.""" | |||||
r"""Add `val` to calculate updated smoothed value.""" | |||||
self.n += 1 | self.n += 1 | ||||
self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val | self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val | ||||
self.smooth = self.mov_avg / (1 - self.beta ** self.n) | self.smooth = self.mov_avg / (1 - self.beta ** self.n) | ||||
class LRFinder(Callback): | class LRFinder(Callback): | ||||
""" | |||||
r""" | |||||
用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 | 用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 | ||||
""" | """ | ||||
def __init__(self, start_lr=1e-6, end_lr=10): | def __init__(self, start_lr=1e-6, end_lr=10): | ||||
""" | |||||
r""" | |||||
:param float start_lr: 学习率下界 | :param float start_lr: 学习率下界 | ||||
:param float end_lr: 学习率上界 | :param float end_lr: 学习率上界 | ||||
@@ -798,7 +798,7 @@ class LRFinder(Callback): | |||||
class TensorboardCallback(Callback): | class TensorboardCallback(Callback): | ||||
""" | |||||
r""" | |||||
接受以下一个或多个字符串作为参数: | 接受以下一个或多个字符串作为参数: | ||||
- "model" | - "model" | ||||
- "loss" | - "loss" | ||||
@@ -873,7 +873,7 @@ class TensorboardCallback(Callback): | |||||
class CheckPointCallback(Callback): | class CheckPointCallback(Callback): | ||||
def __init__(self, save_path, delete_when_train_finish=True, recovery_fitlog=True): | def __init__(self, save_path, delete_when_train_finish=True, recovery_fitlog=True): | ||||
""" | |||||
r""" | |||||
用于在每个epoch结束的时候保存一下当前的Trainer状态,可以用于恢复之前的运行。使用最近的一个epoch继续训练 | 用于在每个epoch结束的时候保存一下当前的Trainer状态,可以用于恢复之前的运行。使用最近的一个epoch继续训练 | ||||
一段示例代码 | 一段示例代码 | ||||
Example1:: | Example1:: | ||||
@@ -918,7 +918,7 @@ class CheckPointCallback(Callback): | |||||
logger.error("Fail to recovery the fitlog states.") | logger.error("Fail to recovery the fitlog states.") | ||||
def on_train_begin(self): | def on_train_begin(self): | ||||
""" | |||||
r""" | |||||
当train开始时,且需要恢复上次训练时,会做以下的操作 | 当train开始时,且需要恢复上次训练时,会做以下的操作 | ||||
(1) 重新加载model权重 | (1) 重新加载model权重 | ||||
(2) 重新加载optimizer的状态 | (2) 重新加载optimizer的状态 | ||||
@@ -944,7 +944,7 @@ class CheckPointCallback(Callback): | |||||
self.trainer.best_metric_indicator = states['best_metric_indicator'] | self.trainer.best_metric_indicator = states['best_metric_indicator'] | ||||
def on_epoch_end(self): | def on_epoch_end(self): | ||||
""" | |||||
r""" | |||||
保存状态,使得结果可以被恢复 | 保存状态,使得结果可以被恢复 | ||||
:param self: | :param self: | ||||
@@ -984,11 +984,11 @@ class CheckPointCallback(Callback): | |||||
class WarmupCallback(Callback): | class WarmupCallback(Callback): | ||||
""" | |||||
r""" | |||||
learning rate按照一定的速率从0上升到设置的learning rate。 | learning rate按照一定的速率从0上升到设置的learning rate。 | ||||
""" | """ | ||||
def __init__(self, warmup=0.1, schedule='constant'): | def __init__(self, warmup=0.1, schedule='constant'): | ||||
""" | |||||
r""" | |||||
:param int,float warmup: 如果warmup为int,则在该step之前,learning rate根据schedule的策略变化; 如果warmup为float, | :param int,float warmup: 如果warmup为int,则在该step之前,learning rate根据schedule的策略变化; 如果warmup为float, | ||||
如0.1, 则前10%的step是按照schedule策略调整learning rate。 | 如0.1, 则前10%的step是按照schedule策略调整learning rate。 | ||||
@@ -1035,7 +1035,7 @@ class WarmupCallback(Callback): | |||||
class SaveModelCallback(Callback): | class SaveModelCallback(Callback): | ||||
""" | |||||
r""" | |||||
由于Trainer在训练过程中只会保存最佳的模型, 该callback可实现多种方式的结果存储。 | 由于Trainer在训练过程中只会保存最佳的模型, 该callback可实现多种方式的结果存储。 | ||||
会根据训练开始的时间戳在save_dir下建立文件夹,再在文件夹下存放多个模型:: | 会根据训练开始的时间戳在save_dir下建立文件夹,再在文件夹下存放多个模型:: | ||||
@@ -1047,7 +1047,7 @@ class SaveModelCallback(Callback): | |||||
-epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能 | -epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能 | ||||
""" | """ | ||||
def __init__(self, save_dir, top=3, only_param=False, save_on_exception=False): | def __init__(self, save_dir, top=3, only_param=False, save_on_exception=False): | ||||
""" | |||||
r""" | |||||
:param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型。如果save_dir不存在将自动创建 | :param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型。如果save_dir不存在将自动创建 | ||||
:param int top: 保存dev表现top多少模型。-1为保存所有模型。 | :param int top: 保存dev表现top多少模型。-1为保存所有模型。 | ||||
@@ -1116,12 +1116,12 @@ class SaveModelCallback(Callback): | |||||
class CallbackException(BaseException): | class CallbackException(BaseException): | ||||
""" | |||||
r""" | |||||
当需要通过callback跳出训练的时候可以通过抛出CallbackException并在on_exception中捕获这个值。 | 当需要通过callback跳出训练的时候可以通过抛出CallbackException并在on_exception中捕获这个值。 | ||||
""" | """ | ||||
def __init__(self, msg): | def __init__(self, msg): | ||||
""" | |||||
r""" | |||||
:param str msg: Exception的信息。 | :param str msg: Exception的信息。 | ||||
""" | """ | ||||
@@ -1129,7 +1129,7 @@ class CallbackException(BaseException): | |||||
class EarlyStopError(CallbackException): | class EarlyStopError(CallbackException): | ||||
""" | |||||
r""" | |||||
用于EarlyStop时从Trainer训练循环中跳出。 | 用于EarlyStop时从Trainer训练循环中跳出。 | ||||
""" | """ | ||||
@@ -1139,7 +1139,7 @@ class EarlyStopError(CallbackException): | |||||
class EchoCallback(Callback): | class EchoCallback(Callback): | ||||
""" | |||||
r""" | |||||
用于测试分布式训练 | 用于测试分布式训练 | ||||
""" | """ | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
from builtins import sorted | from builtins import sorted | ||||
import torch | import torch | ||||
@@ -37,7 +37,7 @@ def batching(samples, max_len=0, padding_val=0): | |||||
class Collector: | class Collector: | ||||
""" | |||||
r""" | |||||
辅助DataSet管理collect_fn的类 | 辅助DataSet管理collect_fn的类 | ||||
""" | """ | ||||
@@ -45,7 +45,7 @@ class Collector: | |||||
self.collect_fns = {} | self.collect_fns = {} | ||||
def add_fn(self, fn, name=None): | def add_fn(self, fn, name=None): | ||||
""" | |||||
r""" | |||||
向collector新增一个collect_fn函数 | 向collector新增一个collect_fn函数 | ||||
:param callable fn: | :param callable fn: | ||||
@@ -59,7 +59,7 @@ class Collector: | |||||
self.collect_fns[name] = fn | self.collect_fns[name] = fn | ||||
def is_empty(self): | def is_empty(self): | ||||
""" | |||||
r""" | |||||
返回是否包含collect_fn | 返回是否包含collect_fn | ||||
:return: | :return: | ||||
@@ -67,7 +67,7 @@ class Collector: | |||||
return len(self.collect_fns)==0 | return len(self.collect_fns)==0 | ||||
def delete_fn(self, name=None): | def delete_fn(self, name=None): | ||||
""" | |||||
r""" | |||||
删除collect_fn | 删除collect_fn | ||||
:param str,int name: 如果为None就删除最近加入的collect_fn | :param str,int name: 如果为None就删除最近加入的collect_fn | ||||
@@ -100,7 +100,7 @@ class Collector: | |||||
class ConcatCollectFn: | class ConcatCollectFn: | ||||
""" | |||||
r""" | |||||
field拼接collect_fn,将不同field按序拼接后,padding产生数据。 | field拼接collect_fn,将不同field按序拼接后,padding产生数据。 | ||||
:param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field | :param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field | ||||
@@ -8,7 +8,7 @@ __all__ = [ | |||||
class Const: | class Const: | ||||
""" | |||||
r""" | |||||
fastNLP中field命名常量。 | fastNLP中field命名常量。 | ||||
.. todo:: | .. todo:: | ||||
@@ -37,48 +37,48 @@ class Const: | |||||
@staticmethod | @staticmethod | ||||
def INPUTS(i): | def INPUTS(i): | ||||
"""得到第 i 个 ``INPUT`` 的命名""" | |||||
r"""得到第 i 个 ``INPUT`` 的命名""" | |||||
i = int(i) + 1 | i = int(i) + 1 | ||||
return Const.INPUT + str(i) | return Const.INPUT + str(i) | ||||
@staticmethod | @staticmethod | ||||
def CHAR_INPUTS(i): | def CHAR_INPUTS(i): | ||||
"""得到第 i 个 ``CHAR_INPUT`` 的命名""" | |||||
r"""得到第 i 个 ``CHAR_INPUT`` 的命名""" | |||||
i = int(i) + 1 | i = int(i) + 1 | ||||
return Const.CHAR_INPUT + str(i) | return Const.CHAR_INPUT + str(i) | ||||
@staticmethod | @staticmethod | ||||
def RAW_WORDS(i): | def RAW_WORDS(i): | ||||
"""得到第 i 个 ``RAW_WORDS`` 的命名""" | |||||
r"""得到第 i 个 ``RAW_WORDS`` 的命名""" | |||||
i = int(i) + 1 | i = int(i) + 1 | ||||
return Const.RAW_WORD + str(i) | return Const.RAW_WORD + str(i) | ||||
@staticmethod | @staticmethod | ||||
def RAW_CHARS(i): | def RAW_CHARS(i): | ||||
"""得到第 i 个 ``RAW_CHARS`` 的命名""" | |||||
r"""得到第 i 个 ``RAW_CHARS`` 的命名""" | |||||
i = int(i) + 1 | i = int(i) + 1 | ||||
return Const.RAW_CHAR + str(i) | return Const.RAW_CHAR + str(i) | ||||
@staticmethod | @staticmethod | ||||
def INPUT_LENS(i): | def INPUT_LENS(i): | ||||
"""得到第 i 个 ``INPUT_LEN`` 的命名""" | |||||
r"""得到第 i 个 ``INPUT_LEN`` 的命名""" | |||||
i = int(i) + 1 | i = int(i) + 1 | ||||
return Const.INPUT_LEN + str(i) | return Const.INPUT_LEN + str(i) | ||||
@staticmethod | @staticmethod | ||||
def OUTPUTS(i): | def OUTPUTS(i): | ||||
"""得到第 i 个 ``OUTPUT`` 的命名""" | |||||
r"""得到第 i 个 ``OUTPUT`` 的命名""" | |||||
i = int(i) + 1 | i = int(i) + 1 | ||||
return Const.OUTPUT + str(i) | return Const.OUTPUT + str(i) | ||||
@staticmethod | @staticmethod | ||||
def TARGETS(i): | def TARGETS(i): | ||||
"""得到第 i 个 ``TARGET`` 的命名""" | |||||
r"""得到第 i 个 ``TARGET`` 的命名""" | |||||
i = int(i) + 1 | i = int(i) + 1 | ||||
return Const.TARGET + str(i) | return Const.TARGET + str(i) | ||||
@staticmethod | @staticmethod | ||||
def LOSSES(i): | def LOSSES(i): | ||||
"""得到第 i 个 ``LOSS`` 的命名""" | |||||
r"""得到第 i 个 ``LOSS`` 的命名""" | |||||
i = int(i) + 1 | i = int(i) + 1 | ||||
return Const.LOSS + str(i) | return Const.LOSS + str(i) |
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
:class:`~fastNLP.core.dataset.DataSet` 是fastNLP中用于承载数据的容器。可以将DataSet看做是一个表格, | :class:`~fastNLP.core.dataset.DataSet` 是fastNLP中用于承载数据的容器。可以将DataSet看做是一个表格, | ||||
每一行是一个sample (在fastNLP中被称为 :mod:`~fastNLP.core.instance` ), | 每一行是一个sample (在fastNLP中被称为 :mod:`~fastNLP.core.instance` ), | ||||
每一列是一个feature (在fastNLP中称为 :mod:`~fastNLP.core.field` )。 | 每一列是一个feature (在fastNLP中称为 :mod:`~fastNLP.core.field` )。 | ||||
@@ -380,12 +380,12 @@ class ApplyResultException(Exception): | |||||
self.index = index # 标示在哪个数据遭遇到问题了 | self.index = index # 标示在哪个数据遭遇到问题了 | ||||
class DataSet(object): | class DataSet(object): | ||||
""" | |||||
r""" | |||||
fastNLP的数据容器,详细的使用方法见文档 :mod:`fastNLP.core.dataset` | fastNLP的数据容器,详细的使用方法见文档 :mod:`fastNLP.core.dataset` | ||||
""" | """ | ||||
def __init__(self, data=None): | def __init__(self, data=None): | ||||
""" | |||||
r""" | |||||
:param data: 如果为dict类型,则每个key的value应该为等长的list; 如果为list, | :param data: 如果为dict类型,则每个key的value应该为等长的list; 如果为list, | ||||
每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 | 每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 | ||||
@@ -447,7 +447,7 @@ class DataSet(object): | |||||
return inner_iter_func() | return inner_iter_func() | ||||
def __getitem__(self, idx): | def __getitem__(self, idx): | ||||
"""给定int的index,返回一个Instance; 给定slice,返回包含这个slice内容的新的DataSet。 | |||||
r"""给定int的index,返回一个Instance; 给定slice,返回包含这个slice内容的新的DataSet。 | |||||
:param idx: can be int or slice. | :param idx: can be int or slice. | ||||
:return: If `idx` is int, return an Instance object. | :return: If `idx` is int, return an Instance object. | ||||
@@ -495,7 +495,7 @@ class DataSet(object): | |||||
return self.__dict__ | return self.__dict__ | ||||
def __len__(self): | def __len__(self): | ||||
"""Fetch the length of the dataset. | |||||
r"""Fetch the length of the dataset. | |||||
:return length: | :return length: | ||||
""" | """ | ||||
@@ -508,7 +508,7 @@ class DataSet(object): | |||||
return str(pretty_table_printer(self)) | return str(pretty_table_printer(self)) | ||||
def print_field_meta(self): | def print_field_meta(self): | ||||
""" | |||||
r""" | |||||
输出当前field的meta信息, 形似下列的输出:: | 输出当前field的meta信息, 形似下列的输出:: | ||||
+-------------+-------+-------+ | +-------------+-------+-------+ | ||||
@@ -564,7 +564,7 @@ class DataSet(object): | |||||
return table | return table | ||||
def append(self, instance): | def append(self, instance): | ||||
""" | |||||
r""" | |||||
将一个instance对象append到DataSet后面。 | 将一个instance对象append到DataSet后面。 | ||||
:param ~fastNLP.Instance instance: 若DataSet不为空,则instance应该拥有和DataSet完全一样的field。 | :param ~fastNLP.Instance instance: 若DataSet不为空,则instance应该拥有和DataSet完全一样的field。 | ||||
@@ -589,7 +589,7 @@ class DataSet(object): | |||||
raise e | raise e | ||||
def add_fieldarray(self, field_name, fieldarray): | def add_fieldarray(self, field_name, fieldarray): | ||||
""" | |||||
r""" | |||||
将fieldarray添加到DataSet中. | 将fieldarray添加到DataSet中. | ||||
:param str field_name: 新加入的field的名称 | :param str field_name: 新加入的field的名称 | ||||
@@ -604,7 +604,7 @@ class DataSet(object): | |||||
self.field_arrays[field_name] = fieldarray | self.field_arrays[field_name] = fieldarray | ||||
def add_field(self, field_name, fields, padder=AutoPadder(), is_input=False, is_target=False, ignore_type=False): | def add_field(self, field_name, fields, padder=AutoPadder(), is_input=False, is_target=False, ignore_type=False): | ||||
""" | |||||
r""" | |||||
新增一个field | 新增一个field | ||||
:param str field_name: 新增的field的名称 | :param str field_name: 新增的field的名称 | ||||
@@ -623,7 +623,7 @@ class DataSet(object): | |||||
padder=padder, ignore_type=ignore_type) | padder=padder, ignore_type=ignore_type) | ||||
def delete_instance(self, index): | def delete_instance(self, index): | ||||
""" | |||||
r""" | |||||
删除第index个instance | 删除第index个instance | ||||
:param int index: 需要删除的instance的index,序号从0开始。 | :param int index: 需要删除的instance的index,序号从0开始。 | ||||
@@ -639,7 +639,7 @@ class DataSet(object): | |||||
return self | return self | ||||
def delete_field(self, field_name): | def delete_field(self, field_name): | ||||
""" | |||||
r""" | |||||
删除名为field_name的field | 删除名为field_name的field | ||||
:param str field_name: 需要删除的field的名称. | :param str field_name: 需要删除的field的名称. | ||||
@@ -648,7 +648,7 @@ class DataSet(object): | |||||
return self | return self | ||||
def copy_field(self, field_name, new_field_name): | def copy_field(self, field_name, new_field_name): | ||||
""" | |||||
r""" | |||||
深度copy名为field_name的field到new_field_name | 深度copy名为field_name的field到new_field_name | ||||
:param str field_name: 需要copy的field。 | :param str field_name: 需要copy的field。 | ||||
@@ -662,7 +662,7 @@ class DataSet(object): | |||||
return self | return self | ||||
def has_field(self, field_name): | def has_field(self, field_name): | ||||
""" | |||||
r""" | |||||
判断DataSet中是否有名为field_name这个field | 判断DataSet中是否有名为field_name这个field | ||||
:param str field_name: field的名称 | :param str field_name: field的名称 | ||||
@@ -673,7 +673,7 @@ class DataSet(object): | |||||
return False | return False | ||||
def get_field(self, field_name): | def get_field(self, field_name): | ||||
""" | |||||
r""" | |||||
获取field_name这个field | 获取field_name这个field | ||||
:param str field_name: field的名称 | :param str field_name: field的名称 | ||||
@@ -684,7 +684,7 @@ class DataSet(object): | |||||
return self.field_arrays[field_name] | return self.field_arrays[field_name] | ||||
def get_all_fields(self): | def get_all_fields(self): | ||||
""" | |||||
r""" | |||||
返回一个dict,key为field_name, value为对应的 :class:`~fastNLP.FieldArray` | 返回一个dict,key为field_name, value为对应的 :class:`~fastNLP.FieldArray` | ||||
:return dict: 返回如上所述的字典 | :return dict: 返回如上所述的字典 | ||||
@@ -692,7 +692,7 @@ class DataSet(object): | |||||
return self.field_arrays | return self.field_arrays | ||||
def get_field_names(self) -> list: | def get_field_names(self) -> list: | ||||
""" | |||||
r""" | |||||
返回一个list,包含所有 field 的名字 | 返回一个list,包含所有 field 的名字 | ||||
:return list: 返回如上所述的列表 | :return list: 返回如上所述的列表 | ||||
@@ -700,7 +700,7 @@ class DataSet(object): | |||||
return sorted(self.field_arrays.keys()) | return sorted(self.field_arrays.keys()) | ||||
def get_length(self): | def get_length(self): | ||||
""" | |||||
r""" | |||||
获取DataSet的元素数量 | 获取DataSet的元素数量 | ||||
:return: int: DataSet中Instance的个数。 | :return: int: DataSet中Instance的个数。 | ||||
@@ -708,7 +708,7 @@ class DataSet(object): | |||||
return len(self) | return len(self) | ||||
def rename_field(self, field_name, new_field_name): | def rename_field(self, field_name, new_field_name): | ||||
""" | |||||
r""" | |||||
将某个field重新命名. | 将某个field重新命名. | ||||
:param str field_name: 原来的field名称。 | :param str field_name: 原来的field名称。 | ||||
@@ -722,7 +722,7 @@ class DataSet(object): | |||||
return self | return self | ||||
def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): | def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): | ||||
""" | |||||
r""" | |||||
将field_names的field设置为target | 将field_names的field设置为target | ||||
Example:: | Example:: | ||||
@@ -749,7 +749,7 @@ class DataSet(object): | |||||
return self | return self | ||||
def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): | def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): | ||||
""" | |||||
r""" | |||||
将field_names的field设置为input:: | 将field_names的field设置为input:: | ||||
dataset.set_input('words', 'seq_len') # 将words和seq_len这两个field的input属性设置为True | dataset.set_input('words', 'seq_len') # 将words和seq_len这两个field的input属性设置为True | ||||
@@ -773,7 +773,7 @@ class DataSet(object): | |||||
return self | return self | ||||
def set_ignore_type(self, *field_names, flag=True): | def set_ignore_type(self, *field_names, flag=True): | ||||
""" | |||||
r""" | |||||
将field设置为忽略类型状态。当某个field被设置了ignore_type, 则在被设置为target或者input时将不进行类型检查, | 将field设置为忽略类型状态。当某个field被设置了ignore_type, 则在被设置为target或者input时将不进行类型检查, | ||||
默认情况下也不进行pad。如果仍需要pad该field,可通过自定义Padder实现,若该field需要转换为tensor,需要在padder | 默认情况下也不进行pad。如果仍需要pad该field,可通过自定义Padder实现,若该field需要转换为tensor,需要在padder | ||||
中转换,但不需要在padder中移动到gpu。 | 中转换,但不需要在padder中移动到gpu。 | ||||
@@ -791,7 +791,7 @@ class DataSet(object): | |||||
return self | return self | ||||
def set_padder(self, field_name, padder): | def set_padder(self, field_name, padder): | ||||
""" | |||||
r""" | |||||
为field_name设置padder:: | 为field_name设置padder:: | ||||
from fastNLP import EngChar2DPadder | from fastNLP import EngChar2DPadder | ||||
@@ -807,7 +807,7 @@ class DataSet(object): | |||||
return self | return self | ||||
def set_pad_val(self, field_name, pad_val): | def set_pad_val(self, field_name, pad_val): | ||||
""" | |||||
r""" | |||||
为某个field设置对应的pad_val. | 为某个field设置对应的pad_val. | ||||
:param str field_name: 修改该field的pad_val | :param str field_name: 修改该field的pad_val | ||||
@@ -819,7 +819,7 @@ class DataSet(object): | |||||
return self | return self | ||||
def get_input_name(self): | def get_input_name(self): | ||||
""" | |||||
r""" | |||||
返回所有is_input被设置为True的field名称 | 返回所有is_input被设置为True的field名称 | ||||
:return list: 里面的元素为被设置为input的field名称 | :return list: 里面的元素为被设置为input的field名称 | ||||
@@ -827,7 +827,7 @@ class DataSet(object): | |||||
return [name for name, field in self.field_arrays.items() if field.is_input] | return [name for name, field in self.field_arrays.items() if field.is_input] | ||||
def get_target_name(self): | def get_target_name(self): | ||||
""" | |||||
r""" | |||||
返回所有is_target被设置为True的field名称 | 返回所有is_target被设置为True的field名称 | ||||
:return list: 里面的元素为被设置为target的field名称 | :return list: 里面的元素为被设置为target的field名称 | ||||
@@ -835,7 +835,7 @@ class DataSet(object): | |||||
return [name for name, field in self.field_arrays.items() if field.is_target] | return [name for name, field in self.field_arrays.items() if field.is_target] | ||||
def apply_field(self, func, field_name, new_field_name=None, **kwargs): | def apply_field(self, func, field_name, new_field_name=None, **kwargs): | ||||
""" | |||||
r""" | |||||
将DataSet中的每个instance中的名为 `field_name` 的field传给func,并获取它的返回值。 | 将DataSet中的每个instance中的名为 `field_name` 的field传给func,并获取它的返回值。 | ||||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | :param callable func: input是instance中名为 `field_name` 的field的内容。 | ||||
@@ -858,7 +858,7 @@ class DataSet(object): | |||||
return self.apply(func, new_field_name, _apply_field=field_name, **kwargs) | return self.apply(func, new_field_name, _apply_field=field_name, **kwargs) | ||||
def apply_field_more(self, func, field_name, modify_fields=True, **kwargs): | def apply_field_more(self, func, field_name, modify_fields=True, **kwargs): | ||||
""" | |||||
r""" | |||||
将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。 | 将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。 | ||||
func 可以返回一个或多个 field 上的结果。 | func 可以返回一个或多个 field 上的结果。 | ||||
@@ -885,7 +885,7 @@ class DataSet(object): | |||||
return self.apply_more(func, modify_fields, _apply_field=field_name, **kwargs) | return self.apply_more(func, modify_fields, _apply_field=field_name, **kwargs) | ||||
def _add_apply_field(self, results, new_field_name, kwargs): | def _add_apply_field(self, results, new_field_name, kwargs): | ||||
""" | |||||
r""" | |||||
将results作为加入到新的field中,field名称为new_field_name | 将results作为加入到新的field中,field名称为new_field_name | ||||
:param List[str] results: 一般是apply*()之后的结果 | :param List[str] results: 一般是apply*()之后的结果 | ||||
@@ -917,7 +917,7 @@ class DataSet(object): | |||||
ignore_type=extra_param.get("ignore_type", False)) | ignore_type=extra_param.get("ignore_type", False)) | ||||
def apply_more(self, func, modify_fields=True, **kwargs): | def apply_more(self, func, modify_fields=True, **kwargs): | ||||
""" | |||||
r""" | |||||
将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。 | 将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。 | ||||
.. note:: | .. note:: | ||||
@@ -978,7 +978,7 @@ class DataSet(object): | |||||
return results | return results | ||||
def apply(self, func, new_field_name=None, **kwargs): | def apply(self, func, new_field_name=None, **kwargs): | ||||
""" | |||||
r""" | |||||
将DataSet中每个instance传入到func中,并获取它的返回值. | 将DataSet中每个instance传入到func中,并获取它的返回值. | ||||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` | :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` | ||||
@@ -1015,7 +1015,7 @@ class DataSet(object): | |||||
return results | return results | ||||
def add_seq_len(self, field_name: str, new_field_name=Const.INPUT_LEN): | def add_seq_len(self, field_name: str, new_field_name=Const.INPUT_LEN): | ||||
""" | |||||
r""" | |||||
将使用len()直接对field_name中每个元素作用,将其结果作为sequence length, 并放入seq_len这个field。 | 将使用len()直接对field_name中每个元素作用,将其结果作为sequence length, 并放入seq_len这个field。 | ||||
:param field_name: str. | :param field_name: str. | ||||
@@ -1029,7 +1029,7 @@ class DataSet(object): | |||||
return self | return self | ||||
def drop(self, func, inplace=True): | def drop(self, func, inplace=True): | ||||
""" | |||||
r""" | |||||
func接受一个Instance,返回bool值。返回值为True时,该Instance会被移除或者不会包含在返回的DataSet中。 | func接受一个Instance,返回bool值。返回值为True时,该Instance会被移除或者不会包含在返回的DataSet中。 | ||||
:param callable func: 接受一个Instance作为参数,返回bool值。为True时删除该instance | :param callable func: 接受一个Instance作为参数,返回bool值。为True时删除该instance | ||||
@@ -1053,7 +1053,7 @@ class DataSet(object): | |||||
return DataSet() | return DataSet() | ||||
def split(self, ratio, shuffle=True): | def split(self, ratio, shuffle=True): | ||||
""" | |||||
r""" | |||||
将DataSet按照ratio的比例拆分,返回两个DataSet | 将DataSet按照ratio的比例拆分,返回两个DataSet | ||||
:param float ratio: 0<ratio<1, 返回的第一个DataSet拥有 `(1-ratio)` 这么多数据,第二个DataSet拥有`ratio`这么多数据 | :param float ratio: 0<ratio<1, 返回的第一个DataSet拥有 `(1-ratio)` 这么多数据,第二个DataSet拥有`ratio`这么多数据 | ||||
@@ -1088,7 +1088,7 @@ class DataSet(object): | |||||
return train_set, dev_set | return train_set, dev_set | ||||
def save(self, path): | def save(self, path): | ||||
""" | |||||
r""" | |||||
保存DataSet. | 保存DataSet. | ||||
:param str path: 将DataSet存在哪个路径 | :param str path: 将DataSet存在哪个路径 | ||||
@@ -1110,7 +1110,7 @@ class DataSet(object): | |||||
return d | return d | ||||
def add_collect_fn(self, fn, name=None): | def add_collect_fn(self, fn, name=None): | ||||
""" | |||||
r""" | |||||
添加 CollectFn,collect_fn允许在生成的batch的过程中动态生成一些数据(在DataSetIter作为迭代器的情况下有效,默认情况下就是用的 | 添加 CollectFn,collect_fn允许在生成的batch的过程中动态生成一些数据(在DataSetIter作为迭代器的情况下有效,默认情况下就是用的 | ||||
这个)。支持依次添加多个collect_fn, 如果相同的key,后面的collect_fn的结果覆盖前面的collect_fn的结果。 | 这个)。支持依次添加多个collect_fn, 如果相同的key,后面的collect_fn的结果覆盖前面的collect_fn的结果。 | ||||
@@ -1126,7 +1126,7 @@ class DataSet(object): | |||||
self.collector.add_fn(fn, name=name) | self.collector.add_fn(fn, name=name) | ||||
def delete_collect_fn(self, name=None): | def delete_collect_fn(self, name=None): | ||||
""" | |||||
r""" | |||||
删除某个collect_fn | 删除某个collect_fn | ||||
:param str,int name: 如果为None,则删除最近加入的collect_fn | :param str,int name: 如果为None,则删除最近加入的collect_fn | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
分布式 Trainer | 分布式 Trainer | ||||
使用步骤 | 使用步骤 | ||||
1. 在代码中调用 DistTrainer,类似 Trainer,传入模型和数据等等参数 | 1. 在代码中调用 DistTrainer,类似 Trainer,传入模型和数据等等参数 | ||||
@@ -41,7 +41,7 @@ __all__ = [ | |||||
] | ] | ||||
def get_local_rank(): | def get_local_rank(): | ||||
""" | |||||
r""" | |||||
返回当前进程的 local rank, 0 到 N-1 ,N为当前分布式总进程数 | 返回当前进程的 local rank, 0 到 N-1 ,N为当前分布式总进程数 | ||||
""" | """ | ||||
if 'LOCAL_RANK' in os.environ: | if 'LOCAL_RANK' in os.environ: | ||||
@@ -57,7 +57,7 @@ def get_local_rank(): | |||||
class DistTrainer(): | class DistTrainer(): | ||||
""" | |||||
r""" | |||||
分布式的 Trainer,支持分布式训练和混合精度的训练。具体实现原理请阅读 pytorch 官方文档。 | 分布式的 Trainer,支持分布式训练和混合精度的训练。具体实现原理请阅读 pytorch 官方文档。 | ||||
Note: 使用分布式 Trainer 时会同时有多个进程执行训练代码。因此将单进程的训练代码改为多进程之前, | Note: 使用分布式 Trainer 时会同时有多个进程执行训练代码。因此将单进程的训练代码改为多进程之前, | ||||
@@ -71,7 +71,7 @@ class DistTrainer(): | |||||
update_every=1, print_every=10, validate_every=-1, | update_every=1, print_every=10, validate_every=-1, | ||||
save_every=-1, save_path=None, device='auto', | save_every=-1, save_path=None, device='auto', | ||||
fp16='', backend=None, init_method=None, use_tqdm=True): | fp16='', backend=None, init_method=None, use_tqdm=True): | ||||
""" | |||||
r""" | |||||
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 | :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 | ||||
:param nn.modules model: 待训练的模型 | :param nn.modules model: 待训练的模型 | ||||
@@ -227,11 +227,11 @@ class DistTrainer(): | |||||
@property | @property | ||||
def is_master(self): | def is_master(self): | ||||
"""是否是主进程""" | |||||
r"""是否是主进程""" | |||||
return self.rank == 0 | return self.rank == 0 | ||||
def train(self, load_best_model=True, on_exception='auto'): | def train(self, load_best_model=True, on_exception='auto'): | ||||
""" | |||||
r""" | |||||
使用该函数使Trainer开始训练。 | 使用该函数使Trainer开始训练。 | ||||
:param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 | :param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 | ||||
@@ -374,7 +374,7 @@ class DistTrainer(): | |||||
# ============ tqdm end ============== # | # ============ tqdm end ============== # | ||||
def _update(self): | def _update(self): | ||||
"""Perform weight update on a model. | |||||
r"""Perform weight update on a model. | |||||
""" | """ | ||||
if self.step % self.update_every == 0: | if self.step % self.update_every == 0: | ||||
@@ -390,7 +390,7 @@ class DistTrainer(): | |||||
return y | return y | ||||
def _compute_loss(self, predict, truth): | def _compute_loss(self, predict, truth): | ||||
"""Compute loss given prediction and ground truth. | |||||
r"""Compute loss given prediction and ground truth. | |||||
:param predict: prediction dict, produced by model.forward | :param predict: prediction dict, produced by model.forward | ||||
:param truth: ground truth dict, produced by batch_y | :param truth: ground truth dict, produced by batch_y | ||||
@@ -404,7 +404,7 @@ class DistTrainer(): | |||||
return loss | return loss | ||||
def save_check_point(self, name=None, only_params=False): | def save_check_point(self, name=None, only_params=False): | ||||
"""保存当前模型""" | |||||
r"""保存当前模型""" | |||||
# only master save models | # only master save models | ||||
if self.is_master: | if self.is_master: | ||||
if name is None: | if name is None: | ||||
@@ -446,5 +446,5 @@ class DistTrainer(): | |||||
dist.barrier() | dist.barrier() | ||||
def close(self): | def close(self): | ||||
"""关闭Trainer,销毁进程""" | |||||
r"""关闭Trainer,销毁进程""" | |||||
dist.destroy_process_group() | dist.destroy_process_group() |
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
.. todo:: | .. todo:: | ||||
doc | doc | ||||
""" | """ | ||||
@@ -88,7 +88,7 @@ class FieldArray: | |||||
@is_input.setter | @is_input.setter | ||||
def is_input(self, value): | def is_input(self, value): | ||||
""" | |||||
r""" | |||||
当 field_array.is_input = True / False 时被调用 | 当 field_array.is_input = True / False 时被调用 | ||||
""" | """ | ||||
# 如果(value为True)且(_is_input和_is_target都是False)且(ignore_type为False) | # 如果(value为True)且(_is_input和_is_target都是False)且(ignore_type为False) | ||||
@@ -107,7 +107,7 @@ class FieldArray: | |||||
@is_target.setter | @is_target.setter | ||||
def is_target(self, value): | def is_target(self, value): | ||||
""" | |||||
r""" | |||||
当 field_array.is_target = True / False 时被调用 | 当 field_array.is_target = True / False 时被调用 | ||||
""" | """ | ||||
if value is True and \ | if value is True and \ | ||||
@@ -120,7 +120,7 @@ class FieldArray: | |||||
self._is_target = value | self._is_target = value | ||||
def _check_dtype_and_ndim(self, only_check_1st_ins_dim_type=True): | def _check_dtype_and_ndim(self, only_check_1st_ins_dim_type=True): | ||||
""" | |||||
r""" | |||||
检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有 | 检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有 | ||||
通过将直接报错. | 通过将直接报错. | ||||
@@ -150,7 +150,7 @@ class FieldArray: | |||||
raise e | raise e | ||||
def append(self, val: Any): | def append(self, val: Any): | ||||
""" | |||||
r""" | |||||
:param val: 把该val append到fieldarray。 | :param val: 把该val append到fieldarray。 | ||||
:return: | :return: | ||||
""" | """ | ||||
@@ -167,7 +167,7 @@ class FieldArray: | |||||
self.content.append(val) | self.content.append(val) | ||||
def pop(self, index): | def pop(self, index): | ||||
""" | |||||
r""" | |||||
删除该field中index处的元素 | 删除该field中index处的元素 | ||||
:param int index: 从0开始的数据下标。 | :param int index: 从0开始的数据下标。 | ||||
:return: | :return: | ||||
@@ -190,7 +190,7 @@ class FieldArray: | |||||
self.content[idx] = val | self.content[idx] = val | ||||
def get(self, indices, pad=True): | def get(self, indices, pad=True): | ||||
""" | |||||
r""" | |||||
根据给定的indices返回内容。 | 根据给定的indices返回内容。 | ||||
:param int,List[int] indices: 获取indices对应的内容。 | :param int,List[int] indices: 获取indices对应的内容。 | ||||
@@ -210,7 +210,7 @@ class FieldArray: | |||||
return np.array(contents) | return np.array(contents) | ||||
def pad(self, contents): | def pad(self, contents): | ||||
""" | |||||
r""" | |||||
传入list的contents,将contents使用padder进行padding,contents必须为从本FieldArray中取出的。 | 传入list的contents,将contents使用padder进行padding,contents必须为从本FieldArray中取出的。 | ||||
:param list contents: | :param list contents: | ||||
@@ -219,7 +219,7 @@ class FieldArray: | |||||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | ||||
def set_padder(self, padder): | def set_padder(self, padder): | ||||
""" | |||||
r""" | |||||
设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 | 设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 | ||||
:param padder: :class:`~fastNLP.Padder` 类型,设置为None即删除padder。 | :param padder: :class:`~fastNLP.Padder` 类型,设置为None即删除padder。 | ||||
@@ -231,7 +231,7 @@ class FieldArray: | |||||
self.padder = None | self.padder = None | ||||
def set_pad_val(self, pad_val): | def set_pad_val(self, pad_val): | ||||
""" | |||||
r""" | |||||
修改padder的pad_val. | 修改padder的pad_val. | ||||
:param int pad_val: 该field的pad值设置为该值。 | :param int pad_val: 该field的pad值设置为该值。 | ||||
@@ -241,7 +241,7 @@ class FieldArray: | |||||
return self | return self | ||||
def __len__(self): | def __len__(self): | ||||
""" | |||||
r""" | |||||
Returns the size of FieldArray. | Returns the size of FieldArray. | ||||
:return int length: | :return int length: | ||||
@@ -249,7 +249,7 @@ class FieldArray: | |||||
return len(self.content) | return len(self.content) | ||||
def to(self, other): | def to(self, other): | ||||
""" | |||||
r""" | |||||
将other的属性复制给本FieldArray(other必须为FieldArray类型). | 将other的属性复制给本FieldArray(other必须为FieldArray类型). | ||||
属性包括 is_input, is_target, padder, ignore_type | 属性包括 is_input, is_target, padder, ignore_type | ||||
@@ -266,7 +266,7 @@ class FieldArray: | |||||
return self | return self | ||||
def split(self, sep: str = None, inplace: bool = True): | def split(self, sep: str = None, inplace: bool = True): | ||||
""" | |||||
r""" | |||||
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值 | 依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值 | ||||
:param sep: 分割符,如果为None则直接调用str.split()。 | :param sep: 分割符,如果为None则直接调用str.split()。 | ||||
@@ -283,7 +283,7 @@ class FieldArray: | |||||
return self._after_process(new_contents, inplace=inplace) | return self._after_process(new_contents, inplace=inplace) | ||||
def int(self, inplace: bool = True): | def int(self, inplace: bool = True): | ||||
""" | |||||
r""" | |||||
将本field中的值调用int(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | 将本field中的值调用int(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | ||||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | ||||
@@ -303,7 +303,7 @@ class FieldArray: | |||||
return self._after_process(new_contents, inplace=inplace) | return self._after_process(new_contents, inplace=inplace) | ||||
def float(self, inplace=True): | def float(self, inplace=True): | ||||
""" | |||||
r""" | |||||
将本field中的值调用float(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | 将本field中的值调用float(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | ||||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | ||||
@@ -323,7 +323,7 @@ class FieldArray: | |||||
return self._after_process(new_contents, inplace=inplace) | return self._after_process(new_contents, inplace=inplace) | ||||
def bool(self, inplace=True): | def bool(self, inplace=True): | ||||
""" | |||||
r""" | |||||
将本field中的值调用bool(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | 将本field中的值调用bool(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | ||||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | ||||
@@ -344,7 +344,7 @@ class FieldArray: | |||||
return self._after_process(new_contents, inplace=inplace) | return self._after_process(new_contents, inplace=inplace) | ||||
def lower(self, inplace=True): | def lower(self, inplace=True): | ||||
""" | |||||
r""" | |||||
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | 将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | ||||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | ||||
@@ -364,7 +364,7 @@ class FieldArray: | |||||
return self._after_process(new_contents, inplace=inplace) | return self._after_process(new_contents, inplace=inplace) | ||||
def upper(self, inplace=True): | def upper(self, inplace=True): | ||||
""" | |||||
r""" | |||||
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | 将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | ||||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | ||||
@@ -384,7 +384,7 @@ class FieldArray: | |||||
return self._after_process(new_contents, inplace=inplace) | return self._after_process(new_contents, inplace=inplace) | ||||
def value_count(self): | def value_count(self): | ||||
""" | |||||
r""" | |||||
返回该field下不同value的数量。多用于统计label数量 | 返回该field下不同value的数量。多用于统计label数量 | ||||
:return: Counter, key是label,value是出现次数 | :return: Counter, key是label,value是出现次数 | ||||
@@ -403,7 +403,7 @@ class FieldArray: | |||||
return count | return count | ||||
def _after_process(self, new_contents, inplace): | def _after_process(self, new_contents, inplace): | ||||
""" | |||||
r""" | |||||
当调用处理函数之后,决定是否要替换field。 | 当调用处理函数之后,决定是否要替换field。 | ||||
:param new_contents: | :param new_contents: | ||||
@@ -424,7 +424,7 @@ class FieldArray: | |||||
def _get_ele_type_and_dim(cell: Any, dim=0): | def _get_ele_type_and_dim(cell: Any, dim=0): | ||||
""" | |||||
r""" | |||||
识别cell的类别与dimension的数量 | 识别cell的类别与dimension的数量 | ||||
numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html | numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html | ||||
@@ -470,7 +470,7 @@ def _get_ele_type_and_dim(cell: Any, dim=0): | |||||
class Padder: | class Padder: | ||||
""" | |||||
r""" | |||||
所有padder都需要继承这个类,并覆盖__call__方法。 | 所有padder都需要继承这个类,并覆盖__call__方法。 | ||||
用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。 | 用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。 | ||||
@@ -479,7 +479,7 @@ class Padder: | |||||
""" | """ | ||||
def __init__(self, pad_val=0, **kwargs): | def __init__(self, pad_val=0, **kwargs): | ||||
""" | |||||
r""" | |||||
:param List[Any] contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | :param List[Any] contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | ||||
deepcopy一份。 | deepcopy一份。 | ||||
@@ -497,7 +497,7 @@ class Padder: | |||||
@abstractmethod | @abstractmethod | ||||
def __call__(self, contents, field_name, field_ele_dtype, dim: int): | def __call__(self, contents, field_name, field_ele_dtype, dim: int): | ||||
""" | |||||
r""" | |||||
传入的是List内容。假设有以下的DataSet。 | 传入的是List内容。假设有以下的DataSet。 | ||||
:param List[Any] contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | :param List[Any] contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | ||||
@@ -541,7 +541,7 @@ class Padder: | |||||
class AutoPadder(Padder): | class AutoPadder(Padder): | ||||
""" | |||||
r""" | |||||
根据contents的数据自动判定是否需要做padding。 | 根据contents的数据自动判定是否需要做padding。 | ||||
1 如果元素类型(元素类型是指field中最里层元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 | 1 如果元素类型(元素类型是指field中最里层元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 | ||||
@@ -633,7 +633,7 @@ class AutoPadder(Padder): | |||||
class EngChar2DPadder(Padder): | class EngChar2DPadder(Padder): | ||||
""" | |||||
r""" | |||||
用于为英语执行character级别的2D padding操作。对应的field内容应该类似[['T', 'h', 'i', 's'], ['a'], ['d', 'e', 'm', 'o']], | 用于为英语执行character级别的2D padding操作。对应的field内容应该类似[['T', 'h', 'i', 's'], ['a'], ['d', 'e', 'm', 'o']], | ||||
但这个Padder只能处理index为int的情况。 | 但这个Padder只能处理index为int的情况。 | ||||
@@ -655,7 +655,7 @@ class EngChar2DPadder(Padder): | |||||
""" | """ | ||||
def __init__(self, pad_val=0, pad_length=0): | def __init__(self, pad_val=0, pad_length=0): | ||||
""" | |||||
r""" | |||||
:param pad_val: int, pad的位置使用该index | :param pad_val: int, pad的位置使用该index | ||||
:param pad_length: int, 如果为0则取一个batch中最大的单词长度作为padding长度。如果为大于0的数,则将所有单词的长度 | :param pad_length: int, 如果为0则取一个batch中最大的单词长度作为padding长度。如果为大于0的数,则将所有单词的长度 | ||||
都pad或截取到该长度. | 都pad或截取到该长度. | ||||
@@ -665,7 +665,7 @@ class EngChar2DPadder(Padder): | |||||
self.pad_length = pad_length | self.pad_length = pad_length | ||||
def __call__(self, contents, field_name, field_ele_dtype, dim): | def __call__(self, contents, field_name, field_ele_dtype, dim): | ||||
""" | |||||
r""" | |||||
期望输入类似于 | 期望输入类似于 | ||||
[ | [ | ||||
[[0, 2], [2, 3, 4], ..], | [[0, 2], [2, 3, 4], ..], | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可以认为是一个Instance类型的对象。 | instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可以认为是一个Instance类型的对象。 | ||||
便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset` 中的表格 | 便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset` 中的表格 | ||||
@@ -12,7 +12,7 @@ from .utils import pretty_table_printer | |||||
class Instance(object): | class Instance(object): | ||||
""" | |||||
r""" | |||||
Instance是fastNLP中对应一个sample的类。每个sample在fastNLP中是一个Instance对象。 | Instance是fastNLP中对应一个sample的类。每个sample在fastNLP中是一个Instance对象。 | ||||
Instance一般与 :class:`~fastNLP.DataSet` 一起使用, Instance的初始化如下面的Example所示:: | Instance一般与 :class:`~fastNLP.DataSet` 一起使用, Instance的初始化如下面的Example所示:: | ||||
@@ -29,7 +29,7 @@ class Instance(object): | |||||
self.fields = fields | self.fields = fields | ||||
def add_field(self, field_name, field): | def add_field(self, field_name, field): | ||||
""" | |||||
r""" | |||||
向Instance中增加一个field | 向Instance中增加一个field | ||||
:param str field_name: 新增field的名称 | :param str field_name: 新增field的名称 | ||||
@@ -38,7 +38,7 @@ class Instance(object): | |||||
self.fields[field_name] = field | self.fields[field_name] = field | ||||
def items(self): | def items(self): | ||||
""" | |||||
r""" | |||||
返回一个迭代器,迭代器返回两个内容,第一个内容是field_name, 第二个内容是field_value | 返回一个迭代器,迭代器返回两个内容,第一个内容是field_name, 第二个内容是field_value | ||||
:return: 一个迭代器 | :return: 一个迭代器 | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
losses 模块定义了 fastNLP 中所需的各种损失函数,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | losses 模块定义了 fastNLP 中所需的各种损失函数,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | ||||
""" | """ | ||||
@@ -34,7 +34,7 @@ from ..core.const import Const | |||||
class LossBase(object): | class LossBase(object): | ||||
""" | |||||
r""" | |||||
所有loss的基类。如果想了解其中的原理,请查看源码。 | 所有loss的基类。如果想了解其中的原理,请查看源码。 | ||||
""" | """ | ||||
@@ -55,7 +55,7 @@ class LossBase(object): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def _init_param_map(self, key_map=None, **kwargs): | def _init_param_map(self, key_map=None, **kwargs): | ||||
"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map | |||||
r"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map | |||||
:param dict key_map: 表示key的映射关系 | :param dict key_map: 表示key的映射关系 | ||||
:param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 | :param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 | ||||
@@ -102,7 +102,7 @@ class LossBase(object): | |||||
# f"positional argument.).") | # f"positional argument.).") | ||||
def __call__(self, pred_dict, target_dict, check=False): | def __call__(self, pred_dict, target_dict, check=False): | ||||
""" | |||||
r""" | |||||
:param dict pred_dict: 模型的forward函数返回的dict | :param dict pred_dict: 模型的forward函数返回的dict | ||||
:param dict target_dict: DataSet.batch_y里的键-值对所组成的dict | :param dict target_dict: DataSet.batch_y里的键-值对所组成的dict | ||||
:param Boolean check: 每一次执行映射函数的时候是否检查映射表,默认为不检查 | :param Boolean check: 每一次执行映射函数的时候是否检查映射表,默认为不检查 | ||||
@@ -168,7 +168,7 @@ class LossBase(object): | |||||
class LossFunc(LossBase): | class LossFunc(LossBase): | ||||
""" | |||||
r""" | |||||
提供给用户使用自定义损失函数的类 | 提供给用户使用自定义损失函数的类 | ||||
:param func: 用户自行定义的损失函数,应当为一个函数或者callable(func)为True的ojbect | :param func: 用户自行定义的损失函数,应当为一个函数或者callable(func)为True的ojbect | ||||
@@ -199,7 +199,7 @@ class LossFunc(LossBase): | |||||
class CrossEntropyLoss(LossBase): | class CrossEntropyLoss(LossBase): | ||||
""" | |||||
r""" | |||||
交叉熵损失函数 | 交叉熵损失函数 | ||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | ||||
@@ -246,7 +246,7 @@ class CrossEntropyLoss(LossBase): | |||||
class L1Loss(LossBase): | class L1Loss(LossBase): | ||||
""" | |||||
r""" | |||||
L1损失函数 | L1损失函数 | ||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | ||||
@@ -266,7 +266,7 @@ class L1Loss(LossBase): | |||||
class BCELoss(LossBase): | class BCELoss(LossBase): | ||||
""" | |||||
r""" | |||||
二分类交叉熵损失函数 | 二分类交叉熵损失函数 | ||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | ||||
@@ -285,12 +285,12 @@ class BCELoss(LossBase): | |||||
class NLLLoss(LossBase): | class NLLLoss(LossBase): | ||||
""" | |||||
r""" | |||||
负对数似然损失函数 | 负对数似然损失函数 | ||||
""" | """ | ||||
def __init__(self, pred=None, target=None, ignore_idx=-100, reduction='mean'): | def __init__(self, pred=None, target=None, ignore_idx=-100, reduction='mean'): | ||||
""" | |||||
r""" | |||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | ||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | ||||
@@ -309,12 +309,12 @@ class NLLLoss(LossBase): | |||||
class LossInForward(LossBase): | class LossInForward(LossBase): | ||||
""" | |||||
r""" | |||||
从forward()函数返回结果中获取loss | 从forward()函数返回结果中获取loss | ||||
""" | """ | ||||
def __init__(self, loss_key=Const.LOSS): | def __init__(self, loss_key=Const.LOSS): | ||||
""" | |||||
r""" | |||||
:param str loss_key: 在forward函数中loss的键名,默认为loss | :param str loss_key: 在forward函数中loss的键名,默认为loss | ||||
""" | """ | ||||
@@ -349,7 +349,7 @@ class LossInForward(LossBase): | |||||
class CMRC2018Loss(LossBase): | class CMRC2018Loss(LossBase): | ||||
""" | |||||
r""" | |||||
用于计算CMRC2018中文问答任务。 | 用于计算CMRC2018中文问答任务。 | ||||
""" | """ | ||||
@@ -364,7 +364,7 @@ class CMRC2018Loss(LossBase): | |||||
self.reduction = reduction | self.reduction = reduction | ||||
def get_loss(self, target_start, target_end, context_len, pred_start, pred_end): | def get_loss(self, target_start, target_end, context_len, pred_start, pred_end): | ||||
""" | |||||
r""" | |||||
:param target_start: batch_size | :param target_start: batch_size | ||||
:param target_end: batch_size | :param target_end: batch_size | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
metrics 模块实现了 fastNLP 所需的各种常用衡量指标,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | metrics 模块实现了 fastNLP 所需的各种常用衡量指标,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | ||||
""" | """ | ||||
@@ -33,7 +33,7 @@ from .utils import ConfusionMatrix | |||||
class MetricBase(object): | class MetricBase(object): | ||||
""" | |||||
r""" | |||||
所有metrics的基类,所有的传入到Trainer, Tester的Metric需要继承自该对象,需要覆盖写入evaluate(), get_metric()方法。 | 所有metrics的基类,所有的传入到Trainer, Tester的Metric需要继承自该对象,需要覆盖写入evaluate(), get_metric()方法。 | ||||
evaluate(xxx)中传入的是一个batch的数据。 | evaluate(xxx)中传入的是一个batch的数据。 | ||||
@@ -145,7 +145,7 @@ class MetricBase(object): | |||||
raise NotImplemented | raise NotImplemented | ||||
def set_metric_name(self, name: str): | def set_metric_name(self, name: str): | ||||
""" | |||||
r""" | |||||
设置metric的名称,默认是Metric的class name. | 设置metric的名称,默认是Metric的class name. | ||||
:param str name: | :param str name: | ||||
@@ -155,7 +155,7 @@ class MetricBase(object): | |||||
return self | return self | ||||
def get_metric_name(self): | def get_metric_name(self): | ||||
""" | |||||
r""" | |||||
返回metric的名称 | 返回metric的名称 | ||||
:return: | :return: | ||||
@@ -163,7 +163,7 @@ class MetricBase(object): | |||||
return self._metric_name | return self._metric_name | ||||
def _init_param_map(self, key_map=None, **kwargs): | def _init_param_map(self, key_map=None, **kwargs): | ||||
"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map | |||||
r"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map | |||||
:param dict key_map: 表示key的映射关系 | :param dict key_map: 表示key的映射关系 | ||||
:param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 | :param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 | ||||
@@ -205,7 +205,7 @@ class MetricBase(object): | |||||
f"initialization parameters, or change its signature.") | f"initialization parameters, or change its signature.") | ||||
def __call__(self, pred_dict, target_dict): | def __call__(self, pred_dict, target_dict): | ||||
""" | |||||
r""" | |||||
这个方法会调用self.evaluate 方法. | 这个方法会调用self.evaluate 方法. | ||||
在调用之前,会进行以下检测: | 在调用之前,会进行以下检测: | ||||
1. self.evaluate当中是否有varargs, 这是不支持的. | 1. self.evaluate当中是否有varargs, 这是不支持的. | ||||
@@ -315,7 +315,7 @@ class ConfusionMatrixMetric(MetricBase): | |||||
seq_len=None, | seq_len=None, | ||||
print_ratio=False | print_ratio=False | ||||
): | ): | ||||
""" | |||||
r""" | |||||
:param vocab: vocab词表类,要求有to_word()方法。 | :param vocab: vocab词表类,要求有to_word()方法。 | ||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | ||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | ||||
@@ -330,7 +330,7 @@ class ConfusionMatrixMetric(MetricBase): | |||||
) | ) | ||||
def evaluate(self, pred, target, seq_len=None): | def evaluate(self, pred, target, seq_len=None): | ||||
""" | |||||
r""" | |||||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | evaluate函数将针对一个批次的预测结果做评价指标的累计 | ||||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | ||||
@@ -379,7 +379,7 @@ class ConfusionMatrixMetric(MetricBase): | |||||
target.tolist()) | target.tolist()) | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
""" | |||||
r""" | |||||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | ||||
:param bool reset: 在调用完get_metric后是否清空评价指标统计量. | :param bool reset: 在调用完get_metric后是否清空评价指标统计量. | ||||
:return dict evaluate_result: {"confusion_matrix": ConfusionMatrix} | :return dict evaluate_result: {"confusion_matrix": ConfusionMatrix} | ||||
@@ -394,12 +394,12 @@ class ConfusionMatrixMetric(MetricBase): | |||||
class AccuracyMetric(MetricBase): | class AccuracyMetric(MetricBase): | ||||
""" | |||||
r""" | |||||
准确率Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) | 准确率Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) | ||||
""" | """ | ||||
def __init__(self, pred=None, target=None, seq_len=None): | def __init__(self, pred=None, target=None, seq_len=None): | ||||
""" | |||||
r""" | |||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | ||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | ||||
@@ -414,7 +414,7 @@ class AccuracyMetric(MetricBase): | |||||
self.acc_count = 0 | self.acc_count = 0 | ||||
def evaluate(self, pred, target, seq_len=None): | def evaluate(self, pred, target, seq_len=None): | ||||
""" | |||||
r""" | |||||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | evaluate函数将针对一个批次的预测结果做评价指标的累计 | ||||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | ||||
@@ -463,7 +463,7 @@ class AccuracyMetric(MetricBase): | |||||
self.total += np.prod(list(pred.size())) | self.total += np.prod(list(pred.size())) | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
""" | |||||
r""" | |||||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | ||||
:param bool reset: 在调用完get_metric后是否清空评价指标统计量. | :param bool reset: 在调用完get_metric后是否清空评价指标统计量. | ||||
@@ -537,7 +537,7 @@ class ClassifyFPreRecMetric(MetricBase): | |||||
# tp: truth=T, classify=T; fp: truth=T, classify=F; fn: truth=F, classify=T | # tp: truth=T, classify=T; fp: truth=T, classify=F; fn: truth=F, classify=T | ||||
def evaluate(self, pred, target, seq_len=None): | def evaluate(self, pred, target, seq_len=None): | ||||
""" | |||||
r""" | |||||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | evaluate函数将针对一个批次的预测结果做评价指标的累计 | ||||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | ||||
@@ -586,7 +586,7 @@ class ClassifyFPreRecMetric(MetricBase): | |||||
self._fn[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target == target_idx, 0).masked_fill(masks, 0)).item() | self._fn[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target == target_idx, 0).masked_fill(masks, 0)).item() | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
""" | |||||
r""" | |||||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | ||||
:param bool reset: 在调用完get_metric后是否清空评价指标统计量. | :param bool reset: 在调用完get_metric后是否清空评价指标统计量. | ||||
@@ -646,7 +646,7 @@ class ClassifyFPreRecMetric(MetricBase): | |||||
def _bmes_tag_to_spans(tags, ignore_labels=None): | def _bmes_tag_to_spans(tags, ignore_labels=None): | ||||
""" | |||||
r""" | |||||
给定一个tags的lis,比如['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']。 | 给定一个tags的lis,比如['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']。 | ||||
返回[('song', (0, 1)), ('singer', (1, 4)), ('moive', (4, 5)), ('actor', (5, 6))] (左闭右开区间) | 返回[('song', (0, 1)), ('singer', (1, 4)), ('moive', (4, 5)), ('actor', (5, 6))] (左闭右开区间) | ||||
也可以是单纯的['S', 'B', 'M', 'E', 'B', 'M', 'M',...]序列 | 也可以是单纯的['S', 'B', 'M', 'E', 'B', 'M', 'M',...]序列 | ||||
@@ -676,7 +676,7 @@ def _bmes_tag_to_spans(tags, ignore_labels=None): | |||||
def _bmeso_tag_to_spans(tags, ignore_labels=None): | def _bmeso_tag_to_spans(tags, ignore_labels=None): | ||||
""" | |||||
r""" | |||||
给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。 | 给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。 | ||||
返回[('singer', (1, 4))] (左闭右开区间) | 返回[('singer', (1, 4))] (左闭右开区间) | ||||
@@ -707,7 +707,7 @@ def _bmeso_tag_to_spans(tags, ignore_labels=None): | |||||
def _bioes_tag_to_spans(tags, ignore_labels=None): | def _bioes_tag_to_spans(tags, ignore_labels=None): | ||||
""" | |||||
r""" | |||||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'E-singer', 'O', 'O']。 | 给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'E-singer', 'O', 'O']。 | ||||
返回[('singer', (1, 4))] (左闭右开区间) | 返回[('singer', (1, 4))] (左闭右开区间) | ||||
@@ -738,7 +738,7 @@ def _bioes_tag_to_spans(tags, ignore_labels=None): | |||||
def _bio_tag_to_spans(tags, ignore_labels=None): | def _bio_tag_to_spans(tags, ignore_labels=None): | ||||
""" | |||||
r""" | |||||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 | 给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 | ||||
返回[('singer', (1, 4))] (左闭右开区间) | 返回[('singer', (1, 4))] (左闭右开区间) | ||||
@@ -766,7 +766,7 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||||
def _get_encoding_type_from_tag_vocab(tag_vocab: Union[Vocabulary, dict]) -> str: | def _get_encoding_type_from_tag_vocab(tag_vocab: Union[Vocabulary, dict]) -> str: | ||||
""" | |||||
r""" | |||||
给定Vocabulary自动判断是哪种类型的encoding, 支持判断bmes, bioes, bmeso, bio | 给定Vocabulary自动判断是哪种类型的encoding, 支持判断bmes, bioes, bmeso, bio | ||||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | :param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | ||||
@@ -802,7 +802,7 @@ def _get_encoding_type_from_tag_vocab(tag_vocab: Union[Vocabulary, dict]) -> str | |||||
def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str): | def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str): | ||||
""" | |||||
r""" | |||||
检查vocab中的tag是否与encoding_type是匹配的 | 检查vocab中的tag是否与encoding_type是匹配的 | ||||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | :param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | ||||
@@ -913,7 +913,7 @@ class SpanFPreRecMetric(MetricBase): | |||||
self._false_negatives = defaultdict(int) | self._false_negatives = defaultdict(int) | ||||
def evaluate(self, pred, target, seq_len): | def evaluate(self, pred, target, seq_len): | ||||
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||||
r"""evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||||
:param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果 | :param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果 | ||||
:param target: [batch, seq_len], 真实值 | :param target: [batch, seq_len], 真实值 | ||||
@@ -967,7 +967,7 @@ class SpanFPreRecMetric(MetricBase): | |||||
self._false_negatives[span[0]] += 1 | self._false_negatives[span[0]] += 1 | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" | |||||
r"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" | |||||
evaluate_result = {} | evaluate_result = {} | ||||
if not self.only_gross or self.f_type == 'macro': | if not self.only_gross or self.f_type == 'macro': | ||||
tags = set(self._false_negatives.keys()) | tags = set(self._false_negatives.keys()) | ||||
@@ -1018,7 +1018,7 @@ class SpanFPreRecMetric(MetricBase): | |||||
def _compute_f_pre_rec(beta_square, tp, fn, fp): | def _compute_f_pre_rec(beta_square, tp, fn, fp): | ||||
""" | |||||
r""" | |||||
:param tp: int, true positive | :param tp: int, true positive | ||||
:param fn: int, false negative | :param fn: int, false negative | ||||
@@ -1033,7 +1033,7 @@ def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||||
def _prepare_metrics(metrics): | def _prepare_metrics(metrics): | ||||
""" | |||||
r""" | |||||
Prepare list of Metric based on input | Prepare list of Metric based on input | ||||
:param metrics: | :param metrics: | ||||
@@ -1064,7 +1064,7 @@ def _prepare_metrics(metrics): | |||||
def _accuracy_topk(y_true, y_prob, k=1): | def _accuracy_topk(y_true, y_prob, k=1): | ||||
"""Compute accuracy of y_true matching top-k probable labels in y_prob. | |||||
r"""Compute accuracy of y_true matching top-k probable labels in y_prob. | |||||
:param y_true: ndarray, true label, [n_samples] | :param y_true: ndarray, true label, [n_samples] | ||||
:param y_prob: ndarray, label probabilities, [n_samples, n_classes] | :param y_prob: ndarray, label probabilities, [n_samples, n_classes] | ||||
@@ -1080,7 +1080,7 @@ def _accuracy_topk(y_true, y_prob, k=1): | |||||
def _pred_topk(y_prob, k=1): | def _pred_topk(y_prob, k=1): | ||||
"""Return top-k predicted labels and corresponding probabilities. | |||||
r"""Return top-k predicted labels and corresponding probabilities. | |||||
:param y_prob: ndarray, size [n_samples, n_classes], probabilities on labels | :param y_prob: ndarray, size [n_samples, n_classes], probabilities on labels | ||||
:param k: int, k of top-k | :param k: int, k of top-k | ||||
@@ -1110,7 +1110,7 @@ class CMRC2018Metric(MetricBase): | |||||
self.f1 = 0 | self.f1 = 0 | ||||
def evaluate(self, answers, raw_chars, pred_start, pred_end, context_len=None): | def evaluate(self, answers, raw_chars, pred_start, pred_end, context_len=None): | ||||
""" | |||||
r""" | |||||
:param list[str] answers: 如[["答案1", "答案2", "答案3"], [...], ...] | :param list[str] answers: 如[["答案1", "答案2", "答案3"], [...], ...] | ||||
:param list[str] raw_chars: [["这", "是", ...], [...]] | :param list[str] raw_chars: [["这", "是", ...], [...]] | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
optimizer 模块定义了 fastNLP 中所需的各种优化器,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | optimizer 模块定义了 fastNLP 中所需的各种优化器,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | ||||
""" | """ | ||||
@@ -16,12 +16,12 @@ from torch.optim.optimizer import Optimizer as TorchOptimizer | |||||
class Optimizer(object): | class Optimizer(object): | ||||
""" | |||||
r""" | |||||
Optimizer | Optimizer | ||||
""" | """ | ||||
def __init__(self, model_params, **kwargs): | def __init__(self, model_params, **kwargs): | ||||
""" | |||||
r""" | |||||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | ||||
:param kwargs: additional parameters. | :param kwargs: additional parameters. | ||||
@@ -36,7 +36,7 @@ class Optimizer(object): | |||||
@staticmethod | @staticmethod | ||||
def _get_require_grads_param(params): | def _get_require_grads_param(params): | ||||
""" | |||||
r""" | |||||
将params中不需要gradient的删除 | 将params中不需要gradient的删除 | ||||
:param iterable params: parameters | :param iterable params: parameters | ||||
@@ -46,7 +46,7 @@ class Optimizer(object): | |||||
class NullOptimizer(Optimizer): | class NullOptimizer(Optimizer): | ||||
""" | |||||
r""" | |||||
当不希望Trainer更新optimizer时,传入本optimizer,但请确保通过callback的方式对参数进行了更新。 | 当不希望Trainer更新optimizer时,传入本optimizer,但请确保通过callback的方式对参数进行了更新。 | ||||
""" | """ | ||||
@@ -64,12 +64,12 @@ class NullOptimizer(Optimizer): | |||||
class SGD(Optimizer): | class SGD(Optimizer): | ||||
""" | |||||
r""" | |||||
SGD | SGD | ||||
""" | """ | ||||
def __init__(self, lr=0.001, momentum=0, model_params=None): | def __init__(self, lr=0.001, momentum=0, model_params=None): | ||||
""" | |||||
r""" | |||||
:param float lr: learning rate. Default: 0.01 | :param float lr: learning rate. Default: 0.01 | ||||
:param float momentum: momentum. Default: 0 | :param float momentum: momentum. Default: 0 | ||||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | ||||
@@ -87,12 +87,12 @@ class SGD(Optimizer): | |||||
class Adam(Optimizer): | class Adam(Optimizer): | ||||
""" | |||||
r""" | |||||
Adam | Adam | ||||
""" | """ | ||||
def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None): | def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None): | ||||
""" | |||||
r""" | |||||
:param float lr: learning rate | :param float lr: learning rate | ||||
:param float weight_decay: | :param float weight_decay: | ||||
@@ -133,7 +133,7 @@ class AdamW(TorchOptimizer): | |||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, | ||||
weight_decay=1e-2, amsgrad=False): | weight_decay=1e-2, amsgrad=False): | ||||
""" | |||||
r""" | |||||
:param params (iterable): iterable of parameters to optimize or dicts defining | :param params (iterable): iterable of parameters to optimize or dicts defining | ||||
parameter groups | parameter groups | ||||
@@ -164,7 +164,7 @@ class AdamW(TorchOptimizer): | |||||
group.setdefault('amsgrad', False) | group.setdefault('amsgrad', False) | ||||
def step(self, closure=None): | def step(self, closure=None): | ||||
"""Performs a single optimization step. | |||||
r"""Performs a single optimization step. | |||||
:param closure: (callable, optional) A closure that reevaluates the model | :param closure: (callable, optional) A closure that reevaluates the model | ||||
and returns the loss. | and returns the loss. | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"Predictor" | "Predictor" | ||||
@@ -15,7 +15,7 @@ from .utils import _build_args, _move_dict_value_to_device, _get_model_device | |||||
class Predictor(object): | class Predictor(object): | ||||
""" | |||||
r""" | |||||
一个根据训练模型预测输出的预测器(Predictor) | 一个根据训练模型预测输出的预测器(Predictor) | ||||
与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。 | 与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。 | ||||
@@ -23,7 +23,7 @@ class Predictor(object): | |||||
""" | """ | ||||
def __init__(self, network): | def __init__(self, network): | ||||
""" | |||||
r""" | |||||
:param torch.nn.Module network: 用来完成预测任务的模型 | :param torch.nn.Module network: 用来完成预测任务的模型 | ||||
""" | """ | ||||
@@ -35,7 +35,7 @@ class Predictor(object): | |||||
self.batch_output = [] | self.batch_output = [] | ||||
def predict(self, data: DataSet, seq_len_field_name=None): | def predict(self, data: DataSet, seq_len_field_name=None): | ||||
"""用已经训练好的模型进行inference. | |||||
r"""用已经训练好的模型进行inference. | |||||
:param fastNLP.DataSet data: 待预测的数据集 | :param fastNLP.DataSet data: 待预测的数据集 | ||||
:param str seq_len_field_name: 表示序列长度信息的field名字 | :param str seq_len_field_name: 表示序列长度信息的field名字 | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
sampler 子类实现了 fastNLP 所需的各种采样器。 | sampler 子类实现了 fastNLP 所需的各种采样器。 | ||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
@@ -14,14 +14,14 @@ import numpy as np | |||||
class Sampler(object): | class Sampler(object): | ||||
""" | |||||
r""" | |||||
`Sampler` 类的基类. 规定以何种顺序取出data中的元素 | `Sampler` 类的基类. 规定以何种顺序取出data中的元素 | ||||
子类必须实现 ``__call__`` 方法. 输入 `DataSet` 对象, 返回其中元素的下标序列 | 子类必须实现 ``__call__`` 方法. 输入 `DataSet` 对象, 返回其中元素的下标序列 | ||||
""" | """ | ||||
def __call__(self, data_set): | def __call__(self, data_set): | ||||
""" | |||||
r""" | |||||
:param DataSet data_set: `DataSet` 对象, 需要Sample的数据 | :param DataSet data_set: `DataSet` 对象, 需要Sample的数据 | ||||
:return result: list(int) 其中元素的下标序列, ``data_set`` 中元素会按 ``result`` 中顺序取出 | :return result: list(int) 其中元素的下标序列, ``data_set`` 中元素会按 ``result`` 中顺序取出 | ||||
""" | """ | ||||
@@ -29,7 +29,7 @@ class Sampler(object): | |||||
class SequentialSampler(Sampler): | class SequentialSampler(Sampler): | ||||
""" | |||||
r""" | |||||
顺序取出元素的 `Sampler` | 顺序取出元素的 `Sampler` | ||||
""" | """ | ||||
@@ -39,7 +39,7 @@ class SequentialSampler(Sampler): | |||||
class RandomSampler(Sampler): | class RandomSampler(Sampler): | ||||
""" | |||||
r""" | |||||
随机化取元素的 `Sampler` | 随机化取元素的 `Sampler` | ||||
""" | """ | ||||
@@ -49,12 +49,12 @@ class RandomSampler(Sampler): | |||||
class BucketSampler(Sampler): | class BucketSampler(Sampler): | ||||
""" | |||||
r""" | |||||
带Bucket的 `Random Sampler`. 可以随机地取出长度相似的元素 | 带Bucket的 `Random Sampler`. 可以随机地取出长度相似的元素 | ||||
""" | """ | ||||
def __init__(self, num_buckets=10, batch_size=None, seq_len_field_name='seq_len'): | def __init__(self, num_buckets=10, batch_size=None, seq_len_field_name='seq_len'): | ||||
""" | |||||
r""" | |||||
:param int num_buckets: bucket的数量 | :param int num_buckets: bucket的数量 | ||||
:param int batch_size: batch的大小. 默认为None,Trainer在调用BucketSampler时,会将该值正确设置,如果是非Trainer场景使用,需 | :param int batch_size: batch的大小. 默认为None,Trainer在调用BucketSampler时,会将该值正确设置,如果是非Trainer场景使用,需 | ||||
@@ -66,7 +66,7 @@ class BucketSampler(Sampler): | |||||
self.seq_len_field_name = seq_len_field_name | self.seq_len_field_name = seq_len_field_name | ||||
def set_batch_size(self, batch_size): | def set_batch_size(self, batch_size): | ||||
""" | |||||
r""" | |||||
:param int batch_size: 每个batch的大小 | :param int batch_size: 每个batch的大小 | ||||
:return: | :return: | ||||
@@ -111,7 +111,7 @@ class BucketSampler(Sampler): | |||||
def simple_sort_bucketing(lengths): | def simple_sort_bucketing(lengths): | ||||
""" | |||||
r""" | |||||
:param lengths: list of int, the lengths of all examples. | :param lengths: list of int, the lengths of all examples. | ||||
:return data: 2-level list | :return data: 2-level list | ||||
@@ -131,7 +131,7 @@ def simple_sort_bucketing(lengths): | |||||
def k_means_1d(x, k, max_iter=100): | def k_means_1d(x, k, max_iter=100): | ||||
"""Perform k-means on 1-D data. | |||||
r"""Perform k-means on 1-D data. | |||||
:param x: list of int, representing points in 1-D. | :param x: list of int, representing points in 1-D. | ||||
:param k: the number of clusters required. | :param k: the number of clusters required. | ||||
@@ -161,7 +161,7 @@ def k_means_1d(x, k, max_iter=100): | |||||
def k_means_bucketing(lengths, buckets): | def k_means_bucketing(lengths, buckets): | ||||
"""Assign all instances into possible buckets using k-means, such that instances in the same bucket have similar lengths. | |||||
r"""Assign all instances into possible buckets using k-means, such that instances in the same bucket have similar lengths. | |||||
:param lengths: list of int, the length of all samples. | :param lengths: list of int, the length of all samples. | ||||
:param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length | :param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
tester模块实现了 fastNLP 所需的Tester类,能在提供数据、模型以及metric的情况下进行性能测试。 | tester模块实现了 fastNLP 所需的Tester类,能在提供数据、模型以及metric的情况下进行性能测试。 | ||||
.. code-block:: | .. code-block:: | ||||
@@ -64,12 +64,12 @@ __all__ = [ | |||||
class Tester(object): | class Tester(object): | ||||
""" | |||||
r""" | |||||
Tester是在提供数据,模型以及metric的情况下进行性能测试的类。需要传入模型,数据以及metric进行验证。 | Tester是在提供数据,模型以及metric的情况下进行性能测试的类。需要传入模型,数据以及metric进行验证。 | ||||
""" | """ | ||||
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True): | def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True): | ||||
""" | |||||
r""" | |||||
:param ~fastNLP.DataSet,~fastNLP.BatchIter data: 需要测试的数据集 | :param ~fastNLP.DataSet,~fastNLP.BatchIter data: 需要测试的数据集 | ||||
:param torch.nn.Module model: 使用的模型 | :param torch.nn.Module model: 使用的模型 | ||||
@@ -196,7 +196,7 @@ class Tester(object): | |||||
return eval_results | return eval_results | ||||
def _mode(self, model, is_test=False): | def _mode(self, model, is_test=False): | ||||
"""Train mode or Test mode. This is for PyTorch currently. | |||||
r"""Train mode or Test mode. This is for PyTorch currently. | |||||
:param model: a PyTorch model | :param model: a PyTorch model | ||||
:param is_test: bool, whether in test mode or not. | :param is_test: bool, whether in test mode or not. | ||||
@@ -208,13 +208,13 @@ class Tester(object): | |||||
model.train() | model.train() | ||||
def _data_forward(self, func, x): | def _data_forward(self, func, x): | ||||
"""A forward pass of the model. """ | |||||
r"""A forward pass of the model. """ | |||||
x = _build_args(func, **x) | x = _build_args(func, **x) | ||||
y = self._predict_func_wrapper(**x) | y = self._predict_func_wrapper(**x) | ||||
return y | return y | ||||
def _format_eval_results(self, results): | def _format_eval_results(self, results): | ||||
"""Override this method to support more print formats. | |||||
r"""Override this method to support more print formats. | |||||
:param results: dict, (str: float) is (metrics name: value) | :param results: dict, (str: float) is (metrics name: value) | ||||
@@ -357,7 +357,7 @@ from ._logger import logger | |||||
class Trainer(object): | class Trainer(object): | ||||
""" | |||||
r""" | |||||
Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在不同训练任务中重复撰写 | Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在不同训练任务中重复撰写 | ||||
(1) epoch循环; | (1) epoch循环; | ||||
(2) 将数据分成不同的Batch; | (2) 将数据分成不同的Batch; | ||||
@@ -572,7 +572,7 @@ class Trainer(object): | |||||
callbacks=callbacks) | callbacks=callbacks) | ||||
def train(self, load_best_model=True, on_exception='auto'): | def train(self, load_best_model=True, on_exception='auto'): | ||||
""" | |||||
r""" | |||||
使用该函数使Trainer开始训练。 | 使用该函数使Trainer开始训练。 | ||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | ||||
@@ -728,7 +728,7 @@ class Trainer(object): | |||||
return res | return res | ||||
def _mode(self, model, is_test=False): | def _mode(self, model, is_test=False): | ||||
"""Train mode or Test mode. This is for PyTorch currently. | |||||
r"""Train mode or Test mode. This is for PyTorch currently. | |||||
:param model: a PyTorch model | :param model: a PyTorch model | ||||
:param bool is_test: whether in test mode or not. | :param bool is_test: whether in test mode or not. | ||||
@@ -740,7 +740,7 @@ class Trainer(object): | |||||
model.train() | model.train() | ||||
def _update(self): | def _update(self): | ||||
"""Perform weight update on a model. | |||||
r"""Perform weight update on a model. | |||||
""" | """ | ||||
if self.step % self.update_every == 0: | if self.step % self.update_every == 0: | ||||
@@ -755,7 +755,7 @@ class Trainer(object): | |||||
return y | return y | ||||
def _grad_backward(self, loss): | def _grad_backward(self, loss): | ||||
"""Compute gradient with link rules. | |||||
r"""Compute gradient with link rules. | |||||
:param loss: a scalar where back-prop starts | :param loss: a scalar where back-prop starts | ||||
@@ -766,7 +766,7 @@ class Trainer(object): | |||||
loss.backward() | loss.backward() | ||||
def _compute_loss(self, predict, truth): | def _compute_loss(self, predict, truth): | ||||
"""Compute loss given prediction and ground truth. | |||||
r"""Compute loss given prediction and ground truth. | |||||
:param predict: prediction dict, produced by model.forward | :param predict: prediction dict, produced by model.forward | ||||
:param truth: ground truth dict, produced by batch_y | :param truth: ground truth dict, produced by batch_y | ||||
@@ -775,7 +775,7 @@ class Trainer(object): | |||||
return self.losser(predict, truth) | return self.losser(predict, truth) | ||||
def _save_model(self, model, model_name, only_param=False): | def _save_model(self, model, model_name, only_param=False): | ||||
""" 存储不含有显卡信息的state_dict或model | |||||
r""" 存储不含有显卡信息的state_dict或model | |||||
:param model: | :param model: | ||||
:param model_name: | :param model_name: | ||||
:param only_param: | :param only_param: | ||||
@@ -816,7 +816,7 @@ class Trainer(object): | |||||
return True | return True | ||||
def _better_eval_result(self, metrics): | def _better_eval_result(self, metrics): | ||||
"""Check if the current epoch yields better validation results. | |||||
r"""Check if the current epoch yields better validation results. | |||||
:return bool value: True means current results on dev set is the best. | :return bool value: True means current results on dev set is the best. | ||||
""" | """ | ||||
@@ -842,7 +842,7 @@ class Trainer(object): | |||||
@property | @property | ||||
def is_master(self): | def is_master(self): | ||||
"""是否是主进程""" | |||||
r"""是否是主进程""" | |||||
return True | return True | ||||
DEFAULT_CHECK_BATCH_SIZE = 2 | DEFAULT_CHECK_BATCH_SIZE = 2 | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 | utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 | ||||
""" | """ | ||||
@@ -35,9 +35,9 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require | |||||
class ConfusionMatrix: | class ConfusionMatrix: | ||||
"""a dict can provide Confusion Matrix""" | |||||
r"""a dict can provide Confusion Matrix""" | |||||
def __init__(self, vocab=None, print_ratio=False): | def __init__(self, vocab=None, print_ratio=False): | ||||
""" | |||||
r""" | |||||
:param vocab: 需要有to_word方法,建议直接使用Fastnlp.core.Vocabulary。 | :param vocab: 需要有to_word方法,建议直接使用Fastnlp.core.Vocabulary。 | ||||
:param print_ratio: 限制print的输出,False只输出数量Confusion Matrix, True还会输出百分比Confusion Matrix, 分别为行/列 | :param print_ratio: 限制print的输出,False只输出数量Confusion Matrix, True还会输出百分比Confusion Matrix, 分别为行/列 | ||||
""" | """ | ||||
@@ -52,7 +52,7 @@ class ConfusionMatrix: | |||||
self.print_ratio = print_ratio | self.print_ratio = print_ratio | ||||
def add_pred_target(self, pred, target): # 一组结果 | def add_pred_target(self, pred, target): # 一组结果 | ||||
""" | |||||
r""" | |||||
通过这个函数向ConfusionMatrix加入一组预测结果 | 通过这个函数向ConfusionMatrix加入一组预测结果 | ||||
:param list pred: 预测的标签列表 | :param list pred: 预测的标签列表 | ||||
:param list target: 真实值的标签列表 | :param list target: 真实值的标签列表 | ||||
@@ -80,7 +80,7 @@ class ConfusionMatrix: | |||||
return self.confusiondict | return self.confusiondict | ||||
def clear(self): | def clear(self): | ||||
""" | |||||
r""" | |||||
清空ConfusionMatrix,等待再次新加入 | 清空ConfusionMatrix,等待再次新加入 | ||||
:return: | :return: | ||||
""" | """ | ||||
@@ -89,7 +89,7 @@ class ConfusionMatrix: | |||||
self.predcount = {} | self.predcount = {} | ||||
def get_result(self): | def get_result(self): | ||||
""" | |||||
r""" | |||||
:return list output: ConfusionMatrix content,具体值与汇总统计 | :return list output: ConfusionMatrix content,具体值与汇总统计 | ||||
""" | """ | ||||
row2idx = {} | row2idx = {} | ||||
@@ -121,7 +121,7 @@ class ConfusionMatrix: | |||||
return output | return output | ||||
def get_percent(self, dim=0): | def get_percent(self, dim=0): | ||||
""" | |||||
r""" | |||||
:param dim int: 0/1, 0 for row,1 for column | :param dim int: 0/1, 0 for row,1 for column | ||||
:return list output: ConfusionMatrix content,具体值与汇总统计 | :return list output: ConfusionMatrix content,具体值与汇总统计 | ||||
""" | """ | ||||
@@ -139,7 +139,7 @@ class ConfusionMatrix: | |||||
return tmp.tolist() | return tmp.tolist() | ||||
def get_aligned_table(self, data, flag="result"): | def get_aligned_table(self, data, flag="result"): | ||||
""" | |||||
r""" | |||||
:param data: highly recommend use get_percent/ get_result return as dataset here, or make sure data is a n*n list type data | :param data: highly recommend use get_percent/ get_result return as dataset here, or make sure data is a n*n list type data | ||||
:param flag: only difference between result and other words is whether "%" is in output string | :param flag: only difference between result and other words is whether "%" is in output string | ||||
:return: an aligned_table ready to print out | :return: an aligned_table ready to print out | ||||
@@ -197,7 +197,7 @@ class ConfusionMatrix: | |||||
return "\n" + out | return "\n" + out | ||||
def __repr__(self): | def __repr__(self): | ||||
""" | |||||
r""" | |||||
:return string output: ConfusionMatrix的格式化输出,包括表头各标签字段,具体值与汇总统计。 | :return string output: ConfusionMatrix的格式化输出,包括表头各标签字段,具体值与汇总统计。 | ||||
""" | """ | ||||
result = self.get_result() | result = self.get_result() | ||||
@@ -218,7 +218,7 @@ class ConfusionMatrix: | |||||
class Option(dict): | class Option(dict): | ||||
"""a dict can treat keys as attributes""" | |||||
r"""a dict can treat keys as attributes""" | |||||
def __getattr__(self, item): | def __getattr__(self, item): | ||||
try: | try: | ||||
@@ -245,7 +245,7 @@ class Option(dict): | |||||
def _prepare_cache_filepath(filepath): | def _prepare_cache_filepath(filepath): | ||||
""" | |||||
r""" | |||||
检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径 | 检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径 | ||||
:param filepath: str. | :param filepath: str. | ||||
:return: None, if not, this function will raise error | :return: None, if not, this function will raise error | ||||
@@ -259,7 +259,7 @@ def _prepare_cache_filepath(filepath): | |||||
def cache_results(_cache_fp, _refresh=False, _verbose=1): | def cache_results(_cache_fp, _refresh=False, _verbose=1): | ||||
""" | |||||
r""" | |||||
cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用:: | cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用:: | ||||
import time | import time | ||||
@@ -358,7 +358,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||
def _save_model(model, model_name, save_dir, only_param=False): | def _save_model(model, model_name, save_dir, only_param=False): | ||||
""" 存储不含有显卡信息的state_dict或model | |||||
r""" 存储不含有显卡信息的state_dict或model | |||||
:param model: | :param model: | ||||
:param model_name: | :param model_name: | ||||
:param save_dir: 保存的directory | :param save_dir: 保存的directory | ||||
@@ -383,7 +383,7 @@ def _save_model(model, model_name, save_dir, only_param=False): | |||||
def _move_model_to_device(model, device): | def _move_model_to_device(model, device): | ||||
""" | |||||
r""" | |||||
将model移动到device | 将model移动到device | ||||
:param model: torch.nn.DataParallel or torch.nn.Module. 当为torch.nn.DataParallel, 则只是调用一次cuda。device必须为 | :param model: torch.nn.DataParallel or torch.nn.Module. 当为torch.nn.DataParallel, 则只是调用一次cuda。device必须为 | ||||
@@ -454,7 +454,7 @@ def _move_model_to_device(model, device): | |||||
def _get_model_device(model): | def _get_model_device(model): | ||||
""" | |||||
r""" | |||||
传入一个nn.Module的模型,获取它所在的device | 传入一个nn.Module的模型,获取它所在的device | ||||
:param model: nn.Module | :param model: nn.Module | ||||
@@ -471,7 +471,7 @@ def _get_model_device(model): | |||||
def _build_args(func, **kwargs): | def _build_args(func, **kwargs): | ||||
""" | |||||
r""" | |||||
根据func的初始化参数,从kwargs中选择func需要的参数 | 根据func的初始化参数,从kwargs中选择func需要的参数 | ||||
:param func: callable | :param func: callable | ||||
@@ -555,7 +555,7 @@ def _check_arg_dict_list(func, args): | |||||
def _get_func_signature(func): | def _get_func_signature(func): | ||||
""" | |||||
r""" | |||||
Given a function or method, return its signature. | Given a function or method, return its signature. | ||||
For example: | For example: | ||||
@@ -596,7 +596,7 @@ def _get_func_signature(func): | |||||
def _is_function_or_method(func): | def _is_function_or_method(func): | ||||
""" | |||||
r""" | |||||
:param func: | :param func: | ||||
:return: | :return: | ||||
@@ -612,7 +612,7 @@ def _check_function_or_method(func): | |||||
def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False): | def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False): | ||||
""" | |||||
r""" | |||||
move data to model's device, element in *args should be dict. This is a inplace change. | move data to model's device, element in *args should be dict. This is a inplace change. | ||||
:param device: torch.device | :param device: torch.device | ||||
@@ -636,7 +636,7 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False): | |||||
class _CheckError(Exception): | class _CheckError(Exception): | ||||
""" | |||||
r""" | |||||
_CheckError. Used in losses.LossBase, metrics.MetricBase. | _CheckError. Used in losses.LossBase, metrics.MetricBase. | ||||
""" | """ | ||||
@@ -807,7 +807,7 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||||
def seq_len_to_mask(seq_len, max_len=None): | def seq_len_to_mask(seq_len, max_len=None): | ||||
""" | |||||
r""" | |||||
将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。 | 将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。 | ||||
转变 1-d seq_len到2-d mask. | 转变 1-d seq_len到2-d mask. | ||||
@@ -851,7 +851,7 @@ def seq_len_to_mask(seq_len, max_len=None): | |||||
class _pseudo_tqdm: | class _pseudo_tqdm: | ||||
""" | |||||
r""" | |||||
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | 当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | ||||
""" | """ | ||||
@@ -878,7 +878,7 @@ class _pseudo_tqdm: | |||||
def iob2(tags: List[str]) -> List[str]: | def iob2(tags: List[str]) -> List[str]: | ||||
""" | |||||
r""" | |||||
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两者的差异见 | 检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两者的差异见 | ||||
https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format | https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format | ||||
@@ -902,7 +902,7 @@ def iob2(tags: List[str]) -> List[str]: | |||||
def iob2bioes(tags: List[str]) -> List[str]: | def iob2bioes(tags: List[str]) -> List[str]: | ||||
""" | |||||
r""" | |||||
将iob的tag转换为bioes编码 | 将iob的tag转换为bioes编码 | ||||
:param tags: List[str]. 编码需要是大写的。 | :param tags: List[str]. 编码需要是大写的。 | ||||
:return: | :return: | ||||
@@ -938,7 +938,7 @@ def _is_iterable(value): | |||||
def get_seq_len(words, pad_value=0): | def get_seq_len(words, pad_value=0): | ||||
""" | |||||
r""" | |||||
给定batch_size x max_len的words矩阵,返回句子长度 | 给定batch_size x max_len的words矩阵,返回句子长度 | ||||
:param words: batch_size x max_len | :param words: batch_size x max_len | ||||
@@ -949,7 +949,7 @@ def get_seq_len(words, pad_value=0): | |||||
def pretty_table_printer(dataset_or_ins) -> PrettyTable: | def pretty_table_printer(dataset_or_ins) -> PrettyTable: | ||||
""" | |||||
r""" | |||||
:param dataset_or_ins: 传入一个dataSet或者instance | :param dataset_or_ins: 传入一个dataSet或者instance | ||||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) | ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) | ||||
+-----------+-----------+-----------------+ | +-----------+-----------+-----------------+ | ||||
@@ -990,7 +990,7 @@ def pretty_table_printer(dataset_or_ins) -> PrettyTable: | |||||
def sub_column(string: str, c: int, c_size: int, title: str) -> str: | def sub_column(string: str, c: int, c_size: int, title: str) -> str: | ||||
""" | |||||
r""" | |||||
:param string: 要被截断的字符串 | :param string: 要被截断的字符串 | ||||
:param c: 命令行列数 | :param c: 命令行列数 | ||||
:param c_size: instance或dataset field数 | :param c_size: instance或dataset field数 | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
.. todo:: | .. todo:: | ||||
doc | doc | ||||
""" | """ | ||||
@@ -33,7 +33,7 @@ class VocabularyOption(Option): | |||||
def _check_build_vocab(func): | def _check_build_vocab(func): | ||||
"""A decorator to make sure the indexing is built before used. | |||||
r"""A decorator to make sure the indexing is built before used. | |||||
""" | """ | ||||
@@ -47,7 +47,7 @@ def _check_build_vocab(func): | |||||
def _check_build_status(func): | def _check_build_status(func): | ||||
"""A decorator to check whether the vocabulary updates after the last build. | |||||
r"""A decorator to check whether the vocabulary updates after the last build. | |||||
""" | """ | ||||
@@ -65,7 +65,7 @@ def _check_build_status(func): | |||||
class Vocabulary(object): | class Vocabulary(object): | ||||
""" | |||||
r""" | |||||
用于构建, 存储和使用 `str` 到 `int` 的一一映射:: | 用于构建, 存储和使用 `str` 到 `int` 的一一映射:: | ||||
vocab = Vocabulary() | vocab = Vocabulary() | ||||
@@ -76,7 +76,7 @@ class Vocabulary(object): | |||||
""" | """ | ||||
def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'): | def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'): | ||||
""" | |||||
r""" | |||||
:param int max_size: `Vocabulary` 的最大大小, 即能存储词的最大数量 | :param int max_size: `Vocabulary` 的最大大小, 即能存储词的最大数量 | ||||
若为 ``None`` , 则不限制大小. Default: ``None`` | 若为 ``None`` , 则不限制大小. Default: ``None`` | ||||
@@ -121,7 +121,7 @@ class Vocabulary(object): | |||||
@_check_build_status | @_check_build_status | ||||
def update(self, word_lst, no_create_entry=False): | def update(self, word_lst, no_create_entry=False): | ||||
"""依次增加序列中词在词典中的出现频率 | |||||
r"""依次增加序列中词在词典中的出现频率 | |||||
:param list word_lst: a list of strings | :param list word_lst: a list of strings | ||||
:param bool no_create_entry: 在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | :param bool no_create_entry: 在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | ||||
@@ -137,7 +137,7 @@ class Vocabulary(object): | |||||
@_check_build_status | @_check_build_status | ||||
def add(self, word, no_create_entry=False): | def add(self, word, no_create_entry=False): | ||||
""" | |||||
r""" | |||||
增加一个新词在词典中的出现频率 | 增加一个新词在词典中的出现频率 | ||||
:param str word: 新词 | :param str word: 新词 | ||||
@@ -153,7 +153,7 @@ class Vocabulary(object): | |||||
return self | return self | ||||
def _add_no_create_entry(self, word, no_create_entry): | def _add_no_create_entry(self, word, no_create_entry): | ||||
""" | |||||
r""" | |||||
在新加入word时,检查_no_create_word的设置。 | 在新加入word时,检查_no_create_word的设置。 | ||||
:param str List[str] word: | :param str List[str] word: | ||||
@@ -170,7 +170,7 @@ class Vocabulary(object): | |||||
@_check_build_status | @_check_build_status | ||||
def add_word(self, word, no_create_entry=False): | def add_word(self, word, no_create_entry=False): | ||||
""" | |||||
r""" | |||||
增加一个新词在词典中的出现频率 | 增加一个新词在词典中的出现频率 | ||||
:param str word: 新词 | :param str word: 新词 | ||||
@@ -185,7 +185,7 @@ class Vocabulary(object): | |||||
@_check_build_status | @_check_build_status | ||||
def add_word_lst(self, word_lst, no_create_entry=False): | def add_word_lst(self, word_lst, no_create_entry=False): | ||||
""" | |||||
r""" | |||||
依次增加序列中词在词典中的出现频率 | 依次增加序列中词在词典中的出现频率 | ||||
:param list[str] word_lst: 词的序列 | :param list[str] word_lst: 词的序列 | ||||
@@ -200,7 +200,7 @@ class Vocabulary(object): | |||||
return self | return self | ||||
def build_vocab(self): | def build_vocab(self): | ||||
""" | |||||
r""" | |||||
根据已经出现的词和出现频率构建词典. 注意: 重复构建可能会改变词典的大小, | 根据已经出现的词和出现频率构建词典. 注意: 重复构建可能会改变词典的大小, | ||||
但已经记录在词典中的词, 不会改变对应的 `int` | 但已经记录在词典中的词, 不会改变对应的 `int` | ||||
@@ -225,7 +225,7 @@ class Vocabulary(object): | |||||
return self | return self | ||||
def build_reverse_vocab(self): | def build_reverse_vocab(self): | ||||
""" | |||||
r""" | |||||
基于 `word to index` dict, 构建 `index to word` dict. | 基于 `word to index` dict, 构建 `index to word` dict. | ||||
""" | """ | ||||
@@ -238,7 +238,7 @@ class Vocabulary(object): | |||||
@_check_build_vocab | @_check_build_vocab | ||||
def __contains__(self, item): | def __contains__(self, item): | ||||
""" | |||||
r""" | |||||
检查词是否被记录 | 检查词是否被记录 | ||||
:param item: the word | :param item: the word | ||||
@@ -247,7 +247,7 @@ class Vocabulary(object): | |||||
return item in self._word2idx | return item in self._word2idx | ||||
def has_word(self, w): | def has_word(self, w): | ||||
""" | |||||
r""" | |||||
检查词是否被记录:: | 检查词是否被记录:: | ||||
has_abc = vocab.has_word('abc') | has_abc = vocab.has_word('abc') | ||||
@@ -261,7 +261,7 @@ class Vocabulary(object): | |||||
@_check_build_vocab | @_check_build_vocab | ||||
def __getitem__(self, w): | def __getitem__(self, w): | ||||
""" | |||||
r""" | |||||
To support usage like:: | To support usage like:: | ||||
vocab[w] | vocab[w] | ||||
@@ -275,7 +275,7 @@ class Vocabulary(object): | |||||
@_check_build_vocab | @_check_build_vocab | ||||
def index_dataset(self, *datasets, field_name, new_field_name=None): | def index_dataset(self, *datasets, field_name, new_field_name=None): | ||||
""" | |||||
r""" | |||||
将DataSet中对应field的词转为数字,Example:: | 将DataSet中对应field的词转为数字,Example:: | ||||
# remember to use `field_name` | # remember to use `field_name` | ||||
@@ -289,7 +289,7 @@ class Vocabulary(object): | |||||
""" | """ | ||||
def index_instance(field): | def index_instance(field): | ||||
""" | |||||
r""" | |||||
有几种情况, str, 1d-list, 2d-list | 有几种情况, str, 1d-list, 2d-list | ||||
:param ins: | :param ins: | ||||
:return: | :return: | ||||
@@ -333,7 +333,7 @@ class Vocabulary(object): | |||||
return len(self._no_create_word) | return len(self._no_create_word) | ||||
def from_dataset(self, *datasets, field_name, no_create_entry_dataset=None): | def from_dataset(self, *datasets, field_name, no_create_entry_dataset=None): | ||||
""" | |||||
r""" | |||||
使用dataset的对应field中词构建词典:: | 使用dataset的对应field中词构建词典:: | ||||
# remember to use `field_name` | # remember to use `field_name` | ||||
@@ -395,7 +395,7 @@ class Vocabulary(object): | |||||
return self | return self | ||||
def _is_word_no_create_entry(self, word): | def _is_word_no_create_entry(self, word): | ||||
""" | |||||
r""" | |||||
判断当前的word是否是不需要创建entry的,具体参见from_dataset的说明 | 判断当前的word是否是不需要创建entry的,具体参见from_dataset的说明 | ||||
:param word: str | :param word: str | ||||
:return: bool | :return: bool | ||||
@@ -403,7 +403,7 @@ class Vocabulary(object): | |||||
return word in self._no_create_word | return word in self._no_create_word | ||||
def to_index(self, w): | def to_index(self, w): | ||||
""" | |||||
r""" | |||||
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出 ``ValueError`` :: | 将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出 ``ValueError`` :: | ||||
index = vocab.to_index('abc') | index = vocab.to_index('abc') | ||||
@@ -418,7 +418,7 @@ class Vocabulary(object): | |||||
@property | @property | ||||
@_check_build_vocab | @_check_build_vocab | ||||
def unknown_idx(self): | def unknown_idx(self): | ||||
""" | |||||
r""" | |||||
unknown 对应的数字. | unknown 对应的数字. | ||||
""" | """ | ||||
if self.unknown is None: | if self.unknown is None: | ||||
@@ -428,7 +428,7 @@ class Vocabulary(object): | |||||
@property | @property | ||||
@_check_build_vocab | @_check_build_vocab | ||||
def padding_idx(self): | def padding_idx(self): | ||||
""" | |||||
r""" | |||||
padding 对应的数字 | padding 对应的数字 | ||||
""" | """ | ||||
if self.padding is None: | if self.padding is None: | ||||
@@ -437,7 +437,7 @@ class Vocabulary(object): | |||||
@_check_build_vocab | @_check_build_vocab | ||||
def to_word(self, idx): | def to_word(self, idx): | ||||
""" | |||||
r""" | |||||
给定一个数字, 将其转为对应的词. | 给定一个数字, 将其转为对应的词. | ||||
:param int idx: the index | :param int idx: the index | ||||
@@ -446,7 +446,7 @@ class Vocabulary(object): | |||||
return self._idx2word[idx] | return self._idx2word[idx] | ||||
def clear(self): | def clear(self): | ||||
""" | |||||
r""" | |||||
删除Vocabulary中的词表数据。相当于重新初始化一下。 | 删除Vocabulary中的词表数据。相当于重新初始化一下。 | ||||
:return: | :return: | ||||
@@ -459,7 +459,7 @@ class Vocabulary(object): | |||||
return self | return self | ||||
def __getstate__(self): | def __getstate__(self): | ||||
"""Use to prepare data for pickle. | |||||
r"""Use to prepare data for pickle. | |||||
""" | """ | ||||
len(self) # make sure vocab has been built | len(self) # make sure vocab has been built | ||||
@@ -469,7 +469,7 @@ class Vocabulary(object): | |||||
return state | return state | ||||
def __setstate__(self, state): | def __setstate__(self, state): | ||||
"""Use to restore state from pickle. | |||||
r"""Use to restore state from pickle. | |||||
""" | """ | ||||
self.__dict__.update(state) | self.__dict__.update(state) | ||||
@@ -484,7 +484,7 @@ class Vocabulary(object): | |||||
yield word, index | yield word, index | ||||
def save(self, filepath): | def save(self, filepath): | ||||
""" | |||||
r""" | |||||
:param str filepath: Vocabulary的储存路径 | :param str filepath: Vocabulary的储存路径 | ||||
:return: | :return: | ||||
@@ -508,7 +508,7 @@ class Vocabulary(object): | |||||
@staticmethod | @staticmethod | ||||
def load(filepath): | def load(filepath): | ||||
""" | |||||
r""" | |||||
:param str filepath: Vocabulary的读取路径 | :param str filepath: Vocabulary的读取路径 | ||||
:return: Vocabulary | :return: Vocabulary | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented | |||||
r"""undocumented | |||||
用于辅助生成 fastNLP 文档的代码 | 用于辅助生成 fastNLP 文档的代码 | ||||
""" | """ | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
embeddings 模块主要用于从各种预训练的模型中获取词语的分布式表示,目前支持的预训练模型包括word2vec, glove, ELMO, BERT等。这里所有 | embeddings 模块主要用于从各种预训练的模型中获取词语的分布式表示,目前支持的预训练模型包括word2vec, glove, ELMO, BERT等。这里所有 | ||||
embedding的forward输入都是形状为 ``(batch_size, max_len)`` 的torch.LongTensor,输出都是 ``(batch_size, max_len, embedding_dim)`` 的 | embedding的forward输入都是形状为 ``(batch_size, max_len)`` 的torch.LongTensor,输出都是 ``(batch_size, max_len, embedding_dim)`` 的 | ||||
torch.FloatTensor。所有的embedding都可以使用 `self.num_embedding` 获取最大的输入index范围, 用 `self.embeddig_dim` 或 `self.embed_size` 获取embedding的 | torch.FloatTensor。所有的embedding都可以使用 `self.num_embedding` 获取最大的输入index范围, 用 `self.embeddig_dim` 或 `self.embed_size` 获取embedding的 | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
.. todo:: | .. todo:: | ||||
doc | doc | ||||
""" | """ | ||||
@@ -24,7 +24,7 @@ from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer | |||||
class BertEmbedding(ContextualEmbedding): | class BertEmbedding(ContextualEmbedding): | ||||
""" | |||||
r""" | |||||
使用BERT对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于 | 使用BERT对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于 | ||||
预训练的bert模型长度限制为512个token,而因为输入的word是未进行word piece分割的(word piece的分割有BertEmbedding在输入word | 预训练的bert模型长度限制为512个token,而因为输入的word是未进行word piece分割的(word piece的分割有BertEmbedding在输入word | ||||
时切分),在分割之后长度可能会超过最大长度限制。 | 时切分),在分割之后长度可能会超过最大长度限制。 | ||||
@@ -57,7 +57,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', | def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', | ||||
pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False, | pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False, | ||||
pooled_cls=True, requires_grad: bool = True, auto_truncate: bool = False, **kwargs): | pooled_cls=True, requires_grad: bool = True, auto_truncate: bool = False, **kwargs): | ||||
""" | |||||
r""" | |||||
:param ~fastNLP.Vocabulary vocab: 词表 | :param ~fastNLP.Vocabulary vocab: 词表 | ||||
:param str model_dir_or_name: 模型所在目录或者模型的名称。当传入模型所在目录时,目录中应该包含一个词表文件(以.txt作为后缀名), | :param str model_dir_or_name: 模型所在目录或者模型的名称。当传入模型所在目录时,目录中应该包含一个词表文件(以.txt作为后缀名), | ||||
@@ -112,7 +112,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
del self.model | del self.model | ||||
def forward(self, words): | def forward(self, words): | ||||
""" | |||||
r""" | |||||
计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要 | 计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要 | ||||
删除这两个token的表示。 | 删除这两个token的表示。 | ||||
@@ -129,7 +129,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
return self.dropout(outputs) | return self.dropout(outputs) | ||||
def drop_word(self, words): | def drop_word(self, words): | ||||
""" | |||||
r""" | |||||
按照设定随机将words设置为unknown_index。 | 按照设定随机将words设置为unknown_index。 | ||||
:param torch.LongTensor words: batch_size x max_len | :param torch.LongTensor words: batch_size x max_len | ||||
@@ -151,7 +151,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
class BertWordPieceEncoder(nn.Module): | class BertWordPieceEncoder(nn.Module): | ||||
""" | |||||
r""" | |||||
读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。 | 读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。 | ||||
BertWordPieceEncoder可以支持自动下载权重,当前支持的模型: | BertWordPieceEncoder可以支持自动下载权重,当前支持的模型: | ||||
@@ -170,7 +170,7 @@ class BertWordPieceEncoder(nn.Module): | |||||
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, | def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, | ||||
word_dropout=0, dropout=0, requires_grad: bool = True): | word_dropout=0, dropout=0, requires_grad: bool = True): | ||||
""" | |||||
r""" | |||||
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` | :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` | ||||
:param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | :param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | ||||
@@ -204,7 +204,7 @@ class BertWordPieceEncoder(nn.Module): | |||||
return self.model.encoder.config.vocab_size | return self.model.encoder.config.vocab_size | ||||
def index_datasets(self, *datasets, field_name, add_cls_sep=True): | def index_datasets(self, *datasets, field_name, add_cls_sep=True): | ||||
""" | |||||
r""" | |||||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了 | 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了 | ||||
bert的pad value。 | bert的pad value。 | ||||
@@ -216,7 +216,7 @@ class BertWordPieceEncoder(nn.Module): | |||||
self.model.index_dataset(*datasets, field_name=field_name, add_cls_sep=add_cls_sep) | self.model.index_dataset(*datasets, field_name=field_name, add_cls_sep=add_cls_sep) | ||||
def forward(self, word_pieces, token_type_ids=None): | def forward(self, word_pieces, token_type_ids=None): | ||||
""" | |||||
r""" | |||||
计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。 | 计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。 | ||||
:param words: batch_size x max_len | :param words: batch_size x max_len | ||||
@@ -239,7 +239,7 @@ class BertWordPieceEncoder(nn.Module): | |||||
return self.dropout_layer(outputs) | return self.dropout_layer(outputs) | ||||
def drop_word(self, words): | def drop_word(self, words): | ||||
""" | |||||
r""" | |||||
按照设定随机将words设置为unknown_index。 | 按照设定随机将words设置为unknown_index。 | ||||
:param torch.LongTensor words: batch_size x max_len | :param torch.LongTensor words: batch_size x max_len | ||||
@@ -353,7 +353,7 @@ class _WordBertModel(nn.Module): | |||||
logger.debug("Successfully generate word pieces.") | logger.debug("Successfully generate word pieces.") | ||||
def forward(self, words): | def forward(self, words): | ||||
""" | |||||
r""" | |||||
:param words: torch.LongTensor, batch_size x max_len | :param words: torch.LongTensor, batch_size x max_len | ||||
:return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size | :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
该文件中主要包含的是character的Embedding,包括基于CNN与LSTM的character Embedding。与其它Embedding一样,这里的Embedding输入也是 | 该文件中主要包含的是character的Embedding,包括基于CNN与LSTM的character Embedding。与其它Embedding一样,这里的Embedding输入也是 | ||||
词的index而不需要使用词语中的char的index来获取表达。 | 词的index而不需要使用词语中的char的index来获取表达。 | ||||
""" | """ | ||||
@@ -24,7 +24,7 @@ from ..modules.encoder.lstm import LSTM | |||||
class CNNCharEmbedding(TokenEmbedding): | class CNNCharEmbedding(TokenEmbedding): | ||||
""" | |||||
r""" | |||||
使用CNN生成character embedding。CNN的结构为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool -> fc -> Dropout. | 使用CNN生成character embedding。CNN的结构为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool -> fc -> Dropout. | ||||
不同的kernel大小的fitler结果是concat起来然后通过一层fully connected layer, 然后输出word的表示。 | 不同的kernel大小的fitler结果是concat起来然后通过一层fully connected layer, 然后输出word的表示。 | ||||
@@ -46,7 +46,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
dropout: float = 0, filter_nums: List[int] = (40, 30, 20), kernel_sizes: List[int] = (5, 3, 1), | dropout: float = 0, filter_nums: List[int] = (40, 30, 20), kernel_sizes: List[int] = (5, 3, 1), | ||||
pool_method: str = 'max', activation='relu', min_char_freq: int = 2, pre_train_char_embed: str = None, | pool_method: str = 'max', activation='relu', min_char_freq: int = 2, pre_train_char_embed: str = None, | ||||
requires_grad:bool=True, include_word_start_end:bool=True): | requires_grad:bool=True, include_word_start_end:bool=True): | ||||
""" | |||||
r""" | |||||
:param vocab: 词表 | :param vocab: 词表 | ||||
:param embed_size: 该CNNCharEmbedding的输出维度大小,默认值为50. | :param embed_size: 该CNNCharEmbedding的输出维度大小,默认值为50. | ||||
@@ -122,7 +122,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
def forward(self, words): | def forward(self, words): | ||||
""" | |||||
r""" | |||||
输入words的index后,生成对应的words的表示。 | 输入words的index后,生成对应的words的表示。 | ||||
:param words: [batch_size, max_len] | :param words: [batch_size, max_len] | ||||
@@ -155,7 +155,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
class LSTMCharEmbedding(TokenEmbedding): | class LSTMCharEmbedding(TokenEmbedding): | ||||
""" | |||||
r""" | |||||
使用LSTM的方式对character进行encode. embed(x) -> Dropout(x) -> LSTM(x) -> activation(x) -> pool -> Dropout | 使用LSTM的方式对character进行encode. embed(x) -> Dropout(x) -> LSTM(x) -> activation(x) -> pool -> Dropout | ||||
Example:: | Example:: | ||||
@@ -176,7 +176,7 @@ class LSTMCharEmbedding(TokenEmbedding): | |||||
dropout: float = 0, hidden_size=50, pool_method: str = 'max', activation='relu', | dropout: float = 0, hidden_size=50, pool_method: str = 'max', activation='relu', | ||||
min_char_freq: int = 2, bidirectional=True, pre_train_char_embed: str = None, | min_char_freq: int = 2, bidirectional=True, pre_train_char_embed: str = None, | ||||
requires_grad:bool=True, include_word_start_end:bool=True): | requires_grad:bool=True, include_word_start_end:bool=True): | ||||
""" | |||||
r""" | |||||
:param vocab: 词表 | :param vocab: 词表 | ||||
:param embed_size: LSTMCharEmbedding的输出维度。默认值为50. | :param embed_size: LSTMCharEmbedding的输出维度。默认值为50. | ||||
@@ -250,7 +250,7 @@ class LSTMCharEmbedding(TokenEmbedding): | |||||
self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
def forward(self, words): | def forward(self, words): | ||||
""" | |||||
r""" | |||||
输入words的index后,生成对应的words的表示。 | 输入words的index后,生成对应的words的表示。 | ||||
:param words: [batch_size, max_len] | :param words: [batch_size, max_len] | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
.. todo:: | .. todo:: | ||||
doc | doc | ||||
""" | """ | ||||
@@ -28,7 +28,7 @@ class ContextualEmbedding(TokenEmbedding): | |||||
super(ContextualEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | super(ContextualEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | ||||
def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool = True): | def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool = True): | ||||
""" | |||||
r""" | |||||
由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。 | 由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。 | ||||
:param datasets: DataSet对象 | :param datasets: DataSet对象 | ||||
@@ -77,7 +77,7 @@ class ContextualEmbedding(TokenEmbedding): | |||||
self._delete_model_weights() | self._delete_model_weights() | ||||
def _get_sent_reprs(self, words): | def _get_sent_reprs(self, words): | ||||
""" | |||||
r""" | |||||
获取sentence的表示,如果有缓存,则返回缓存的值; 没有缓存则返回None | 获取sentence的表示,如果有缓存,则返回缓存的值; 没有缓存则返回None | ||||
:param words: torch.LongTensor | :param words: torch.LongTensor | ||||
@@ -101,11 +101,11 @@ class ContextualEmbedding(TokenEmbedding): | |||||
@abstractmethod | @abstractmethod | ||||
def _delete_model_weights(self): | def _delete_model_weights(self): | ||||
"""删除计算表示的模型以节省资源""" | |||||
r"""删除计算表示的模型以节省资源""" | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def remove_sentence_cache(self): | def remove_sentence_cache(self): | ||||
""" | |||||
r""" | |||||
删除缓存的句子表示. 删除之后如果模型权重没有被删除,将开始使用动态计算权重。 | 删除缓存的句子表示. 删除之后如果模型权重没有被删除,将开始使用动态计算权重。 | ||||
:return: | :return: | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
.. todo:: | .. todo:: | ||||
doc | doc | ||||
""" | """ | ||||
@@ -23,7 +23,7 @@ from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder | |||||
class ElmoEmbedding(ContextualEmbedding): | class ElmoEmbedding(ContextualEmbedding): | ||||
""" | |||||
r""" | |||||
使用ELMo的embedding。初始化之后,只需要传入words就可以得到对应的embedding。 | 使用ELMo的embedding。初始化之后,只需要传入words就可以得到对应的embedding。 | ||||
当前支持的使用名称初始化的模型: | 当前支持的使用名称初始化的模型: | ||||
@@ -56,7 +56,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', layers: str = '2', requires_grad: bool = True, | def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', layers: str = '2', requires_grad: bool = True, | ||||
word_dropout=0.0, dropout=0.0, cache_word_reprs: bool = False): | word_dropout=0.0, dropout=0.0, cache_word_reprs: bool = False): | ||||
""" | |||||
r""" | |||||
:param vocab: 词表 | :param vocab: 词表 | ||||
:param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo所在文件夹,该文件夹下面应该有两个文件, | :param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo所在文件夹,该文件夹下面应该有两个文件, | ||||
@@ -110,7 +110,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||||
return self.gamma.to(outputs) * outputs | return self.gamma.to(outputs) * outputs | ||||
def set_mix_weights_requires_grad(self, flag=True): | def set_mix_weights_requires_grad(self, flag=True): | ||||
""" | |||||
r""" | |||||
当初始化ElmoEmbedding时layers被设置为mix时,可以通过调用该方法设置mix weights是否可训练。如果layers不是mix,调用 | 当初始化ElmoEmbedding时layers被设置为mix时,可以通过调用该方法设置mix weights是否可训练。如果layers不是mix,调用 | ||||
该方法没有用。 | 该方法没有用。 | ||||
@@ -130,7 +130,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||||
return outputs | return outputs | ||||
def forward(self, words: torch.LongTensor): | def forward(self, words: torch.LongTensor): | ||||
""" | |||||
r""" | |||||
计算words的elmo embedding表示。根据elmo文章中介绍的ELMO实际上是有2L+1层结果,但是为了让结果比较容易拆分,token的 | 计算words的elmo embedding表示。根据elmo文章中介绍的ELMO实际上是有2L+1层结果,但是为了让结果比较容易拆分,token的 | ||||
被重复了一次,使得实际上layer=0的结果是[token_embedding;token_embedding], 而layer=1的结果是[forward_hiddens; | 被重复了一次,使得实际上layer=0的结果是[token_embedding;token_embedding], 而layer=1的结果是[forward_hiddens; | ||||
backward_hiddens]. | backward_hiddens]. | ||||
@@ -153,7 +153,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||||
class _ElmoModel(nn.Module): | class _ElmoModel(nn.Module): | ||||
""" | |||||
r""" | |||||
该Module是ElmoEmbedding中进行所有的heavy lifting的地方。做的工作,包括 | 该Module是ElmoEmbedding中进行所有的heavy lifting的地方。做的工作,包括 | ||||
(1) 根据配置,加载模型; | (1) 根据配置,加载模型; | ||||
(2) 根据vocab,对模型中的embedding进行调整. 并将其正确初始化 | (2) 根据vocab,对模型中的embedding进行调整. 并将其正确初始化 | ||||
@@ -295,7 +295,7 @@ class _ElmoModel(nn.Module): | |||||
logger.info("There is no need to cache word representations, since no character information is used.") | logger.info("There is no need to cache word representations, since no character information is used.") | ||||
def forward(self, words): | def forward(self, words): | ||||
""" | |||||
r""" | |||||
:param words: batch_size x max_len | :param words: batch_size x max_len | ||||
:return: num_layers x batch_size x max_len x hidden_size | :return: num_layers x batch_size x max_len x hidden_size | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
该模块中的Embedding主要用于随机初始化的embedding(更推荐使用 :class:`fastNLP.embeddings.StaticEmbedding` ),或按照预训练权重初始化Embedding。 | 该模块中的Embedding主要用于随机初始化的embedding(更推荐使用 :class:`fastNLP.embeddings.StaticEmbedding` ),或按照预训练权重初始化Embedding。 | ||||
""" | """ | ||||
@@ -17,7 +17,7 @@ from .utils import get_embeddings | |||||
class Embedding(nn.Module): | class Embedding(nn.Module): | ||||
""" | |||||
r""" | |||||
词向量嵌入,支持输入多种方式初始化. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度. | 词向量嵌入,支持输入多种方式初始化. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度. | ||||
Example:: | Example:: | ||||
@@ -32,7 +32,7 @@ class Embedding(nn.Module): | |||||
""" | """ | ||||
def __init__(self, init_embed, word_dropout=0, dropout=0.0, unk_index=None): | def __init__(self, init_embed, word_dropout=0, dropout=0.0, unk_index=None): | ||||
""" | |||||
r""" | |||||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: 支持传入Embedding的大小(传入tuple(int, int), | :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: 支持传入Embedding的大小(传入tuple(int, int), | ||||
第一个int为vocab_zie, 第二个int为embed_dim); 或传入Tensor, Embedding, numpy.ndarray等则直接使用该值初始化Embedding; | 第一个int为vocab_zie, 第二个int为embed_dim); 或传入Tensor, Embedding, numpy.ndarray等则直接使用该值初始化Embedding; | ||||
@@ -62,7 +62,7 @@ class Embedding(nn.Module): | |||||
self.word_dropout = word_dropout | self.word_dropout = word_dropout | ||||
def forward(self, words): | def forward(self, words): | ||||
""" | |||||
r""" | |||||
:param torch.LongTensor words: [batch, seq_len] | :param torch.LongTensor words: [batch, seq_len] | ||||
:return: torch.Tensor : [batch, seq_len, embed_dim] | :return: torch.Tensor : [batch, seq_len, embed_dim] | ||||
""" | """ | ||||
@@ -93,7 +93,7 @@ class Embedding(nn.Module): | |||||
@property | @property | ||||
def requires_grad(self): | def requires_grad(self): | ||||
""" | |||||
r""" | |||||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | ||||
:return: | :return: | ||||
""" | """ | ||||
@@ -118,7 +118,7 @@ class Embedding(nn.Module): | |||||
class TokenEmbedding(nn.Module): | class TokenEmbedding(nn.Module): | ||||
""" | |||||
r""" | |||||
fastNLP中各种Embedding的基类 | fastNLP中各种Embedding的基类 | ||||
""" | """ | ||||
@@ -136,7 +136,7 @@ class TokenEmbedding(nn.Module): | |||||
self.dropout_layer = nn.Dropout(dropout) | self.dropout_layer = nn.Dropout(dropout) | ||||
def drop_word(self, words): | def drop_word(self, words): | ||||
""" | |||||
r""" | |||||
按照设定随机将words设置为unknown_index。 | 按照设定随机将words设置为unknown_index。 | ||||
:param torch.LongTensor words: batch_size x max_len | :param torch.LongTensor words: batch_size x max_len | ||||
@@ -151,7 +151,7 @@ class TokenEmbedding(nn.Module): | |||||
return words | return words | ||||
def dropout(self, words): | def dropout(self, words): | ||||
""" | |||||
r""" | |||||
对embedding后的word表示进行drop。 | 对embedding后的word表示进行drop。 | ||||
:param torch.FloatTensor words: batch_size x max_len x embed_size | :param torch.FloatTensor words: batch_size x max_len x embed_size | ||||
@@ -161,7 +161,7 @@ class TokenEmbedding(nn.Module): | |||||
@property | @property | ||||
def requires_grad(self): | def requires_grad(self): | ||||
""" | |||||
r""" | |||||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | ||||
:return: | :return: | ||||
""" | """ | ||||
@@ -189,14 +189,14 @@ class TokenEmbedding(nn.Module): | |||||
@property | @property | ||||
def num_embedding(self) -> int: | def num_embedding(self) -> int: | ||||
""" | |||||
r""" | |||||
这个值可能会大于实际的embedding矩阵的大小。 | 这个值可能会大于实际的embedding矩阵的大小。 | ||||
:return: | :return: | ||||
""" | """ | ||||
return len(self._word_vocab) | return len(self._word_vocab) | ||||
def get_word_vocab(self): | def get_word_vocab(self): | ||||
""" | |||||
r""" | |||||
返回embedding的词典。 | 返回embedding的词典。 | ||||
:return: Vocabulary | :return: Vocabulary | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
.. todo:: | .. todo:: | ||||
doc | doc | ||||
""" | """ | ||||
@@ -16,7 +16,7 @@ from .embedding import TokenEmbedding | |||||
class StackEmbedding(TokenEmbedding): | class StackEmbedding(TokenEmbedding): | ||||
""" | |||||
r""" | |||||
支持将多个embedding集合成一个embedding。 | 支持将多个embedding集合成一个embedding。 | ||||
Example:: | Example:: | ||||
@@ -31,7 +31,7 @@ class StackEmbedding(TokenEmbedding): | |||||
""" | """ | ||||
def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0): | def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0): | ||||
""" | |||||
r""" | |||||
:param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致 | :param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致 | ||||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。不同embedidng会在相同的位置 | :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。不同embedidng会在相同的位置 | ||||
@@ -54,7 +54,7 @@ class StackEmbedding(TokenEmbedding): | |||||
self._embed_size = sum([embed.embed_size for embed in self.embeds]) | self._embed_size = sum([embed.embed_size for embed in self.embeds]) | ||||
def append(self, embed: TokenEmbedding): | def append(self, embed: TokenEmbedding): | ||||
""" | |||||
r""" | |||||
添加一个embedding到结尾。 | 添加一个embedding到结尾。 | ||||
:param embed: | :param embed: | ||||
:return: | :return: | ||||
@@ -65,7 +65,7 @@ class StackEmbedding(TokenEmbedding): | |||||
return self | return self | ||||
def pop(self): | def pop(self): | ||||
""" | |||||
r""" | |||||
弹出最后一个embed | 弹出最后一个embed | ||||
:return: | :return: | ||||
""" | """ | ||||
@@ -75,14 +75,14 @@ class StackEmbedding(TokenEmbedding): | |||||
@property | @property | ||||
def embed_size(self): | def embed_size(self): | ||||
""" | |||||
r""" | |||||
该Embedding输出的vector的最后一维的维度。 | 该Embedding输出的vector的最后一维的维度。 | ||||
:return: | :return: | ||||
""" | """ | ||||
return self._embed_size | return self._embed_size | ||||
def forward(self, words): | def forward(self, words): | ||||
""" | |||||
r""" | |||||
得到多个embedding的结果,并把结果按照顺序concat起来。 | 得到多个embedding的结果,并把结果按照顺序concat起来。 | ||||
:param words: batch_size x max_len | :param words: batch_size x max_len | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
.. todo:: | .. todo:: | ||||
doc | doc | ||||
""" | """ | ||||
@@ -23,7 +23,7 @@ from ..modules.utils import _get_file_name_base_on_postfix | |||||
class StaticEmbedding(TokenEmbedding): | class StaticEmbedding(TokenEmbedding): | ||||
""" | |||||
r""" | |||||
StaticEmbedding组件. 给定预训练embedding的名称或路径,根据vocab从embedding中抽取相应的数据(只会将出现在vocab中的词抽取出来, | StaticEmbedding组件. 给定预训练embedding的名称或路径,根据vocab从embedding中抽取相应的数据(只会将出现在vocab中的词抽取出来, | ||||
如果没有找到,则会随机初始化一个值(但如果该word是被标记为no_create_entry的话,则不会单独创建一个值,而是会被指向unk的index))。 | 如果没有找到,则会随机初始化一个值(但如果该word是被标记为no_create_entry的话,则不会单独创建一个值,而是会被指向unk的index))。 | ||||
当前支持自动下载的预训练vector有: | 当前支持自动下载的预训练vector有: | ||||
@@ -72,7 +72,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', embedding_dim=-1, requires_grad: bool = True, | def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', embedding_dim=-1, requires_grad: bool = True, | ||||
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs): | init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs): | ||||
""" | |||||
r""" | |||||
:param vocab: Vocabulary. 若该项为None则会读取所有的embedding。 | :param vocab: Vocabulary. 若该项为None则会读取所有的embedding。 | ||||
:param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个 | :param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个 | ||||
@@ -204,7 +204,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
return self.embedding.weight | return self.embedding.weight | ||||
def _randomly_init_embed(self, num_embedding, embedding_dim, init_embed=None): | def _randomly_init_embed(self, num_embedding, embedding_dim, init_embed=None): | ||||
""" | |||||
r""" | |||||
:param int num_embedding: embedding的entry的数量 | :param int num_embedding: embedding的entry的数量 | ||||
:param int embedding_dim: embedding的维度大小 | :param int embedding_dim: embedding的维度大小 | ||||
@@ -222,7 +222,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', | def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', | ||||
error='ignore', init_method=None): | error='ignore', init_method=None): | ||||
""" | |||||
r""" | |||||
从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 | 从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 | ||||
word2vec(第一行只有两个元素)还是glove格式的数据。 | word2vec(第一行只有两个元素)还是glove格式的数据。 | ||||
@@ -309,7 +309,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
return vectors | return vectors | ||||
def forward(self, words): | def forward(self, words): | ||||
""" | |||||
r""" | |||||
传入words的index | 传入words的index | ||||
:param words: torch.LongTensor, [batch_size, max_len] | :param words: torch.LongTensor, [batch_size, max_len] | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
.. todo:: | .. todo:: | ||||
doc | doc | ||||
""" | """ | ||||
@@ -14,7 +14,7 @@ __all__ = [ | |||||
def _construct_char_vocab_from_vocab(vocab: Vocabulary, min_freq: int = 1, include_word_start_end=True): | def _construct_char_vocab_from_vocab(vocab: Vocabulary, min_freq: int = 1, include_word_start_end=True): | ||||
""" | |||||
r""" | |||||
给定一个word的vocabulary生成character的vocabulary. | 给定一个word的vocabulary生成character的vocabulary. | ||||
:param vocab: 从vocab | :param vocab: 从vocab | ||||
@@ -32,7 +32,7 @@ def _construct_char_vocab_from_vocab(vocab: Vocabulary, min_freq: int = 1, inclu | |||||
def get_embeddings(init_embed): | def get_embeddings(init_embed): | ||||
""" | |||||
r""" | |||||
根据输入的init_embed返回Embedding对象。如果输入是tuple, 则随机初始化一个nn.Embedding; 如果输入是numpy.ndarray, 则按照ndarray | 根据输入的init_embed返回Embedding对象。如果输入是tuple, 则随机初始化一个nn.Embedding; 如果输入是numpy.ndarray, 则按照ndarray | ||||
的值将nn.Embedding初始化; 如果输入是torch.Tensor, 则按该值初始化nn.Embedding; 如果输入是fastNLP中的embedding将不做处理 | 的值将nn.Embedding初始化; 如果输入是torch.Tensor, 则按该值初始化nn.Embedding; 如果输入是fastNLP中的embedding将不做处理 | ||||
返回原对象。 | 返回原对象。 | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
用于IO的模块, 具体包括: | 用于IO的模块, 具体包括: | ||||
1. 用于读入 embedding 的 :mod:`EmbedLoader <fastNLP.io.embed_loader>` 类, | 1. 用于读入 embedding 的 :mod:`EmbedLoader <fastNLP.io.embed_loader>` 类, | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
.. todo:: | .. todo:: | ||||
doc | doc | ||||
""" | """ | ||||
@@ -14,7 +14,7 @@ from ..core._logger import logger | |||||
class DataBundle: | class DataBundle: | ||||
""" | |||||
r""" | |||||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。该对象一般由fastNLP中各种 | 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。该对象一般由fastNLP中各种 | ||||
Loader的load函数生成,可以通过以下的方法获取里面的内容 | Loader的load函数生成,可以通过以下的方法获取里面的内容 | ||||
@@ -28,7 +28,7 @@ class DataBundle: | |||||
""" | """ | ||||
def __init__(self, vocabs: dict = None, datasets: dict = None): | def __init__(self, vocabs: dict = None, datasets: dict = None): | ||||
""" | |||||
r""" | |||||
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | :param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | ||||
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict | :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict | ||||
@@ -37,7 +37,7 @@ class DataBundle: | |||||
self.datasets = datasets or {} | self.datasets = datasets or {} | ||||
def set_vocab(self, vocab, field_name): | def set_vocab(self, vocab, field_name): | ||||
""" | |||||
r""" | |||||
向DataBunlde中增加vocab | 向DataBunlde中增加vocab | ||||
:param ~fastNLP.Vocabulary vocab: 词表 | :param ~fastNLP.Vocabulary vocab: 词表 | ||||
@@ -49,7 +49,7 @@ class DataBundle: | |||||
return self | return self | ||||
def set_dataset(self, dataset, name: str): | def set_dataset(self, dataset, name: str): | ||||
""" | |||||
r""" | |||||
:param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet | :param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet | ||||
:param str name: dataset的名称 | :param str name: dataset的名称 | ||||
@@ -60,7 +60,7 @@ class DataBundle: | |||||
return self | return self | ||||
def get_dataset(self, name: str) -> DataSet: | def get_dataset(self, name: str) -> DataSet: | ||||
""" | |||||
r""" | |||||
获取名为name的dataset | 获取名为name的dataset | ||||
:param str name: dataset的名称,一般为'train', 'dev', 'test' | :param str name: dataset的名称,一般为'train', 'dev', 'test' | ||||
@@ -75,7 +75,7 @@ class DataBundle: | |||||
raise KeyError(error_msg) | raise KeyError(error_msg) | ||||
def delete_dataset(self, name: str): | def delete_dataset(self, name: str): | ||||
""" | |||||
r""" | |||||
删除名为name的DataSet | 删除名为name的DataSet | ||||
:param str name: | :param str name: | ||||
@@ -85,7 +85,7 @@ class DataBundle: | |||||
return self | return self | ||||
def get_vocab(self, field_name: str) -> Vocabulary: | def get_vocab(self, field_name: str) -> Vocabulary: | ||||
""" | |||||
r""" | |||||
获取field名为field_name对应的vocab | 获取field名为field_name对应的vocab | ||||
:param str field_name: 名称 | :param str field_name: 名称 | ||||
@@ -100,7 +100,7 @@ class DataBundle: | |||||
raise KeyError(error_msg) | raise KeyError(error_msg) | ||||
def delete_vocab(self, field_name: str): | def delete_vocab(self, field_name: str): | ||||
""" | |||||
r""" | |||||
删除vocab | 删除vocab | ||||
:param str field_name: | :param str field_name: | ||||
:return: self | :return: self | ||||
@@ -117,7 +117,7 @@ class DataBundle: | |||||
return len(self.vocabs) | return len(self.vocabs) | ||||
def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True): | def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True): | ||||
""" | |||||
r""" | |||||
将field_names中的field设置为input, 对data_bundle中所有的dataset执行该操作:: | 将field_names中的field设置为input, 对data_bundle中所有的dataset执行该操作:: | ||||
data_bundle.set_input('words', 'seq_len') # 将words和seq_len这两个field的input属性设置为True | data_bundle.set_input('words', 'seq_len') # 将words和seq_len这两个field的input属性设置为True | ||||
@@ -142,7 +142,7 @@ class DataBundle: | |||||
return self | return self | ||||
def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True): | def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True): | ||||
""" | |||||
r""" | |||||
将field_names中的field设置为target, 对data_bundle中所有的dataset执行该操作:: | 将field_names中的field设置为target, 对data_bundle中所有的dataset执行该操作:: | ||||
data_bundle.set_target('target', 'seq_len') # 将words和target这两个field的input属性设置为True | data_bundle.set_target('target', 'seq_len') # 将words和target这两个field的input属性设置为True | ||||
@@ -167,7 +167,7 @@ class DataBundle: | |||||
return self | return self | ||||
def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True): | def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True): | ||||
""" | |||||
r""" | |||||
将DataBundle中所有的DataSet中名为field_name的Field的padding值设置为pad_val. | 将DataBundle中所有的DataSet中名为field_name的Field的padding值设置为pad_val. | ||||
:param str field_name: | :param str field_name: | ||||
@@ -184,7 +184,7 @@ class DataBundle: | |||||
return self | return self | ||||
def set_ignore_type(self, *field_names, flag=True, ignore_miss_dataset=True): | def set_ignore_type(self, *field_names, flag=True, ignore_miss_dataset=True): | ||||
""" | |||||
r""" | |||||
将DataBundle中所有的DataSet中名为*field_names的Field的ignore_type设置为flag状态 | 将DataBundle中所有的DataSet中名为*field_names的Field的ignore_type设置为flag状态 | ||||
:param str field_names: | :param str field_names: | ||||
@@ -202,7 +202,7 @@ class DataBundle: | |||||
return self | return self | ||||
def copy_field(self, field_name, new_field_name, ignore_miss_dataset=True): | def copy_field(self, field_name, new_field_name, ignore_miss_dataset=True): | ||||
""" | |||||
r""" | |||||
将DataBundle中所有的DataSet中名为field_name的Field复制一份并命名为叫new_field_name. | 将DataBundle中所有的DataSet中名为field_name的Field复制一份并命名为叫new_field_name. | ||||
:param str field_name: | :param str field_name: | ||||
@@ -219,7 +219,7 @@ class DataBundle: | |||||
return self | return self | ||||
def rename_field(self, field_name, new_field_name, ignore_miss_dataset=True, rename_vocab=True): | def rename_field(self, field_name, new_field_name, ignore_miss_dataset=True, rename_vocab=True): | ||||
""" | |||||
r""" | |||||
将DataBundle中所有DataSet中名为field_name的field重命名为new_field_name. | 将DataBundle中所有DataSet中名为field_name的field重命名为new_field_name. | ||||
:param str field_name: | :param str field_name: | ||||
@@ -241,7 +241,7 @@ class DataBundle: | |||||
return self | return self | ||||
def delete_field(self, field_name, ignore_miss_dataset=True, delete_vocab=True): | def delete_field(self, field_name, ignore_miss_dataset=True, delete_vocab=True): | ||||
""" | |||||
r""" | |||||
将DataBundle中所有DataSet中名为field_name的field删除掉. | 将DataBundle中所有DataSet中名为field_name的field删除掉. | ||||
:param str field_name: | :param str field_name: | ||||
@@ -261,7 +261,7 @@ class DataBundle: | |||||
return self | return self | ||||
def iter_datasets(self) -> Union[str, DataSet]: | def iter_datasets(self) -> Union[str, DataSet]: | ||||
""" | |||||
r""" | |||||
迭代data_bundle中的DataSet | 迭代data_bundle中的DataSet | ||||
Example:: | Example:: | ||||
@@ -275,7 +275,7 @@ class DataBundle: | |||||
yield name, dataset | yield name, dataset | ||||
def get_dataset_names(self) -> List[str]: | def get_dataset_names(self) -> List[str]: | ||||
""" | |||||
r""" | |||||
返回DataBundle中DataSet的名称 | 返回DataBundle中DataSet的名称 | ||||
:return: | :return: | ||||
@@ -283,7 +283,7 @@ class DataBundle: | |||||
return list(self.datasets.keys()) | return list(self.datasets.keys()) | ||||
def get_vocab_names(self)->List[str]: | def get_vocab_names(self)->List[str]: | ||||
""" | |||||
r""" | |||||
返回DataBundle中Vocabulary的名称 | 返回DataBundle中Vocabulary的名称 | ||||
:return: | :return: | ||||
@@ -291,7 +291,7 @@ class DataBundle: | |||||
return list(self.vocabs.keys()) | return list(self.vocabs.keys()) | ||||
def iter_vocabs(self) -> Union[str, Vocabulary]: | def iter_vocabs(self) -> Union[str, Vocabulary]: | ||||
""" | |||||
r""" | |||||
迭代data_bundle中的DataSet | 迭代data_bundle中的DataSet | ||||
Example: | Example: | ||||
@@ -305,7 +305,7 @@ class DataBundle: | |||||
yield field_name, vocab | yield field_name, vocab | ||||
def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs): | def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs): | ||||
""" | |||||
r""" | |||||
对DataBundle中所有的dataset使用apply_field方法 | 对DataBundle中所有的dataset使用apply_field方法 | ||||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | :param callable func: input是instance中名为 `field_name` 的field的内容。 | ||||
@@ -330,7 +330,7 @@ class DataBundle: | |||||
return self | return self | ||||
def apply(self, func, new_field_name:str, **kwargs): | def apply(self, func, new_field_name:str, **kwargs): | ||||
""" | |||||
r""" | |||||
对DataBundle中所有的dataset使用apply方法 | 对DataBundle中所有的dataset使用apply方法 | ||||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | :param callable func: input是instance中名为 `field_name` 的field的内容。 | ||||
@@ -349,7 +349,7 @@ class DataBundle: | |||||
return self | return self | ||||
def add_collect_fn(self, fn, name=None): | def add_collect_fn(self, fn, name=None): | ||||
""" | |||||
r""" | |||||
向所有DataSet增加collect_fn, collect_fn详见 :class:`~fastNLP.DataSet` 中相关说明. | 向所有DataSet增加collect_fn, collect_fn详见 :class:`~fastNLP.DataSet` 中相关说明. | ||||
:param callable fn: | :param callable fn: | ||||
@@ -360,7 +360,7 @@ class DataBundle: | |||||
dataset.add_collect_fn(fn=fn, name=name) | dataset.add_collect_fn(fn=fn, name=name) | ||||
def delete_collect_fn(self, name=None): | def delete_collect_fn(self, name=None): | ||||
""" | |||||
r""" | |||||
删除DataSet中的collect_fn | 删除DataSet中的collect_fn | ||||
:param name: | :param name: | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
.. todo:: | .. todo:: | ||||
doc | doc | ||||
""" | """ | ||||
@@ -32,7 +32,7 @@ class EmbeddingOption(Option): | |||||
class EmbedLoader: | class EmbedLoader: | ||||
""" | |||||
r""" | |||||
用于读取预训练的embedding, 读取结果可直接载入为模型参数。 | 用于读取预训练的embedding, 读取结果可直接载入为模型参数。 | ||||
""" | """ | ||||
@@ -42,7 +42,7 @@ class EmbedLoader: | |||||
@staticmethod | @staticmethod | ||||
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, | def load_with_vocab(embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, | ||||
error='ignore', init_method=None): | error='ignore', init_method=None): | ||||
""" | |||||
r""" | |||||
从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 | 从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 | ||||
word2vec(第一行只有两个元素)还是glove格式的数据。 | word2vec(第一行只有两个元素)还是glove格式的数据。 | ||||
@@ -114,7 +114,7 @@ class EmbedLoader: | |||||
@staticmethod | @staticmethod | ||||
def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, | def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, | ||||
error='ignore'): | error='ignore'): | ||||
""" | |||||
r""" | |||||
从embed_filepath中读取预训练的word vector。根据预训练的词表读取embedding并生成一个对应的Vocabulary。 | 从embed_filepath中读取预训练的word vector。根据预训练的词表读取embedding并生成一个对应的Vocabulary。 | ||||
:param str embed_filepath: 预训练的embedding的路径。 | :param str embed_filepath: 预训练的embedding的路径。 | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented | |||||
r"""undocumented | |||||
此模块用于给其它模块提供读取文件的函数,没有为用户提供 API | 此模块用于给其它模块提供读取文件的函数,没有为用户提供 API | ||||
""" | """ | ||||
@@ -11,7 +11,7 @@ from ..core import logger | |||||
def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | ||||
""" | |||||
r""" | |||||
Construct a generator to read csv items. | Construct a generator to read csv items. | ||||
:param path: file path | :param path: file path | ||||
@@ -51,7 +51,7 @@ def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | |||||
def _read_json(path, encoding='utf-8', fields=None, dropna=True): | def _read_json(path, encoding='utf-8', fields=None, dropna=True): | ||||
""" | |||||
r""" | |||||
Construct a generator to read json items. | Construct a generator to read json items. | ||||
:param path: file path | :param path: file path | ||||
@@ -82,7 +82,7 @@ def _read_json(path, encoding='utf-8', fields=None, dropna=True): | |||||
def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | ||||
""" | |||||
r""" | |||||
Construct a generator to read conll items. | Construct a generator to read conll items. | ||||
:param path: file path | :param path: file path | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
.. todo:: | .. todo:: | ||||
doc | doc | ||||
""" | """ | ||||
@@ -138,7 +138,7 @@ FASTNLP_EXTEND_EMBEDDING_URL = {'elmo': 'fastnlp_elmo_url.txt', | |||||
def cached_path(url_or_filename: str, cache_dir: str = None, name=None) -> Path: | def cached_path(url_or_filename: str, cache_dir: str = None, name=None) -> Path: | ||||
""" | |||||
r""" | |||||
给定一个url,尝试通过url中的解析出来的文件名字filename到{cache_dir}/{name}/{filename}下寻找这个文件, | 给定一个url,尝试通过url中的解析出来的文件名字filename到{cache_dir}/{name}/{filename}下寻找这个文件, | ||||
1. 如果cache_dir=None, 则cache_dir=~/.fastNLP/; 否则cache_dir=cache_dir | 1. 如果cache_dir=None, 则cache_dir=~/.fastNLP/; 否则cache_dir=cache_dir | ||||
@@ -183,7 +183,7 @@ def cached_path(url_or_filename: str, cache_dir: str = None, name=None) -> Path: | |||||
def get_filepath(filepath): | def get_filepath(filepath): | ||||
""" | |||||
r""" | |||||
如果filepath为文件夹, | 如果filepath为文件夹, | ||||
如果内含多个文件, 返回filepath | 如果内含多个文件, 返回filepath | ||||
@@ -210,7 +210,7 @@ def get_filepath(filepath): | |||||
def get_cache_path(): | def get_cache_path(): | ||||
""" | |||||
r""" | |||||
获取fastNLP默认cache的存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 | 获取fastNLP默认cache的存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 | ||||
:return str: 存放路径 | :return str: 存放路径 | ||||
@@ -226,7 +226,7 @@ def get_cache_path(): | |||||
def _get_base_url(name): | def _get_base_url(name): | ||||
""" | |||||
r""" | |||||
根据name返回下载的url地址。 | 根据name返回下载的url地址。 | ||||
:param str name: 支持dataset和embedding两种 | :param str name: 支持dataset和embedding两种 | ||||
@@ -252,7 +252,7 @@ def _get_base_url(name): | |||||
def _get_embedding_url(embed_type, name): | def _get_embedding_url(embed_type, name): | ||||
""" | |||||
r""" | |||||
给定embedding类似和名称,返回下载url | 给定embedding类似和名称,返回下载url | ||||
:param str embed_type: 支持static, bert, elmo。即embedding的类型 | :param str embed_type: 支持static, bert, elmo。即embedding的类型 | ||||
@@ -276,7 +276,7 @@ def _get_embedding_url(embed_type, name): | |||||
raise KeyError(f"There is no {embed_type}. Only supports bert, elmo, static") | raise KeyError(f"There is no {embed_type}. Only supports bert, elmo, static") | ||||
def _read_extend_url_file(filename, name)->str: | def _read_extend_url_file(filename, name)->str: | ||||
""" | |||||
r""" | |||||
filename中的内容使用制表符隔开,第一列是名称,第二列是下载的url地址 | filename中的内容使用制表符隔开,第一列是名称,第二列是下载的url地址 | ||||
:param str filename: 在默认的路径下寻找file这个文件 | :param str filename: 在默认的路径下寻找file这个文件 | ||||
@@ -297,7 +297,7 @@ def _read_extend_url_file(filename, name)->str: | |||||
return None | return None | ||||
def _get_dataset_url(name): | def _get_dataset_url(name): | ||||
""" | |||||
r""" | |||||
给定dataset的名称,返回下载url | 给定dataset的名称,返回下载url | ||||
:param str name: 给定dataset的名称,比如imdb, sst-2等 | :param str name: 给定dataset的名称,比如imdb, sst-2等 | ||||
@@ -317,7 +317,7 @@ def _get_dataset_url(name): | |||||
def split_filename_suffix(filepath): | def split_filename_suffix(filepath): | ||||
""" | |||||
r""" | |||||
给定filepath 返回对应的name和suffix. 如果后缀是多个点,仅支持.tar.gz类型 | 给定filepath 返回对应的name和suffix. 如果后缀是多个点,仅支持.tar.gz类型 | ||||
:param filepath: 文件路径 | :param filepath: 文件路径 | ||||
@@ -330,7 +330,7 @@ def split_filename_suffix(filepath): | |||||
def get_from_cache(url: str, cache_dir: Path = None) -> Path: | def get_from_cache(url: str, cache_dir: Path = None) -> Path: | ||||
""" | |||||
r""" | |||||
尝试在cache_dir中寻找url定义的资源; 如果没有找到; 则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。会将下载的 | 尝试在cache_dir中寻找url定义的资源; 如果没有找到; 则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。会将下载的 | ||||
文件解压,将解压后的文件全部放在cache_dir文件夹中。 | 文件解压,将解压后的文件全部放在cache_dir文件夹中。 | ||||
@@ -469,7 +469,7 @@ def ungzip_file(file: str, to: str, filename:str): | |||||
def match_file(dir_name: str, cache_dir: Path) -> str: | def match_file(dir_name: str, cache_dir: Path) -> str: | ||||
""" | |||||
r""" | |||||
匹配的原则是: 在cache_dir下的文件与dir_name完全一致, 或除了后缀以外和dir_name完全一致。 | 匹配的原则是: 在cache_dir下的文件与dir_name完全一致, 或除了后缀以外和dir_name完全一致。 | ||||
如果找到了两个匹配的结果将报错. 如果找到了则返回匹配的文件的名称; 没有找到返回空字符串 | 如果找到了两个匹配的结果将报错. 如果找到了则返回匹配的文件的名称; 没有找到返回空字符串 | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle` 中。所有的Loader都支持以下的 | Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle` 中。所有的Loader都支持以下的 | ||||
三个方法: ``__init__`` , ``_load`` , ``loads`` . 其中 ``__init__(...)`` 用于申明读取参数,以及说明该Loader支持的数据格式, | 三个方法: ``__init__`` , ``_load`` , ``loads`` . 其中 ``__init__(...)`` 用于申明读取参数,以及说明该Loader支持的数据格式, | ||||
读取后 :class:`~fastNLP.DataSet` 中的 `field` ; ``_load(path)`` 方法传入文件路径读取单个文件,并返回 :class:`~fastNLP.DataSet` ; | 读取后 :class:`~fastNLP.DataSet` 中的 `field` ; ``_load(path)`` 方法传入文件路径读取单个文件,并返回 :class:`~fastNLP.DataSet` ; | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"CLSBaseLoader", | "CLSBaseLoader", | ||||
@@ -29,7 +29,7 @@ from ...core._logger import logger | |||||
class CLSBaseLoader(Loader): | class CLSBaseLoader(Loader): | ||||
""" | |||||
r""" | |||||
文本分类Loader的一个基类 | 文本分类Loader的一个基类 | ||||
原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。 | 原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。 | ||||
@@ -118,7 +118,7 @@ def _split_dev(dataset_name, data_dir, dev_ratio=0.0, re_download=False, suffix= | |||||
class AGsNewsLoader(CLSBaseLoader): | class AGsNewsLoader(CLSBaseLoader): | ||||
def download(self): | def download(self): | ||||
""" | |||||
r""" | |||||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | ||||
Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | ||||
@@ -131,7 +131,7 @@ class AGsNewsLoader(CLSBaseLoader): | |||||
class DBPediaLoader(CLSBaseLoader): | class DBPediaLoader(CLSBaseLoader): | ||||
def download(self, dev_ratio: float = 0.0, re_download: bool = False): | def download(self, dev_ratio: float = 0.0, re_download: bool = False): | ||||
""" | |||||
r""" | |||||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | ||||
Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | ||||
@@ -155,7 +155,7 @@ class DBPediaLoader(CLSBaseLoader): | |||||
class IMDBLoader(CLSBaseLoader): | class IMDBLoader(CLSBaseLoader): | ||||
""" | |||||
r""" | |||||
原始数据中内容应该为, 每一行为一个sample,制表符之前为target,制表符之后为文本内容。 | 原始数据中内容应该为, 每一行为一个sample,制表符之前为target,制表符之后为文本内容。 | ||||
Example:: | Example:: | ||||
@@ -178,7 +178,7 @@ class IMDBLoader(CLSBaseLoader): | |||||
super().__init__(sep='\t') | super().__init__(sep='\t') | ||||
def download(self, dev_ratio: float = 0.0, re_download=False): | def download(self, dev_ratio: float = 0.0, re_download=False): | ||||
""" | |||||
r""" | |||||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | ||||
http://www.aclweb.org/anthology/P11-1015 | http://www.aclweb.org/anthology/P11-1015 | ||||
@@ -200,7 +200,7 @@ class IMDBLoader(CLSBaseLoader): | |||||
class SSTLoader(Loader): | class SSTLoader(Loader): | ||||
""" | |||||
r""" | |||||
原始数据中内容应该为: | 原始数据中内容应该为: | ||||
Example:: | Example:: | ||||
@@ -225,7 +225,7 @@ class SSTLoader(Loader): | |||||
super().__init__() | super().__init__() | ||||
def _load(self, path: str): | def _load(self, path: str): | ||||
""" | |||||
r""" | |||||
从path读取SST文件 | 从path读取SST文件 | ||||
:param str path: 文件路径 | :param str path: 文件路径 | ||||
@@ -240,7 +240,7 @@ class SSTLoader(Loader): | |||||
return ds | return ds | ||||
def download(self): | def download(self): | ||||
""" | |||||
r""" | |||||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | ||||
https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf | https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf | ||||
@@ -253,7 +253,7 @@ class SSTLoader(Loader): | |||||
class YelpFullLoader(CLSBaseLoader): | class YelpFullLoader(CLSBaseLoader): | ||||
def download(self, dev_ratio: float = 0.0, re_download: bool = False): | def download(self, dev_ratio: float = 0.0, re_download: bool = False): | ||||
""" | |||||
r""" | |||||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | ||||
Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | ||||
@@ -278,7 +278,7 @@ class YelpFullLoader(CLSBaseLoader): | |||||
class YelpPolarityLoader(CLSBaseLoader): | class YelpPolarityLoader(CLSBaseLoader): | ||||
def download(self, dev_ratio: float = 0.0, re_download: bool = False): | def download(self, dev_ratio: float = 0.0, re_download: bool = False): | ||||
""" | |||||
r""" | |||||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | ||||
Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | ||||
@@ -302,7 +302,7 @@ class YelpPolarityLoader(CLSBaseLoader): | |||||
class SST2Loader(Loader): | class SST2Loader(Loader): | ||||
""" | |||||
r""" | |||||
原始数据中内容为:第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是句子,第一个制表符之后认为是label | 原始数据中内容为:第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是句子,第一个制表符之后认为是label | ||||
Example:: | Example:: | ||||
@@ -327,7 +327,7 @@ class SST2Loader(Loader): | |||||
super().__init__() | super().__init__() | ||||
def _load(self, path: str): | def _load(self, path: str): | ||||
"""从path读取SST2文件 | |||||
r"""从path读取SST2文件 | |||||
:param str path: 数据路径 | :param str path: 数据路径 | ||||
:return: DataSet | :return: DataSet | ||||
@@ -357,7 +357,7 @@ class SST2Loader(Loader): | |||||
return ds | return ds | ||||
def download(self): | def download(self): | ||||
""" | |||||
r""" | |||||
自动下载数据集,如果你使用了该数据集,请引用以下的文章 | 自动下载数据集,如果你使用了该数据集,请引用以下的文章 | ||||
https://nlp.stanford.edu/pubs/SocherBauerManningNg_ACL2013.pdf | https://nlp.stanford.edu/pubs/SocherBauerManningNg_ACL2013.pdf | ||||
:return: | :return: | ||||
@@ -367,7 +367,7 @@ class SST2Loader(Loader): | |||||
class ChnSentiCorpLoader(Loader): | class ChnSentiCorpLoader(Loader): | ||||
""" | |||||
r""" | |||||
支持读取的数据的格式为,第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是label,第 | 支持读取的数据的格式为,第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是label,第 | ||||
一个制表符之后认为是句子 | 一个制表符之后认为是句子 | ||||
@@ -391,7 +391,7 @@ class ChnSentiCorpLoader(Loader): | |||||
super().__init__() | super().__init__() | ||||
def _load(self, path: str): | def _load(self, path: str): | ||||
""" | |||||
r""" | |||||
从path中读取数据 | 从path中读取数据 | ||||
:param path: | :param path: | ||||
@@ -411,7 +411,7 @@ class ChnSentiCorpLoader(Loader): | |||||
return ds | return ds | ||||
def download(self) -> str: | def download(self) -> str: | ||||
""" | |||||
r""" | |||||
自动下载数据,该数据取自https://github.com/pengming617/bert_classification/tree/master/data,在 | 自动下载数据,该数据取自https://github.com/pengming617/bert_classification/tree/master/data,在 | ||||
https://arxiv.org/pdf/1904.09223.pdf与https://arxiv.org/pdf/1906.08101.pdf有使用 | https://arxiv.org/pdf/1904.09223.pdf与https://arxiv.org/pdf/1906.08101.pdf有使用 | ||||
@@ -422,7 +422,7 @@ class ChnSentiCorpLoader(Loader): | |||||
class THUCNewsLoader(Loader): | class THUCNewsLoader(Loader): | ||||
""" | |||||
r""" | |||||
数据集简介:document-level分类任务,新闻10分类 | 数据集简介:document-level分类任务,新闻10分类 | ||||
原始数据内容为:每行一个sample,第一个'\t'之前为target,第一个'\t'之后为raw_words | 原始数据内容为:每行一个sample,第一个'\t'之前为target,第一个'\t'之后为raw_words | ||||
@@ -456,7 +456,7 @@ class THUCNewsLoader(Loader): | |||||
return ds | return ds | ||||
def download(self) -> str: | def download(self) -> str: | ||||
""" | |||||
r""" | |||||
自动下载数据,该数据取自 | 自动下载数据,该数据取自 | ||||
http://thuctc.thunlp.org/#%E4%B8%AD%E6%96%87%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%95%B0%E6%8D%AE%E9%9B%86THUCNews | http://thuctc.thunlp.org/#%E4%B8%AD%E6%96%87%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%95%B0%E6%8D%AE%E9%9B%86THUCNews | ||||
@@ -468,7 +468,7 @@ class THUCNewsLoader(Loader): | |||||
class WeiboSenti100kLoader(Loader): | class WeiboSenti100kLoader(Loader): | ||||
""" | |||||
r""" | |||||
别名: | 别名: | ||||
数据集简介:微博sentiment classification,二分类 | 数据集简介:微博sentiment classification,二分类 | ||||
@@ -505,7 +505,7 @@ class WeiboSenti100kLoader(Loader): | |||||
return ds | return ds | ||||
def download(self) -> str: | def download(self) -> str: | ||||
""" | |||||
r""" | |||||
自动下载数据,该数据取自 https://github.com/SophonPlus/ChineseNlpCorpus/ | 自动下载数据,该数据取自 https://github.com/SophonPlus/ChineseNlpCorpus/ | ||||
在 https://arxiv.org/abs/1906.08101 有使用 | 在 https://arxiv.org/abs/1906.08101 有使用 | ||||
:return: | :return: | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"ConllLoader", | "ConllLoader", | ||||
@@ -26,7 +26,7 @@ from ...core.instance import Instance | |||||
class ConllLoader(Loader): | class ConllLoader(Loader): | ||||
""" | |||||
r""" | |||||
ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示: | ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示: | ||||
Example:: | Example:: | ||||
@@ -56,7 +56,7 @@ class ConllLoader(Loader): | |||||
""" | """ | ||||
def __init__(self, headers, indexes=None, dropna=True): | def __init__(self, headers, indexes=None, dropna=True): | ||||
""" | |||||
r""" | |||||
:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | :param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | ||||
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | :param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | ||||
@@ -76,7 +76,7 @@ class ConllLoader(Loader): | |||||
self.indexes = indexes | self.indexes = indexes | ||||
def _load(self, path): | def _load(self, path): | ||||
""" | |||||
r""" | |||||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | ||||
:param str path: 文件的路径 | :param str path: 文件的路径 | ||||
@@ -90,7 +90,7 @@ class ConllLoader(Loader): | |||||
class Conll2003Loader(ConllLoader): | class Conll2003Loader(ConllLoader): | ||||
""" | |||||
r""" | |||||
用于读取conll2003任务的数据。数据的内容应该类似与以下的内容, 第一列为raw_words, 第二列为pos, 第三列为chunking,第四列为ner。 | 用于读取conll2003任务的数据。数据的内容应该类似与以下的内容, 第一列为raw_words, 第二列为pos, 第三列为chunking,第四列为ner。 | ||||
Example:: | Example:: | ||||
@@ -123,7 +123,7 @@ class Conll2003Loader(ConllLoader): | |||||
super(Conll2003Loader, self).__init__(headers=headers) | super(Conll2003Loader, self).__init__(headers=headers) | ||||
def _load(self, path): | def _load(self, path): | ||||
""" | |||||
r""" | |||||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | ||||
:param str path: 文件的路径 | :param str path: 文件的路径 | ||||
@@ -148,7 +148,7 @@ class Conll2003Loader(ConllLoader): | |||||
class Conll2003NERLoader(ConllLoader): | class Conll2003NERLoader(ConllLoader): | ||||
""" | |||||
r""" | |||||
用于读取conll2003任务的NER数据。每一行有4列内容,空行意味着隔开两个句子 | 用于读取conll2003任务的NER数据。每一行有4列内容,空行意味着隔开两个句子 | ||||
支持读取的内容如下 | 支持读取的内容如下 | ||||
@@ -182,7 +182,7 @@ class Conll2003NERLoader(ConllLoader): | |||||
super().__init__(headers=headers, indexes=[0, 3]) | super().__init__(headers=headers, indexes=[0, 3]) | ||||
def _load(self, path): | def _load(self, path): | ||||
""" | |||||
r""" | |||||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | ||||
:param str path: 文件的路径 | :param str path: 文件的路径 | ||||
@@ -209,7 +209,7 @@ class Conll2003NERLoader(ConllLoader): | |||||
class OntoNotesNERLoader(ConllLoader): | class OntoNotesNERLoader(ConllLoader): | ||||
""" | |||||
r""" | |||||
用以读取OntoNotes的NER数据,同时也是Conll2012的NER任务数据。将OntoNote数据处理为conll格式的过程可以参考 | 用以读取OntoNotes的NER数据,同时也是Conll2012的NER任务数据。将OntoNote数据处理为conll格式的过程可以参考 | ||||
https://github.com/yhcc/OntoNotes-5.0-NER。OntoNoteNERLoader将取第4列和第11列的内容。 | https://github.com/yhcc/OntoNotes-5.0-NER。OntoNoteNERLoader将取第4列和第11列的内容。 | ||||
@@ -287,7 +287,7 @@ class OntoNotesNERLoader(ConllLoader): | |||||
class CTBLoader(Loader): | class CTBLoader(Loader): | ||||
""" | |||||
r""" | |||||
支持加载的数据应该具备以下格式, 其中第二列为词语,第四列为pos tag,第七列为依赖树的head,第八列为依赖树的label | 支持加载的数据应该具备以下格式, 其中第二列为词语,第四列为pos tag,第七列为依赖树的head,第八列为依赖树的label | ||||
Example:: | Example:: | ||||
@@ -328,7 +328,7 @@ class CTBLoader(Loader): | |||||
return dataset | return dataset | ||||
def download(self): | def download(self): | ||||
""" | |||||
r""" | |||||
由于版权限制,不能提供自动下载功能。可参考 | 由于版权限制,不能提供自动下载功能。可参考 | ||||
https://catalog.ldc.upenn.edu/LDC2013T21 | https://catalog.ldc.upenn.edu/LDC2013T21 | ||||
@@ -340,7 +340,7 @@ class CTBLoader(Loader): | |||||
class CNNERLoader(Loader): | class CNNERLoader(Loader): | ||||
def _load(self, path: str): | def _load(self, path: str): | ||||
""" | |||||
r""" | |||||
支持加载形如以下格式的内容,一行两列,以空格隔开两个sample | 支持加载形如以下格式的内容,一行两列,以空格隔开两个sample | ||||
Example:: | Example:: | ||||
@@ -378,7 +378,7 @@ class CNNERLoader(Loader): | |||||
class MsraNERLoader(CNNERLoader): | class MsraNERLoader(CNNERLoader): | ||||
""" | |||||
r""" | |||||
读取MSRA-NER数据,数据中的格式应该类似与下列的内容 | 读取MSRA-NER数据,数据中的格式应该类似与下列的内容 | ||||
Example:: | Example:: | ||||
@@ -416,7 +416,7 @@ class MsraNERLoader(CNNERLoader): | |||||
super().__init__() | super().__init__() | ||||
def download(self, dev_ratio: float = 0.1, re_download: bool = False) -> str: | def download(self, dev_ratio: float = 0.1, re_download: bool = False) -> str: | ||||
""" | |||||
r""" | |||||
自动下载MSAR-NER的数据,如果你使用该数据,请引用 Gina-Anne Levow, 2006, The Third International Chinese Language | 自动下载MSAR-NER的数据,如果你使用该数据,请引用 Gina-Anne Levow, 2006, The Third International Chinese Language | ||||
Processing Bakeoff: Word Segmentation and Named Entity Recognition. | Processing Bakeoff: Word Segmentation and Named Entity Recognition. | ||||
@@ -466,7 +466,7 @@ class MsraNERLoader(CNNERLoader): | |||||
class WeiboNERLoader(CNNERLoader): | class WeiboNERLoader(CNNERLoader): | ||||
""" | |||||
r""" | |||||
读取WeiboNER数据,数据中的格式应该类似与下列的内容 | 读取WeiboNER数据,数据中的格式应该类似与下列的内容 | ||||
Example:: | Example:: | ||||
@@ -494,7 +494,7 @@ class WeiboNERLoader(CNNERLoader): | |||||
super().__init__() | super().__init__() | ||||
def download(self) -> str: | def download(self) -> str: | ||||
""" | |||||
r""" | |||||
自动下载Weibo-NER的数据,如果你使用了该数据,请引用 Nanyun Peng and Mark Dredze, 2015, Named Entity Recognition for | 自动下载Weibo-NER的数据,如果你使用了该数据,请引用 Nanyun Peng and Mark Dredze, 2015, Named Entity Recognition for | ||||
Chinese Social Media with Jointly Trained Embeddings. | Chinese Social Media with Jointly Trained Embeddings. | ||||
@@ -507,7 +507,7 @@ class WeiboNERLoader(CNNERLoader): | |||||
class PeopleDailyNERLoader(CNNERLoader): | class PeopleDailyNERLoader(CNNERLoader): | ||||
""" | |||||
r""" | |||||
支持加载的数据格式如下 | 支持加载的数据格式如下 | ||||
Example:: | Example:: | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"CoReferenceLoader", | "CoReferenceLoader", | ||||
@@ -12,7 +12,7 @@ from .json import JsonLoader | |||||
class CoReferenceLoader(JsonLoader): | class CoReferenceLoader(JsonLoader): | ||||
""" | |||||
r""" | |||||
原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。 | 原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。 | ||||
Example:: | Example:: | ||||
@@ -38,7 +38,7 @@ class CoReferenceLoader(JsonLoader): | |||||
"sentences": Const.RAW_WORDS(3)} | "sentences": Const.RAW_WORDS(3)} | ||||
def _load(self, path): | def _load(self, path): | ||||
""" | |||||
r""" | |||||
加载数据 | 加载数据 | ||||
:param path: 数据文件路径,文件为json | :param path: 数据文件路径,文件为json | ||||
@@ -54,7 +54,7 @@ class CoReferenceLoader(JsonLoader): | |||||
return dataset | return dataset | ||||
def download(self): | def download(self): | ||||
""" | |||||
r""" | |||||
由于版权限制,不能提供自动下载功能。可参考 | 由于版权限制,不能提供自动下载功能。可参考 | ||||
https://www.aclweb.org/anthology/W12-4501 | https://www.aclweb.org/anthology/W12-4501 | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"CSVLoader", | "CSVLoader", | ||||
@@ -11,13 +11,13 @@ from ...core.instance import Instance | |||||
class CSVLoader(Loader): | class CSVLoader(Loader): | ||||
""" | |||||
r""" | |||||
读取CSV格式的数据集, 返回 ``DataSet`` 。 | 读取CSV格式的数据集, 返回 ``DataSet`` 。 | ||||
""" | """ | ||||
def __init__(self, headers=None, sep=",", dropna=False): | def __init__(self, headers=None, sep=",", dropna=False): | ||||
""" | |||||
r""" | |||||
:param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 | :param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 | ||||
若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` | 若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"CWSLoader" | "CWSLoader" | ||||
@@ -16,7 +16,7 @@ from ...core.instance import Instance | |||||
class CWSLoader(Loader): | class CWSLoader(Loader): | ||||
""" | |||||
r""" | |||||
CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: | CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: | ||||
Example:: | Example:: | ||||
@@ -36,7 +36,7 @@ class CWSLoader(Loader): | |||||
""" | """ | ||||
def __init__(self, dataset_name:str=None): | def __init__(self, dataset_name:str=None): | ||||
""" | |||||
r""" | |||||
:param str dataset_name: data的名称,支持pku, msra, cityu(繁体), as(繁体), None | :param str dataset_name: data的名称,支持pku, msra, cityu(繁体), as(繁体), None | ||||
""" | """ | ||||
@@ -57,7 +57,7 @@ class CWSLoader(Loader): | |||||
return ds | return ds | ||||
def download(self, dev_ratio=0.1, re_download=False)->str: | def download(self, dev_ratio=0.1, re_download=False)->str: | ||||
""" | |||||
r""" | |||||
如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff, | 如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff, | ||||
2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看 | 2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看 | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"JsonLoader" | "JsonLoader" | ||||
@@ -11,7 +11,7 @@ from ...core.instance import Instance | |||||
class JsonLoader(Loader): | class JsonLoader(Loader): | ||||
""" | |||||
r""" | |||||
别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.loader.JsonLoader` | 别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.loader.JsonLoader` | ||||
读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 | 读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"Loader" | "Loader" | ||||
@@ -13,7 +13,7 @@ from ...core.dataset import DataSet | |||||
class Loader: | class Loader: | ||||
""" | |||||
r""" | |||||
各种数据 Loader 的基类,提供了 API 的参考. | 各种数据 Loader 的基类,提供了 API 的参考. | ||||
Loader支持以下的三个函数 | Loader支持以下的三个函数 | ||||
@@ -27,7 +27,7 @@ class Loader: | |||||
pass | pass | ||||
def _load(self, path: str) -> DataSet: | def _load(self, path: str) -> DataSet: | ||||
""" | |||||
r""" | |||||
给定一个路径,返回读取的DataSet。 | 给定一个路径,返回读取的DataSet。 | ||||
:param str path: 路径 | :param str path: 路径 | ||||
@@ -71,7 +71,7 @@ class Loader: | |||||
return data_bundle | return data_bundle | ||||
def download(self) -> str: | def download(self) -> str: | ||||
""" | |||||
r""" | |||||
自动下载该数据集 | 自动下载该数据集 | ||||
:return: 下载后解压目录 | :return: 下载后解压目录 | ||||
@@ -80,7 +80,7 @@ class Loader: | |||||
@staticmethod | @staticmethod | ||||
def _get_dataset_path(dataset_name): | def _get_dataset_path(dataset_name): | ||||
""" | |||||
r""" | |||||
传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存(如果支持的话) | 传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存(如果支持的话) | ||||
:param str dataset_name: 数据集的名称 | :param str dataset_name: 数据集的名称 | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"MNLILoader", | "MNLILoader", | ||||
@@ -26,7 +26,7 @@ from ...core.instance import Instance | |||||
class MNLILoader(Loader): | class MNLILoader(Loader): | ||||
""" | |||||
r""" | |||||
读取的数据格式为: | 读取的数据格式为: | ||||
Example:: | Example:: | ||||
@@ -80,7 +80,7 @@ class MNLILoader(Loader): | |||||
return ds | return ds | ||||
def load(self, paths: str = None): | def load(self, paths: str = None): | ||||
""" | |||||
r""" | |||||
:param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv, | :param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv, | ||||
test_mismatched.tsv, train.tsv文件夹 | test_mismatched.tsv, train.tsv文件夹 | ||||
@@ -112,7 +112,7 @@ class MNLILoader(Loader): | |||||
return data_bundle | return data_bundle | ||||
def download(self): | def download(self): | ||||
""" | |||||
r""" | |||||
如果你使用了这个数据,请引用 | 如果你使用了这个数据,请引用 | ||||
https://www.nyu.edu/projects/bowman/multinli/paper.pdf | https://www.nyu.edu/projects/bowman/multinli/paper.pdf | ||||
@@ -123,7 +123,7 @@ class MNLILoader(Loader): | |||||
class SNLILoader(JsonLoader): | class SNLILoader(JsonLoader): | ||||
""" | |||||
r""" | |||||
文件每一行是一个sample,每一行都为一个json对象,其数据格式为: | 文件每一行是一个sample,每一行都为一个json对象,其数据格式为: | ||||
Example:: | Example:: | ||||
@@ -157,7 +157,7 @@ class SNLILoader(JsonLoader): | |||||
}) | }) | ||||
def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: | def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: | ||||
""" | |||||
r""" | |||||
从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 | 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 | ||||
读取的field根据Loader初始化时传入的field决定。 | 读取的field根据Loader初始化时传入的field决定。 | ||||
@@ -187,7 +187,7 @@ class SNLILoader(JsonLoader): | |||||
return data_bundle | return data_bundle | ||||
def download(self): | def download(self): | ||||
""" | |||||
r""" | |||||
如果您的文章使用了这份数据,请引用 | 如果您的文章使用了这份数据,请引用 | ||||
http://nlp.stanford.edu/pubs/snli_paper.pdf | http://nlp.stanford.edu/pubs/snli_paper.pdf | ||||
@@ -198,7 +198,7 @@ class SNLILoader(JsonLoader): | |||||
class QNLILoader(JsonLoader): | class QNLILoader(JsonLoader): | ||||
""" | |||||
r""" | |||||
第一行为标题(具体内容会被忽略),之后每一行是一个sample,由index、问题、句子和标签构成(以制表符分割),数据结构如下: | 第一行为标题(具体内容会被忽略),之后每一行是一个sample,由index、问题、句子和标签构成(以制表符分割),数据结构如下: | ||||
Example:: | Example:: | ||||
@@ -250,7 +250,7 @@ class QNLILoader(JsonLoader): | |||||
return ds | return ds | ||||
def download(self): | def download(self): | ||||
""" | |||||
r""" | |||||
如果您的实验使用到了该数据,请引用 | 如果您的实验使用到了该数据,请引用 | ||||
https://arxiv.org/pdf/1809.05053.pdf | https://arxiv.org/pdf/1809.05053.pdf | ||||
@@ -261,7 +261,7 @@ class QNLILoader(JsonLoader): | |||||
class RTELoader(Loader): | class RTELoader(Loader): | ||||
""" | |||||
r""" | |||||
第一行为标题(具体内容会被忽略),之后每一行是一个sample,由index、句子1、句子2和标签构成(以制表符分割),数据结构如下: | 第一行为标题(具体内容会被忽略),之后每一行是一个sample,由index、句子1、句子2和标签构成(以制表符分割),数据结构如下: | ||||
Example:: | Example:: | ||||
@@ -312,7 +312,7 @@ class RTELoader(Loader): | |||||
return ds | return ds | ||||
def download(self): | def download(self): | ||||
""" | |||||
r""" | |||||
如果您的实验使用到了该数据,请引用GLUE Benchmark | 如果您的实验使用到了该数据,请引用GLUE Benchmark | ||||
https://openreview.net/pdf?id=rJ4km2R5t7 | https://openreview.net/pdf?id=rJ4km2R5t7 | ||||
@@ -323,7 +323,7 @@ class RTELoader(Loader): | |||||
class QuoraLoader(Loader): | class QuoraLoader(Loader): | ||||
""" | |||||
r""" | |||||
Quora matching任务的数据集Loader | Quora matching任务的数据集Loader | ||||
支持读取的文件中的内容,应该有以下的形式, 以制表符分隔,且前三列的内容必须是:第一列是label,第二列和第三列是句子 | 支持读取的文件中的内容,应该有以下的形式, 以制表符分隔,且前三列的内容必须是:第一列是label,第二列和第三列是句子 | ||||
@@ -364,7 +364,7 @@ class QuoraLoader(Loader): | |||||
return ds | return ds | ||||
def download(self): | def download(self): | ||||
""" | |||||
r""" | |||||
由于版权限制,不能提供自动下载功能。可参考 | 由于版权限制,不能提供自动下载功能。可参考 | ||||
https://www.kaggle.com/c/quora-question-pairs/data | https://www.kaggle.com/c/quora-question-pairs/data | ||||
@@ -375,7 +375,7 @@ class QuoraLoader(Loader): | |||||
class CNXNLILoader(Loader): | class CNXNLILoader(Loader): | ||||
""" | |||||
r""" | |||||
数据集简介:中文句对NLI(本为multi-lingual的数据集,但是这里只取了中文的数据集)。原句子已被MOSES tokenizer处理,这里我们将其还原并重新按字tokenize | 数据集简介:中文句对NLI(本为multi-lingual的数据集,但是这里只取了中文的数据集)。原句子已被MOSES tokenizer处理,这里我们将其还原并重新按字tokenize | ||||
原始数据数据为: | 原始数据数据为: | ||||
@@ -459,7 +459,7 @@ class CNXNLILoader(Loader): | |||||
return data_bundle | return data_bundle | ||||
def download(self) -> str: | def download(self) -> str: | ||||
""" | |||||
r""" | |||||
自动下载数据,该数据取自 https://arxiv.org/abs/1809.05053 | 自动下载数据,该数据取自 https://arxiv.org/abs/1809.05053 | ||||
在 https://arxiv.org/pdf/1905.05526.pdf https://arxiv.org/pdf/1901.10125.pdf | 在 https://arxiv.org/pdf/1905.05526.pdf https://arxiv.org/pdf/1901.10125.pdf | ||||
https://arxiv.org/pdf/1809.05053.pdf 有使用 | https://arxiv.org/pdf/1809.05053.pdf 有使用 | ||||
@@ -470,7 +470,7 @@ class CNXNLILoader(Loader): | |||||
class BQCorpusLoader(Loader): | class BQCorpusLoader(Loader): | ||||
""" | |||||
r""" | |||||
别名: | 别名: | ||||
数据集简介:句子对二分类任务(判断是否具有相同的语义) | 数据集简介:句子对二分类任务(判断是否具有相同的语义) | ||||
原始数据结构为: | 原始数据结构为: | ||||
@@ -511,7 +511,7 @@ class BQCorpusLoader(Loader): | |||||
return ds | return ds | ||||
def download(self): | def download(self): | ||||
""" | |||||
r""" | |||||
由于版权限制,不能提供自动下载功能。可参考 | 由于版权限制,不能提供自动下载功能。可参考 | ||||
https://github.com/ymcui/Chinese-BERT-wwm | https://github.com/ymcui/Chinese-BERT-wwm | ||||
@@ -566,7 +566,7 @@ class LCQMCLoader(Loader): | |||||
return ds | return ds | ||||
def download(self): | def download(self): | ||||
""" | |||||
r""" | |||||
由于版权限制,不能提供自动下载功能。可参考 | 由于版权限制,不能提供自动下载功能。可参考 | ||||
https://github.com/ymcui/Chinese-BERT-wwm | https://github.com/ymcui/Chinese-BERT-wwm | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
该文件中的Loader主要用于读取问答式任务的数据 | 该文件中的Loader主要用于读取问答式任务的数据 | ||||
""" | """ | ||||
@@ -12,7 +12,7 @@ __all__ = ['CMRC2018Loader'] | |||||
class CMRC2018Loader(Loader): | class CMRC2018Loader(Loader): | ||||
""" | |||||
r""" | |||||
请直接使用从fastNLP下载的数据进行处理。该数据集未提供测试集,测试需要通过上传到对应的系统进行评测 | 请直接使用从fastNLP下载的数据进行处理。该数据集未提供测试集,测试需要通过上传到对应的系统进行评测 | ||||
读取之后训练集DataSet将具备以下的内容,每个问题的答案只有一个 | 读取之后训练集DataSet将具备以下的内容,每个问题的答案只有一个 | ||||
@@ -64,7 +64,7 @@ class CMRC2018Loader(Loader): | |||||
return ds | return ds | ||||
def download(self) -> str: | def download(self) -> str: | ||||
""" | |||||
r""" | |||||
如果您使用了本数据,请引用A Span-Extraction Dataset for Chinese Machine Reading Comprehension. Yiming Cui, Ting Liu, etc. | 如果您使用了本数据,请引用A Span-Extraction Dataset for Chinese Machine Reading Comprehension. Yiming Cui, Ting Liu, etc. | ||||
:return: | :return: | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"ExtCNNDMLoader" | "ExtCNNDMLoader" | ||||
@@ -13,7 +13,7 @@ from .json import JsonLoader | |||||
class ExtCNNDMLoader(JsonLoader): | class ExtCNNDMLoader(JsonLoader): | ||||
""" | |||||
r""" | |||||
读取之后的DataSet中的field情况为 | 读取之后的DataSet中的field情况为 | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -30,7 +30,7 @@ class ExtCNNDMLoader(JsonLoader): | |||||
super(ExtCNNDMLoader, self).__init__(fields=fields) | super(ExtCNNDMLoader, self).__init__(fields=fields) | ||||
def load(self, paths: Union[str, Dict[str, str]] = None): | def load(self, paths: Union[str, Dict[str, str]] = None): | ||||
""" | |||||
r""" | |||||
从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 | 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 | ||||
读取的field根据ExtCNNDMLoader初始化时传入的headers决定。 | 读取的field根据ExtCNNDMLoader初始化时传入的headers决定。 | ||||
@@ -53,7 +53,7 @@ class ExtCNNDMLoader(JsonLoader): | |||||
return data_bundle | return data_bundle | ||||
def download(self): | def download(self): | ||||
""" | |||||
r""" | |||||
如果你使用了这个数据,请引用 | 如果你使用了这个数据,请引用 | ||||
https://arxiv.org/pdf/1506.03340.pdf | https://arxiv.org/pdf/1506.03340.pdf | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
用于载入和保存模型 | 用于载入和保存模型 | ||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
@@ -10,7 +10,7 @@ import torch | |||||
class ModelLoader: | class ModelLoader: | ||||
""" | |||||
r""" | |||||
用于读取模型 | 用于读取模型 | ||||
""" | """ | ||||
@@ -19,7 +19,7 @@ class ModelLoader: | |||||
@staticmethod | @staticmethod | ||||
def load_pytorch(empty_model, model_path): | def load_pytorch(empty_model, model_path): | ||||
""" | |||||
r""" | |||||
从 ".pkl" 文件读取 PyTorch 模型 | 从 ".pkl" 文件读取 PyTorch 模型 | ||||
:param empty_model: 初始化参数的 PyTorch 模型 | :param empty_model: 初始化参数的 PyTorch 模型 | ||||
@@ -29,7 +29,7 @@ class ModelLoader: | |||||
@staticmethod | @staticmethod | ||||
def load_pytorch_model(model_path): | def load_pytorch_model(model_path): | ||||
""" | |||||
r""" | |||||
读取整个模型 | 读取整个模型 | ||||
:param str model_path: 模型保存的路径 | :param str model_path: 模型保存的路径 | ||||
@@ -38,7 +38,7 @@ class ModelLoader: | |||||
class ModelSaver(object): | class ModelSaver(object): | ||||
""" | |||||
r""" | |||||
用于保存模型 | 用于保存模型 | ||||
Example:: | Example:: | ||||
@@ -49,14 +49,14 @@ class ModelSaver(object): | |||||
""" | """ | ||||
def __init__(self, save_path): | def __init__(self, save_path): | ||||
""" | |||||
r""" | |||||
:param save_path: 模型保存的路径 | :param save_path: 模型保存的路径 | ||||
""" | """ | ||||
self.save_path = save_path | self.save_path = save_path | ||||
def save_pytorch(self, model, param_only=True): | def save_pytorch(self, model, param_only=True): | ||||
""" | |||||
r""" | |||||
把 PyTorch 模型存入 ".pkl" 文件 | 把 PyTorch 模型存入 ".pkl" 文件 | ||||
:param model: PyTorch 模型 | :param model: PyTorch 模型 | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
Pipe用于处理通过 Loader 读取的数据,所有的 Pipe 都包含 ``process`` 和 ``process_from_file`` 两种方法。 | Pipe用于处理通过 Loader 读取的数据,所有的 Pipe 都包含 ``process`` 和 ``process_from_file`` 两种方法。 | ||||
``process(data_bundle)`` 传入一个 :class:`~fastNLP.io.DataBundle` 类型的对象, 在传入的 `data_bundle` 上进行原位修改,并将其返回; | ``process(data_bundle)`` 传入一个 :class:`~fastNLP.io.DataBundle` 类型的对象, 在传入的 `data_bundle` 上进行原位修改,并将其返回; | ||||
``process_from_file(paths)`` 传入的文件路径,返回一个 :class:`~fastNLP.io.DataBundle` 类型的对象。 | ``process_from_file(paths)`` 传入的文件路径,返回一个 :class:`~fastNLP.io.DataBundle` 类型的对象。 | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"CLSBasePipe", | "CLSBasePipe", | ||||
@@ -39,7 +39,7 @@ class CLSBasePipe(Pipe): | |||||
self.tokenizer = get_tokenizer(tokenizer, lang=lang) | self.tokenizer = get_tokenizer(tokenizer, lang=lang) | ||||
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): | def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): | ||||
""" | |||||
r""" | |||||
将DataBundle中的数据进行tokenize | 将DataBundle中的数据进行tokenize | ||||
:param DataBundle data_bundle: | :param DataBundle data_bundle: | ||||
@@ -54,7 +54,7 @@ class CLSBasePipe(Pipe): | |||||
return data_bundle | return data_bundle | ||||
def process(self, data_bundle: DataBundle): | def process(self, data_bundle: DataBundle): | ||||
""" | |||||
r""" | |||||
传入的DataSet应该具备如下的结构 | 传入的DataSet应该具备如下的结构 | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -83,7 +83,7 @@ class CLSBasePipe(Pipe): | |||||
return data_bundle | return data_bundle | ||||
def process_from_file(self, paths) -> DataBundle: | def process_from_file(self, paths) -> DataBundle: | ||||
""" | |||||
r""" | |||||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | ||||
:param paths: | :param paths: | ||||
@@ -93,7 +93,7 @@ class CLSBasePipe(Pipe): | |||||
class YelpFullPipe(CLSBasePipe): | class YelpFullPipe(CLSBasePipe): | ||||
""" | |||||
r""" | |||||
处理YelpFull的数据, 处理之后DataSet中的内容如下 | 处理YelpFull的数据, 处理之后DataSet中的内容如下 | ||||
.. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field | .. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field | ||||
@@ -117,7 +117,7 @@ class YelpFullPipe(CLSBasePipe): | |||||
""" | """ | ||||
def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy'): | def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy'): | ||||
""" | |||||
r""" | |||||
:param bool lower: 是否对输入进行小写化。 | :param bool lower: 是否对输入进行小写化。 | ||||
:param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将1、2归为1类,4、5归为一类,丢掉2;若为3, 则有3分类问题,将 | :param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将1、2归为1类,4、5归为一类,丢掉2;若为3, 则有3分类问题,将 | ||||
@@ -136,7 +136,7 @@ class YelpFullPipe(CLSBasePipe): | |||||
self.tag_map = None | self.tag_map = None | ||||
def process(self, data_bundle): | def process(self, data_bundle): | ||||
""" | |||||
r""" | |||||
传入的DataSet应该具备如下的结构 | 传入的DataSet应该具备如下的结构 | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -157,7 +157,7 @@ class YelpFullPipe(CLSBasePipe): | |||||
return data_bundle | return data_bundle | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
""" | |||||
r""" | |||||
:param paths: | :param paths: | ||||
:return: DataBundle | :return: DataBundle | ||||
@@ -167,7 +167,7 @@ class YelpFullPipe(CLSBasePipe): | |||||
class YelpPolarityPipe(CLSBasePipe): | class YelpPolarityPipe(CLSBasePipe): | ||||
""" | |||||
r""" | |||||
处理YelpPolarity的数据, 处理之后DataSet中的内容如下 | 处理YelpPolarity的数据, 处理之后DataSet中的内容如下 | ||||
.. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field | .. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field | ||||
@@ -191,7 +191,7 @@ class YelpPolarityPipe(CLSBasePipe): | |||||
""" | """ | ||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | ||||
""" | |||||
r""" | |||||
:param bool lower: 是否对输入进行小写化。 | :param bool lower: 是否对输入进行小写化。 | ||||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | ||||
@@ -199,7 +199,7 @@ class YelpPolarityPipe(CLSBasePipe): | |||||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
""" | |||||
r""" | |||||
:param str paths: | :param str paths: | ||||
:return: DataBundle | :return: DataBundle | ||||
@@ -209,7 +209,7 @@ class YelpPolarityPipe(CLSBasePipe): | |||||
class AGsNewsPipe(CLSBasePipe): | class AGsNewsPipe(CLSBasePipe): | ||||
""" | |||||
r""" | |||||
处理AG's News的数据, 处理之后DataSet中的内容如下 | 处理AG's News的数据, 处理之后DataSet中的内容如下 | ||||
.. csv-table:: 下面是使用AGsNewsPipe处理后的DataSet所具备的field | .. csv-table:: 下面是使用AGsNewsPipe处理后的DataSet所具备的field | ||||
@@ -233,7 +233,7 @@ class AGsNewsPipe(CLSBasePipe): | |||||
""" | """ | ||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | ||||
""" | |||||
r""" | |||||
:param bool lower: 是否对输入进行小写化。 | :param bool lower: 是否对输入进行小写化。 | ||||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | ||||
@@ -241,7 +241,7 @@ class AGsNewsPipe(CLSBasePipe): | |||||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
""" | |||||
r""" | |||||
:param str paths: | :param str paths: | ||||
:return: DataBundle | :return: DataBundle | ||||
""" | """ | ||||
@@ -250,7 +250,7 @@ class AGsNewsPipe(CLSBasePipe): | |||||
class DBPediaPipe(CLSBasePipe): | class DBPediaPipe(CLSBasePipe): | ||||
""" | |||||
r""" | |||||
处理DBPedia的数据, 处理之后DataSet中的内容如下 | 处理DBPedia的数据, 处理之后DataSet中的内容如下 | ||||
.. csv-table:: 下面是使用DBPediaPipe处理后的DataSet所具备的field | .. csv-table:: 下面是使用DBPediaPipe处理后的DataSet所具备的field | ||||
@@ -274,7 +274,7 @@ class DBPediaPipe(CLSBasePipe): | |||||
""" | """ | ||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | ||||
""" | |||||
r""" | |||||
:param bool lower: 是否对输入进行小写化。 | :param bool lower: 是否对输入进行小写化。 | ||||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | ||||
@@ -282,7 +282,7 @@ class DBPediaPipe(CLSBasePipe): | |||||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
""" | |||||
r""" | |||||
:param str paths: | :param str paths: | ||||
:return: DataBundle | :return: DataBundle | ||||
""" | """ | ||||
@@ -291,7 +291,7 @@ class DBPediaPipe(CLSBasePipe): | |||||
class SSTPipe(CLSBasePipe): | class SSTPipe(CLSBasePipe): | ||||
""" | |||||
r""" | |||||
经过该Pipe之后,DataSet中具备的field如下所示 | 经过该Pipe之后,DataSet中具备的field如下所示 | ||||
.. csv-table:: 下面是使用SSTPipe处理后的DataSet所具备的field | .. csv-table:: 下面是使用SSTPipe处理后的DataSet所具备的field | ||||
@@ -315,7 +315,7 @@ class SSTPipe(CLSBasePipe): | |||||
""" | """ | ||||
def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'): | def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'): | ||||
""" | |||||
r""" | |||||
:param bool subtree: 是否将train, test, dev数据展开为子树,扩充数据量。 Default: ``False`` | :param bool subtree: 是否将train, test, dev数据展开为子树,扩充数据量。 Default: ``False`` | ||||
:param bool train_subtree: 是否将train集通过子树扩展数据。 | :param bool train_subtree: 是否将train集通过子树扩展数据。 | ||||
@@ -339,7 +339,7 @@ class SSTPipe(CLSBasePipe): | |||||
self.tag_map = None | self.tag_map = None | ||||
def process(self, data_bundle: DataBundle): | def process(self, data_bundle: DataBundle): | ||||
""" | |||||
r""" | |||||
对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 | 对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 | ||||
.. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field | .. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field | ||||
@@ -383,7 +383,7 @@ class SSTPipe(CLSBasePipe): | |||||
class SST2Pipe(CLSBasePipe): | class SST2Pipe(CLSBasePipe): | ||||
""" | |||||
r""" | |||||
加载SST2的数据, 处理完成之后DataSet将拥有以下的field | 加载SST2的数据, 处理完成之后DataSet将拥有以下的field | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -407,7 +407,7 @@ class SST2Pipe(CLSBasePipe): | |||||
""" | """ | ||||
def __init__(self, lower=False, tokenizer='spacy'): | def __init__(self, lower=False, tokenizer='spacy'): | ||||
""" | |||||
r""" | |||||
:param bool lower: 是否对输入进行小写化。 | :param bool lower: 是否对输入进行小写化。 | ||||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | ||||
@@ -415,7 +415,7 @@ class SST2Pipe(CLSBasePipe): | |||||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
""" | |||||
r""" | |||||
:param str paths: 如果为None,则自动下载并缓存到fastNLP的缓存地址。 | :param str paths: 如果为None,则自动下载并缓存到fastNLP的缓存地址。 | ||||
:return: DataBundle | :return: DataBundle | ||||
@@ -425,7 +425,7 @@ class SST2Pipe(CLSBasePipe): | |||||
class IMDBPipe(CLSBasePipe): | class IMDBPipe(CLSBasePipe): | ||||
""" | |||||
r""" | |||||
经过本Pipe处理后DataSet将如下 | 经过本Pipe处理后DataSet将如下 | ||||
.. csv-table:: 输出DataSet的field | .. csv-table:: 输出DataSet的field | ||||
@@ -452,7 +452,7 @@ class IMDBPipe(CLSBasePipe): | |||||
""" | """ | ||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | ||||
""" | |||||
r""" | |||||
:param bool lower: 是否将words列的数据小写。 | :param bool lower: 是否将words列的数据小写。 | ||||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | ||||
@@ -461,7 +461,7 @@ class IMDBPipe(CLSBasePipe): | |||||
self.lower = lower | self.lower = lower | ||||
def process(self, data_bundle: DataBundle): | def process(self, data_bundle: DataBundle): | ||||
""" | |||||
r""" | |||||
期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型 | 期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型 | ||||
.. csv-table:: 输入DataSet的field | .. csv-table:: 输入DataSet的field | ||||
@@ -489,7 +489,7 @@ class IMDBPipe(CLSBasePipe): | |||||
return data_bundle | return data_bundle | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
""" | |||||
r""" | |||||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | ||||
:return: DataBundle | :return: DataBundle | ||||
@@ -502,7 +502,7 @@ class IMDBPipe(CLSBasePipe): | |||||
class ChnSentiCorpPipe(Pipe): | class ChnSentiCorpPipe(Pipe): | ||||
""" | |||||
r""" | |||||
处理之后的DataSet有以下的结构 | 处理之后的DataSet有以下的结构 | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -526,7 +526,7 @@ class ChnSentiCorpPipe(Pipe): | |||||
""" | """ | ||||
def __init__(self, bigrams=False, trigrams=False): | def __init__(self, bigrams=False, trigrams=False): | ||||
""" | |||||
r""" | |||||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | :param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | ||||
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | 设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | ||||
@@ -541,7 +541,7 @@ class ChnSentiCorpPipe(Pipe): | |||||
self.trigrams = trigrams | self.trigrams = trigrams | ||||
def _tokenize(self, data_bundle): | def _tokenize(self, data_bundle): | ||||
""" | |||||
r""" | |||||
将DataSet中的"复旦大学"拆分为["复", "旦", "大", "学"]. 未来可以通过扩展这个函数实现分词。 | 将DataSet中的"复旦大学"拆分为["复", "旦", "大", "学"]. 未来可以通过扩展这个函数实现分词。 | ||||
:param data_bundle: | :param data_bundle: | ||||
@@ -551,7 +551,7 @@ class ChnSentiCorpPipe(Pipe): | |||||
return data_bundle | return data_bundle | ||||
def process(self, data_bundle:DataBundle): | def process(self, data_bundle:DataBundle): | ||||
""" | |||||
r""" | |||||
可以处理的DataSet应该具备以下的field | 可以处理的DataSet应该具备以下的field | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -596,7 +596,7 @@ class ChnSentiCorpPipe(Pipe): | |||||
return data_bundle | return data_bundle | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
""" | |||||
r""" | |||||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | ||||
:return: DataBundle | :return: DataBundle | ||||
@@ -609,7 +609,7 @@ class ChnSentiCorpPipe(Pipe): | |||||
class THUCNewsPipe(CLSBasePipe): | class THUCNewsPipe(CLSBasePipe): | ||||
""" | |||||
r""" | |||||
处理之后的DataSet有以下的结构 | 处理之后的DataSet有以下的结构 | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -658,7 +658,7 @@ class THUCNewsPipe(CLSBasePipe): | |||||
return data_bundle | return data_bundle | ||||
def process(self, data_bundle: DataBundle): | def process(self, data_bundle: DataBundle): | ||||
""" | |||||
r""" | |||||
可处理的DataSet应具备如下的field | 可处理的DataSet应具备如下的field | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -710,7 +710,7 @@ class THUCNewsPipe(CLSBasePipe): | |||||
return data_bundle | return data_bundle | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
""" | |||||
r""" | |||||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | ||||
:return: DataBundle | :return: DataBundle | ||||
""" | """ | ||||
@@ -721,7 +721,7 @@ class THUCNewsPipe(CLSBasePipe): | |||||
class WeiboSenti100kPipe(CLSBasePipe): | class WeiboSenti100kPipe(CLSBasePipe): | ||||
""" | |||||
r""" | |||||
处理之后的DataSet有以下的结构 | 处理之后的DataSet有以下的结构 | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -766,7 +766,7 @@ class WeiboSenti100kPipe(CLSBasePipe): | |||||
return data_bundle | return data_bundle | ||||
def process(self, data_bundle: DataBundle): | def process(self, data_bundle: DataBundle): | ||||
""" | |||||
r""" | |||||
可处理的DataSet应具备以下的field | 可处理的DataSet应具备以下的field | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -814,7 +814,7 @@ class WeiboSenti100kPipe(CLSBasePipe): | |||||
return data_bundle | return data_bundle | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
""" | |||||
r""" | |||||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | ||||
:return: DataBundle | :return: DataBundle | ||||
""" | """ | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"Conll2003NERPipe", | "Conll2003NERPipe", | ||||
@@ -21,7 +21,7 @@ from ...core.vocabulary import Vocabulary | |||||
class _NERPipe(Pipe): | class _NERPipe(Pipe): | ||||
""" | |||||
r""" | |||||
NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 | NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 | ||||
(创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 | (创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 | ||||
Vocabulary转换为index。 | Vocabulary转换为index。 | ||||
@@ -31,7 +31,7 @@ class _NERPipe(Pipe): | |||||
""" | """ | ||||
def __init__(self, encoding_type: str = 'bio', lower: bool = False): | def __init__(self, encoding_type: str = 'bio', lower: bool = False): | ||||
""" | |||||
r""" | |||||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | ||||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | ||||
@@ -45,7 +45,7 @@ class _NERPipe(Pipe): | |||||
self.lower = lower | self.lower = lower | ||||
def process(self, data_bundle: DataBundle) -> DataBundle: | def process(self, data_bundle: DataBundle) -> DataBundle: | ||||
""" | |||||
r""" | |||||
支持的DataSet的field为 | 支持的DataSet的field为 | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -80,7 +80,7 @@ class _NERPipe(Pipe): | |||||
class Conll2003NERPipe(_NERPipe): | class Conll2003NERPipe(_NERPipe): | ||||
""" | |||||
r""" | |||||
Conll2003的NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 | Conll2003的NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 | ||||
(创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 | (创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 | ||||
Vocabulary转换为index。 | Vocabulary转换为index。 | ||||
@@ -110,7 +110,7 @@ class Conll2003NERPipe(_NERPipe): | |||||
""" | """ | ||||
def process_from_file(self, paths) -> DataBundle: | def process_from_file(self, paths) -> DataBundle: | ||||
""" | |||||
r""" | |||||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.ConllLoader` 的load函数。 | :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.ConllLoader` 的load函数。 | ||||
:return: DataBundle | :return: DataBundle | ||||
@@ -123,7 +123,7 @@ class Conll2003NERPipe(_NERPipe): | |||||
class Conll2003Pipe(Pipe): | class Conll2003Pipe(Pipe): | ||||
""" | |||||
r""" | |||||
经过该Pipe后,DataSet中的内容如下 | 经过该Pipe后,DataSet中的内容如下 | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -148,7 +148,7 @@ class Conll2003Pipe(Pipe): | |||||
""" | """ | ||||
def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False): | def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False): | ||||
""" | |||||
r""" | |||||
:param str chunk_encoding_type: 支持bioes, bio。 | :param str chunk_encoding_type: 支持bioes, bio。 | ||||
:param str ner_encoding_type: 支持bioes, bio。 | :param str ner_encoding_type: 支持bioes, bio。 | ||||
@@ -169,7 +169,7 @@ class Conll2003Pipe(Pipe): | |||||
self.lower = lower | self.lower = lower | ||||
def process(self, data_bundle) -> DataBundle: | def process(self, data_bundle) -> DataBundle: | ||||
""" | |||||
r""" | |||||
输入的DataSet应该类似于如下的形式 | 输入的DataSet应该类似于如下的形式 | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -210,7 +210,7 @@ class Conll2003Pipe(Pipe): | |||||
return data_bundle | return data_bundle | ||||
def process_from_file(self, paths): | def process_from_file(self, paths): | ||||
""" | |||||
r""" | |||||
:param paths: | :param paths: | ||||
:return: | :return: | ||||
@@ -220,7 +220,7 @@ class Conll2003Pipe(Pipe): | |||||
class OntoNotesNERPipe(_NERPipe): | class OntoNotesNERPipe(_NERPipe): | ||||
""" | |||||
r""" | |||||
处理OntoNotes的NER数据,处理之后DataSet中的field情况为 | 处理OntoNotes的NER数据,处理之后DataSet中的field情况为 | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -252,7 +252,7 @@ class OntoNotesNERPipe(_NERPipe): | |||||
class _CNNERPipe(Pipe): | class _CNNERPipe(Pipe): | ||||
""" | |||||
r""" | |||||
中文NER任务的处理Pipe, 该Pipe会(1)复制raw_chars列,并命名为chars; (2)在chars, target列建立词表 | 中文NER任务的处理Pipe, 该Pipe会(1)复制raw_chars列,并命名为chars; (2)在chars, target列建立词表 | ||||
(创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将chars,target列根据相应的 | (创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将chars,target列根据相应的 | ||||
Vocabulary转换为index。 | Vocabulary转换为index。 | ||||
@@ -263,7 +263,7 @@ class _CNNERPipe(Pipe): | |||||
""" | """ | ||||
def __init__(self, encoding_type: str = 'bio', bigrams=False, trigrams=False): | def __init__(self, encoding_type: str = 'bio', bigrams=False, trigrams=False): | ||||
""" | |||||
r""" | |||||
:param str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | :param str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | ||||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | :param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | ||||
@@ -284,7 +284,7 @@ class _CNNERPipe(Pipe): | |||||
self.trigrams = trigrams | self.trigrams = trigrams | ||||
def process(self, data_bundle: DataBundle) -> DataBundle: | def process(self, data_bundle: DataBundle) -> DataBundle: | ||||
""" | |||||
r""" | |||||
支持的DataSet的field为 | 支持的DataSet的field为 | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -335,7 +335,7 @@ class _CNNERPipe(Pipe): | |||||
class MsraNERPipe(_CNNERPipe): | class MsraNERPipe(_CNNERPipe): | ||||
""" | |||||
r""" | |||||
处理MSRA-NER的数据,处理之后的DataSet的field情况为 | 处理MSRA-NER的数据,处理之后的DataSet的field情况为 | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -367,7 +367,7 @@ class MsraNERPipe(_CNNERPipe): | |||||
class PeopleDailyPipe(_CNNERPipe): | class PeopleDailyPipe(_CNNERPipe): | ||||
""" | |||||
r""" | |||||
处理people daily的ner的数据,处理之后的DataSet的field情况为 | 处理people daily的ner的数据,处理之后的DataSet的field情况为 | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -399,7 +399,7 @@ class PeopleDailyPipe(_CNNERPipe): | |||||
class WeiboNERPipe(_CNNERPipe): | class WeiboNERPipe(_CNNERPipe): | ||||
""" | |||||
r""" | |||||
处理weibo的ner的数据,处理之后的DataSet的field情况为 | 处理weibo的ner的数据,处理之后的DataSet的field情况为 | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"CoReferencePipe" | "CoReferencePipe" | ||||
@@ -16,7 +16,7 @@ from ...core.const import Const | |||||
class CoReferencePipe(Pipe): | class CoReferencePipe(Pipe): | ||||
""" | |||||
r""" | |||||
对Coreference resolution问题进行处理,得到文章种类/说话者/字符级信息/序列长度。 | 对Coreference resolution问题进行处理,得到文章种类/说话者/字符级信息/序列长度。 | ||||
处理完成后数据包含文章类别、speaker信息、句子信息、句子对应的index、char、句子长度、target: | 处理完成后数据包含文章类别、speaker信息、句子信息、句子对应的index、char、句子长度、target: | ||||
@@ -45,7 +45,7 @@ class CoReferencePipe(Pipe): | |||||
self.config = config | self.config = config | ||||
def process(self, data_bundle: DataBundle): | def process(self, data_bundle: DataBundle): | ||||
""" | |||||
r""" | |||||
对load进来的数据进一步处理原始数据包含:raw_key,raw_speaker,raw_words,raw_clusters | 对load进来的数据进一步处理原始数据包含:raw_key,raw_speaker,raw_words,raw_clusters | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"CWSPipe" | "CWSPipe" | ||||
@@ -15,7 +15,7 @@ from ...core.const import Const | |||||
def _word_lens_to_bmes(word_lens): | def _word_lens_to_bmes(word_lens): | ||||
""" | |||||
r""" | |||||
:param list word_lens: List[int], 每个词语的长度 | :param list word_lens: List[int], 每个词语的长度 | ||||
:return: List[str], BMES的序列 | :return: List[str], BMES的序列 | ||||
@@ -32,7 +32,7 @@ def _word_lens_to_bmes(word_lens): | |||||
def _word_lens_to_segapp(word_lens): | def _word_lens_to_segapp(word_lens): | ||||
""" | |||||
r""" | |||||
:param list word_lens: List[int], 每个词语的长度 | :param list word_lens: List[int], 每个词语的长度 | ||||
:return: List[str], BMES的序列 | :return: List[str], BMES的序列 | ||||
@@ -48,7 +48,7 @@ def _word_lens_to_segapp(word_lens): | |||||
def _alpha_span_to_special_tag(span): | def _alpha_span_to_special_tag(span): | ||||
""" | |||||
r""" | |||||
将span替换成特殊的字符 | 将span替换成特殊的字符 | ||||
:param str span: | :param str span: | ||||
@@ -63,7 +63,7 @@ def _alpha_span_to_special_tag(span): | |||||
def _find_and_replace_alpha_spans(line): | def _find_and_replace_alpha_spans(line): | ||||
""" | |||||
r""" | |||||
传入原始句子,替换其中的字母为特殊标记 | 传入原始句子,替换其中的字母为特殊标记 | ||||
:param str line:原始数据 | :param str line:原始数据 | ||||
@@ -82,7 +82,7 @@ def _find_and_replace_alpha_spans(line): | |||||
def _digit_span_to_special_tag(span): | def _digit_span_to_special_tag(span): | ||||
""" | |||||
r""" | |||||
:param str span: 需要替换的str | :param str span: 需要替换的str | ||||
:return: | :return: | ||||
@@ -108,7 +108,7 @@ def _digit_span_to_special_tag(span): | |||||
def _find_and_replace_digit_spans(line): | def _find_and_replace_digit_spans(line): | ||||
""" | |||||
r""" | |||||
only consider words start with number, contains '.', characters. | only consider words start with number, contains '.', characters. | ||||
If ends with space, will be processed | If ends with space, will be processed | ||||
@@ -134,7 +134,7 @@ def _find_and_replace_digit_spans(line): | |||||
class CWSPipe(Pipe): | class CWSPipe(Pipe): | ||||
""" | |||||
r""" | |||||
对CWS数据进行预处理, 处理之后的数据,具备以下的结构 | 对CWS数据进行预处理, 处理之后的数据,具备以下的结构 | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -158,7 +158,7 @@ class CWSPipe(Pipe): | |||||
""" | """ | ||||
def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, bigrams=False, trigrams=False): | def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, bigrams=False, trigrams=False): | ||||
""" | |||||
r""" | |||||
:param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None | :param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None | ||||
:param str encoding_type: 可以选择'bmes', 'segapp'两种。"我 来自 复旦大学...", bmes的tag为[S, B, E, B, M, M, E...]; segapp | :param str encoding_type: 可以选择'bmes', 'segapp'两种。"我 来自 复旦大学...", bmes的tag为[S, B, E, B, M, M, E...]; segapp | ||||
@@ -178,7 +178,7 @@ class CWSPipe(Pipe): | |||||
self.replace_num_alpha = replace_num_alpha | self.replace_num_alpha = replace_num_alpha | ||||
def _tokenize(self, data_bundle): | def _tokenize(self, data_bundle): | ||||
""" | |||||
r""" | |||||
将data_bundle中的'chars'列切分成一个一个的word. | 将data_bundle中的'chars'列切分成一个一个的word. | ||||
例如输入是"共同 创造 美好.."->[[共, 同], [创, 造], [...], ] | 例如输入是"共同 创造 美好.."->[[共, 同], [创, 造], [...], ] | ||||
@@ -216,7 +216,7 @@ class CWSPipe(Pipe): | |||||
return data_bundle | return data_bundle | ||||
def process(self, data_bundle: DataBundle) -> DataBundle: | def process(self, data_bundle: DataBundle) -> DataBundle: | ||||
""" | |||||
r""" | |||||
可以处理的DataSet需要包含raw_words列 | 可以处理的DataSet需要包含raw_words列 | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -268,7 +268,7 @@ class CWSPipe(Pipe): | |||||
return data_bundle | return data_bundle | ||||
def process_from_file(self, paths=None) -> DataBundle: | def process_from_file(self, paths=None) -> DataBundle: | ||||
""" | |||||
r""" | |||||
:param str paths: | :param str paths: | ||||
:return: | :return: | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"MatchingBertPipe", | "MatchingBertPipe", | ||||
@@ -37,7 +37,7 @@ from ...core.vocabulary import Vocabulary | |||||
class MatchingBertPipe(Pipe): | class MatchingBertPipe(Pipe): | ||||
""" | |||||
r""" | |||||
Matching任务的Bert pipe,输出的DataSet将包含以下的field | Matching任务的Bert pipe,输出的DataSet将包含以下的field | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -65,7 +65,7 @@ class MatchingBertPipe(Pipe): | |||||
""" | """ | ||||
def __init__(self, lower=False, tokenizer: str = 'raw'): | def __init__(self, lower=False, tokenizer: str = 'raw'): | ||||
""" | |||||
r""" | |||||
:param bool lower: 是否将word小写化。 | :param bool lower: 是否将word小写化。 | ||||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | ||||
@@ -76,7 +76,7 @@ class MatchingBertPipe(Pipe): | |||||
self.tokenizer = get_tokenizer(tokenize_method=tokenizer) | self.tokenizer = get_tokenizer(tokenize_method=tokenizer) | ||||
def _tokenize(self, data_bundle, field_names, new_field_names): | def _tokenize(self, data_bundle, field_names, new_field_names): | ||||
""" | |||||
r""" | |||||
:param DataBundle data_bundle: DataBundle. | :param DataBundle data_bundle: DataBundle. | ||||
:param list field_names: List[str], 需要tokenize的field名称 | :param list field_names: List[str], 需要tokenize的field名称 | ||||
@@ -90,7 +90,7 @@ class MatchingBertPipe(Pipe): | |||||
return data_bundle | return data_bundle | ||||
def process(self, data_bundle): | def process(self, data_bundle): | ||||
""" | |||||
r""" | |||||
输入的data_bundle中的dataset需要具有以下结构: | 输入的data_bundle中的dataset需要具有以下结构: | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -201,7 +201,7 @@ class MNLIBertPipe(MatchingBertPipe): | |||||
class MatchingPipe(Pipe): | class MatchingPipe(Pipe): | ||||
""" | |||||
r""" | |||||
Matching任务的Pipe。输出的DataSet将包含以下的field | Matching任务的Pipe。输出的DataSet将包含以下的field | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -229,7 +229,7 @@ class MatchingPipe(Pipe): | |||||
""" | """ | ||||
def __init__(self, lower=False, tokenizer: str = 'raw'): | def __init__(self, lower=False, tokenizer: str = 'raw'): | ||||
""" | |||||
r""" | |||||
:param bool lower: 是否将所有raw_words转为小写。 | :param bool lower: 是否将所有raw_words转为小写。 | ||||
:param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。 | :param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。 | ||||
@@ -240,7 +240,7 @@ class MatchingPipe(Pipe): | |||||
self.tokenizer = get_tokenizer(tokenize_method=tokenizer) | self.tokenizer = get_tokenizer(tokenize_method=tokenizer) | ||||
def _tokenize(self, data_bundle, field_names, new_field_names): | def _tokenize(self, data_bundle, field_names, new_field_names): | ||||
""" | |||||
r""" | |||||
:param ~fastNLP.DataBundle data_bundle: DataBundle. | :param ~fastNLP.DataBundle data_bundle: DataBundle. | ||||
:param list field_names: List[str], 需要tokenize的field名称 | :param list field_names: List[str], 需要tokenize的field名称 | ||||
@@ -254,7 +254,7 @@ class MatchingPipe(Pipe): | |||||
return data_bundle | return data_bundle | ||||
def process(self, data_bundle): | def process(self, data_bundle): | ||||
""" | |||||
r""" | |||||
接受的DataBundle中的DataSet应该具有以下的field, target列可以没有 | 接受的DataBundle中的DataSet应该具有以下的field, target列可以没有 | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -433,7 +433,7 @@ class GranularizePipe(Pipe): | |||||
self.task = task | self.task = task | ||||
def _granularize(self, data_bundle, tag_map): | def _granularize(self, data_bundle, tag_map): | ||||
""" | |||||
r""" | |||||
该函数对data_bundle中'target'列中的内容进行转换。 | 该函数对data_bundle中'target'列中的内容进行转换。 | ||||
:param data_bundle: | :param data_bundle: | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"Pipe", | "Pipe", | ||||
@@ -8,7 +8,7 @@ from .. import DataBundle | |||||
class Pipe: | class Pipe: | ||||
""" | |||||
r""" | |||||
Pipe是fastNLP中用于处理DataBundle的类,但实际是处理DataBundle中的DataSet。所有Pipe都会在其process()函数的文档中指出该Pipe可处理的DataSet应该具备怎样的格式;在Pipe | Pipe是fastNLP中用于处理DataBundle的类,但实际是处理DataBundle中的DataSet。所有Pipe都会在其process()函数的文档中指出该Pipe可处理的DataSet应该具备怎样的格式;在Pipe | ||||
文档中说明该Pipe返回后DataSet的格式以及其field的信息;以及新增的Vocabulary的信息。 | 文档中说明该Pipe返回后DataSet的格式以及其field的信息;以及新增的Vocabulary的信息。 | ||||
@@ -23,7 +23,7 @@ class Pipe: | |||||
""" | """ | ||||
def process(self, data_bundle: DataBundle) -> DataBundle: | def process(self, data_bundle: DataBundle) -> DataBundle: | ||||
""" | |||||
r""" | |||||
对输入的DataBundle进行处理,然后返回该DataBundle。 | 对输入的DataBundle进行处理,然后返回该DataBundle。 | ||||
:param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象 | :param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象 | ||||
@@ -32,7 +32,7 @@ class Pipe: | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def process_from_file(self, paths) -> DataBundle: | def process_from_file(self, paths) -> DataBundle: | ||||
""" | |||||
r""" | |||||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | ||||
:param paths: | :param paths: | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
本文件中的Pipe主要用于处理问答任务的数据。 | 本文件中的Pipe主要用于处理问答任务的数据。 | ||||
""" | """ | ||||
@@ -17,7 +17,7 @@ __all__ = ['CMRC2018BertPipe'] | |||||
def _concat_clip(data_bundle, max_len, concat_field_name='raw_chars'): | def _concat_clip(data_bundle, max_len, concat_field_name='raw_chars'): | ||||
""" | |||||
r""" | |||||
处理data_bundle中的DataSet,将context与question按照character进行tokenize,然后使用[SEP]将两者连接起来。 | 处理data_bundle中的DataSet,将context与question按照character进行tokenize,然后使用[SEP]将两者连接起来。 | ||||
会新增field: context_len(int), raw_words(list[str]), target_start(int), target_end(int)其中target_start | 会新增field: context_len(int), raw_words(list[str]), target_start(int), target_end(int)其中target_start | ||||
@@ -78,7 +78,7 @@ def _concat_clip(data_bundle, max_len, concat_field_name='raw_chars'): | |||||
class CMRC2018BertPipe(Pipe): | class CMRC2018BertPipe(Pipe): | ||||
""" | |||||
r""" | |||||
处理之后的DataSet将新增以下的field(传入的field仍然保留) | 处理之后的DataSet将新增以下的field(传入的field仍然保留) | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -111,7 +111,7 @@ class CMRC2018BertPipe(Pipe): | |||||
self.max_len = max_len | self.max_len = max_len | ||||
def process(self, data_bundle: DataBundle) -> DataBundle: | def process(self, data_bundle: DataBundle) -> DataBundle: | ||||
""" | |||||
r""" | |||||
传入的DataSet应该具备以下的field | 传入的DataSet应该具备以下的field | ||||
.. csv-table:: | .. csv-table:: | ||||
@@ -1,197 +1,197 @@ | |||||
"""undocumented""" | |||||
import os | |||||
import numpy as np | |||||
from .pipe import Pipe | |||||
from .utils import _drop_empty_instance | |||||
from ..loader.summarization import ExtCNNDMLoader | |||||
from ..data_bundle import DataBundle | |||||
from ...core.const import Const | |||||
from ...core.vocabulary import Vocabulary | |||||
from ...core._logger import logger | |||||
WORD_PAD = "[PAD]" | |||||
WORD_UNK = "[UNK]" | |||||
DOMAIN_UNK = "X" | |||||
TAG_UNK = "X" | |||||
class ExtCNNDMPipe(Pipe): | |||||
""" | |||||
对CNN/Daily Mail数据进行适用于extractive summarization task的预处理,预处理之后的数据,具备以下结构: | |||||
.. csv-table:: | |||||
:header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target" | |||||
""" | |||||
def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False): | |||||
""" | |||||
:param vocab_size: int, 词表大小 | |||||
:param sent_max_len: int, 句子最大长度,不足的句子将padding,超出的将截断 | |||||
:param doc_max_timesteps: int, 文章最多句子个数,不足的将padding,超出的将截断 | |||||
:param vocab_path: str, 外部词表路径 | |||||
:param domain: bool, 是否需要建立domain词表 | |||||
""" | |||||
self.vocab_size = vocab_size | |||||
self.vocab_path = vocab_path | |||||
self.sent_max_len = sent_max_len | |||||
self.doc_max_timesteps = doc_max_timesteps | |||||
self.domain = domain | |||||
def process(self, data_bundle: DataBundle): | |||||
""" | |||||
传入的DataSet应该具备如下的结构 | |||||
.. csv-table:: | |||||
:header: "text", "summary", "label", "publication" | |||||
["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm" | |||||
["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm" | |||||
["..."], ["..."], [], "cnndm" | |||||
:param data_bundle: | |||||
:return: 处理得到的数据包括 | |||||
.. csv-table:: | |||||
:header: "text_wd", "words", "seq_len", "target" | |||||
[["I","got",..."."],...,["..."]], [[54,89,...,5],...,[9,43,..,0]], [1,1,...,0], [0,1,...,0] | |||||
[["Don't","waste",...,"."],...,["..."]], [[5234,653,...,5],...,[87,234,..,0]], [1,1,...,0], [1,1,...,0] | |||||
[[""],...,[""]], [[],...,[]], [], [] | |||||
""" | |||||
if self.vocab_path is None: | |||||
error_msg = 'vocab file is not defined!' | |||||
logger.error(error_msg) | |||||
raise RuntimeError(error_msg) | |||||
data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||||
data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||||
data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') | |||||
data_bundle.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name=Const.TARGET) | |||||
data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name=Const.INPUT) | |||||
# db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask") | |||||
# pad document | |||||
data_bundle.apply(lambda x: _pad_doc(x[Const.INPUT], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT) | |||||
data_bundle.apply(lambda x: _sent_mask(x[Const.INPUT], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN) | |||||
data_bundle.apply(lambda x: _pad_label(x[Const.TARGET], self.doc_max_timesteps), new_field_name=Const.TARGET) | |||||
data_bundle = _drop_empty_instance(data_bundle, "label") | |||||
# set input and target | |||||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
data_bundle.set_target(Const.TARGET, Const.INPUT_LEN) | |||||
# print("[INFO] Load existing vocab from %s!" % self.vocab_path) | |||||
word_list = [] | |||||
with open(self.vocab_path, 'r', encoding='utf8') as vocab_f: | |||||
cnt = 2 # pad and unk | |||||
for line in vocab_f: | |||||
pieces = line.split("\t") | |||||
word_list.append(pieces[0]) | |||||
cnt += 1 | |||||
if cnt > self.vocab_size: | |||||
break | |||||
vocabs = Vocabulary(max_size=self.vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||||
vocabs.add_word_lst(word_list) | |||||
vocabs.build_vocab() | |||||
data_bundle.set_vocab(vocabs, "vocab") | |||||
if self.domain is True: | |||||
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) | |||||
domaindict.from_dataset(data_bundle.get_dataset("train"), field_name="publication") | |||||
data_bundle.set_vocab(domaindict, "domain") | |||||
return data_bundle | |||||
def process_from_file(self, paths=None): | |||||
""" | |||||
:param paths: dict or string | |||||
:return: DataBundle | |||||
""" | |||||
loader = ExtCNNDMLoader() | |||||
if self.vocab_path is None: | |||||
if paths is None: | |||||
paths = loader.download() | |||||
if not os.path.isdir(paths): | |||||
error_msg = 'vocab file is not defined!' | |||||
logger.error(error_msg) | |||||
raise RuntimeError(error_msg) | |||||
self.vocab_path = os.path.join(paths, 'vocab') | |||||
db = loader.load(paths=paths) | |||||
db = self.process(db) | |||||
for ds in db.datasets.values(): | |||||
db.get_vocab("vocab").index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||||
return db | |||||
def _lower_text(text_list): | |||||
return [text.lower() for text in text_list] | |||||
def _split_list(text_list): | |||||
return [text.split() for text in text_list] | |||||
def _convert_label(label, sent_len): | |||||
np_label = np.zeros(sent_len, dtype=int) | |||||
if label != []: | |||||
np_label[np.array(label)] = 1 | |||||
return np_label.tolist() | |||||
def _pad_sent(text_wd, sent_max_len): | |||||
pad_text_wd = [] | |||||
for sent_wd in text_wd: | |||||
if len(sent_wd) < sent_max_len: | |||||
pad_num = sent_max_len - len(sent_wd) | |||||
sent_wd.extend([WORD_PAD] * pad_num) | |||||
else: | |||||
sent_wd = sent_wd[:sent_max_len] | |||||
pad_text_wd.append(sent_wd) | |||||
return pad_text_wd | |||||
def _token_mask(text_wd, sent_max_len): | |||||
token_mask_list = [] | |||||
for sent_wd in text_wd: | |||||
token_num = len(sent_wd) | |||||
if token_num < sent_max_len: | |||||
mask = [1] * token_num + [0] * (sent_max_len - token_num) | |||||
else: | |||||
mask = [1] * sent_max_len | |||||
token_mask_list.append(mask) | |||||
return token_mask_list | |||||
def _pad_label(label, doc_max_timesteps): | |||||
text_len = len(label) | |||||
if text_len < doc_max_timesteps: | |||||
pad_label = label + [0] * (doc_max_timesteps - text_len) | |||||
else: | |||||
pad_label = label[:doc_max_timesteps] | |||||
return pad_label | |||||
def _pad_doc(text_wd, sent_max_len, doc_max_timesteps): | |||||
text_len = len(text_wd) | |||||
if text_len < doc_max_timesteps: | |||||
padding = [WORD_PAD] * sent_max_len | |||||
pad_text = text_wd + [padding] * (doc_max_timesteps - text_len) | |||||
else: | |||||
pad_text = text_wd[:doc_max_timesteps] | |||||
return pad_text | |||||
def _sent_mask(text_wd, doc_max_timesteps): | |||||
text_len = len(text_wd) | |||||
if text_len < doc_max_timesteps: | |||||
sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len) | |||||
else: | |||||
sent_mask = [1] * doc_max_timesteps | |||||
return sent_mask | |||||
r"""undocumented""" | |||||
import os | |||||
import numpy as np | |||||
from .pipe import Pipe | |||||
from .utils import _drop_empty_instance | |||||
from ..loader.summarization import ExtCNNDMLoader | |||||
from ..data_bundle import DataBundle | |||||
from ...core.const import Const | |||||
from ...core.vocabulary import Vocabulary | |||||
from ...core._logger import logger | |||||
WORD_PAD = "[PAD]" | |||||
WORD_UNK = "[UNK]" | |||||
DOMAIN_UNK = "X" | |||||
TAG_UNK = "X" | |||||
class ExtCNNDMPipe(Pipe): | |||||
r""" | |||||
对CNN/Daily Mail数据进行适用于extractive summarization task的预处理,预处理之后的数据,具备以下结构: | |||||
.. csv-table:: | |||||
:header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target" | |||||
""" | |||||
def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False): | |||||
r""" | |||||
:param vocab_size: int, 词表大小 | |||||
:param sent_max_len: int, 句子最大长度,不足的句子将padding,超出的将截断 | |||||
:param doc_max_timesteps: int, 文章最多句子个数,不足的将padding,超出的将截断 | |||||
:param vocab_path: str, 外部词表路径 | |||||
:param domain: bool, 是否需要建立domain词表 | |||||
""" | |||||
self.vocab_size = vocab_size | |||||
self.vocab_path = vocab_path | |||||
self.sent_max_len = sent_max_len | |||||
self.doc_max_timesteps = doc_max_timesteps | |||||
self.domain = domain | |||||
def process(self, data_bundle: DataBundle): | |||||
r""" | |||||
传入的DataSet应该具备如下的结构 | |||||
.. csv-table:: | |||||
:header: "text", "summary", "label", "publication" | |||||
["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm" | |||||
["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm" | |||||
["..."], ["..."], [], "cnndm" | |||||
:param data_bundle: | |||||
:return: 处理得到的数据包括 | |||||
.. csv-table:: | |||||
:header: "text_wd", "words", "seq_len", "target" | |||||
[["I","got",..."."],...,["..."]], [[54,89,...,5],...,[9,43,..,0]], [1,1,...,0], [0,1,...,0] | |||||
[["Don't","waste",...,"."],...,["..."]], [[5234,653,...,5],...,[87,234,..,0]], [1,1,...,0], [1,1,...,0] | |||||
[[""],...,[""]], [[],...,[]], [], [] | |||||
""" | |||||
if self.vocab_path is None: | |||||
error_msg = 'vocab file is not defined!' | |||||
logger.error(error_msg) | |||||
raise RuntimeError(error_msg) | |||||
data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||||
data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||||
data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') | |||||
data_bundle.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name=Const.TARGET) | |||||
data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name=Const.INPUT) | |||||
# db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask") | |||||
# pad document | |||||
data_bundle.apply(lambda x: _pad_doc(x[Const.INPUT], self.sent_max_len, self.doc_max_timesteps), new_field_name=Const.INPUT) | |||||
data_bundle.apply(lambda x: _sent_mask(x[Const.INPUT], self.doc_max_timesteps), new_field_name=Const.INPUT_LEN) | |||||
data_bundle.apply(lambda x: _pad_label(x[Const.TARGET], self.doc_max_timesteps), new_field_name=Const.TARGET) | |||||
data_bundle = _drop_empty_instance(data_bundle, "label") | |||||
# set input and target | |||||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
data_bundle.set_target(Const.TARGET, Const.INPUT_LEN) | |||||
# print("[INFO] Load existing vocab from %s!" % self.vocab_path) | |||||
word_list = [] | |||||
with open(self.vocab_path, 'r', encoding='utf8') as vocab_f: | |||||
cnt = 2 # pad and unk | |||||
for line in vocab_f: | |||||
pieces = line.split("\t") | |||||
word_list.append(pieces[0]) | |||||
cnt += 1 | |||||
if cnt > self.vocab_size: | |||||
break | |||||
vocabs = Vocabulary(max_size=self.vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||||
vocabs.add_word_lst(word_list) | |||||
vocabs.build_vocab() | |||||
data_bundle.set_vocab(vocabs, "vocab") | |||||
if self.domain is True: | |||||
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) | |||||
domaindict.from_dataset(data_bundle.get_dataset("train"), field_name="publication") | |||||
data_bundle.set_vocab(domaindict, "domain") | |||||
return data_bundle | |||||
def process_from_file(self, paths=None): | |||||
r""" | |||||
:param paths: dict or string | |||||
:return: DataBundle | |||||
""" | |||||
loader = ExtCNNDMLoader() | |||||
if self.vocab_path is None: | |||||
if paths is None: | |||||
paths = loader.download() | |||||
if not os.path.isdir(paths): | |||||
error_msg = 'vocab file is not defined!' | |||||
logger.error(error_msg) | |||||
raise RuntimeError(error_msg) | |||||
self.vocab_path = os.path.join(paths, 'vocab') | |||||
db = loader.load(paths=paths) | |||||
db = self.process(db) | |||||
for ds in db.datasets.values(): | |||||
db.get_vocab("vocab").index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||||
return db | |||||
def _lower_text(text_list): | |||||
return [text.lower() for text in text_list] | |||||
def _split_list(text_list): | |||||
return [text.split() for text in text_list] | |||||
def _convert_label(label, sent_len): | |||||
np_label = np.zeros(sent_len, dtype=int) | |||||
if label != []: | |||||
np_label[np.array(label)] = 1 | |||||
return np_label.tolist() | |||||
def _pad_sent(text_wd, sent_max_len): | |||||
pad_text_wd = [] | |||||
for sent_wd in text_wd: | |||||
if len(sent_wd) < sent_max_len: | |||||
pad_num = sent_max_len - len(sent_wd) | |||||
sent_wd.extend([WORD_PAD] * pad_num) | |||||
else: | |||||
sent_wd = sent_wd[:sent_max_len] | |||||
pad_text_wd.append(sent_wd) | |||||
return pad_text_wd | |||||
def _token_mask(text_wd, sent_max_len): | |||||
token_mask_list = [] | |||||
for sent_wd in text_wd: | |||||
token_num = len(sent_wd) | |||||
if token_num < sent_max_len: | |||||
mask = [1] * token_num + [0] * (sent_max_len - token_num) | |||||
else: | |||||
mask = [1] * sent_max_len | |||||
token_mask_list.append(mask) | |||||
return token_mask_list | |||||
def _pad_label(label, doc_max_timesteps): | |||||
text_len = len(label) | |||||
if text_len < doc_max_timesteps: | |||||
pad_label = label + [0] * (doc_max_timesteps - text_len) | |||||
else: | |||||
pad_label = label[:doc_max_timesteps] | |||||
return pad_label | |||||
def _pad_doc(text_wd, sent_max_len, doc_max_timesteps): | |||||
text_len = len(text_wd) | |||||
if text_len < doc_max_timesteps: | |||||
padding = [WORD_PAD] * sent_max_len | |||||
pad_text = text_wd + [padding] * (doc_max_timesteps - text_len) | |||||
else: | |||||
pad_text = text_wd[:doc_max_timesteps] | |||||
return pad_text | |||||
def _sent_mask(text_wd, doc_max_timesteps): | |||||
text_len = len(text_wd) | |||||
if text_len < doc_max_timesteps: | |||||
sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len) | |||||
else: | |||||
sent_mask = [1] * doc_max_timesteps | |||||
return sent_mask | |||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"iob2", | "iob2", | ||||
@@ -15,7 +15,7 @@ from ...core._logger import logger | |||||
def iob2(tags: List[str]) -> List[str]: | def iob2(tags: List[str]) -> List[str]: | ||||
""" | |||||
r""" | |||||
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见 | 检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见 | ||||
https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format | https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format | ||||
@@ -39,7 +39,7 @@ def iob2(tags: List[str]) -> List[str]: | |||||
def iob2bioes(tags: List[str]) -> List[str]: | def iob2bioes(tags: List[str]) -> List[str]: | ||||
""" | |||||
r""" | |||||
将iob的tag转换为bioes编码 | 将iob的tag转换为bioes编码 | ||||
:param tags: | :param tags: | ||||
:return: | :return: | ||||
@@ -66,7 +66,7 @@ def iob2bioes(tags: List[str]) -> List[str]: | |||||
def get_tokenizer(tokenize_method: str, lang='en'): | def get_tokenizer(tokenize_method: str, lang='en'): | ||||
""" | |||||
r""" | |||||
:param str tokenize_method: 获取tokenzier方法 | :param str tokenize_method: 获取tokenzier方法 | ||||
:param str lang: 语言,当前仅支持en | :param str lang: 语言,当前仅支持en | ||||
@@ -100,7 +100,7 @@ def _raw_split(sent): | |||||
def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Const.TARGET): | def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Const.TARGET): | ||||
""" | |||||
r""" | |||||
在dataset中的field_name列建立词表,Const.TARGET列建立词表,并把词表加入到data_bundle中。 | 在dataset中的field_name列建立词表,Const.TARGET列建立词表,并把词表加入到data_bundle中。 | ||||
:param ~fastNLP.DataBundle data_bundle: | :param ~fastNLP.DataBundle data_bundle: | ||||
@@ -143,7 +143,7 @@ def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Con | |||||
def _add_words_field(data_bundle, lower=False): | def _add_words_field(data_bundle, lower=False): | ||||
""" | |||||
r""" | |||||
给data_bundle中的dataset中复制一列words. 并根据lower参数判断是否需要小写化 | 给data_bundle中的dataset中复制一列words. 并根据lower参数判断是否需要小写化 | ||||
:param data_bundle: | :param data_bundle: | ||||
@@ -159,7 +159,7 @@ def _add_words_field(data_bundle, lower=False): | |||||
def _add_chars_field(data_bundle, lower=False): | def _add_chars_field(data_bundle, lower=False): | ||||
""" | |||||
r""" | |||||
给data_bundle中的dataset中复制一列chars. 并根据lower参数判断是否需要小写化 | 给data_bundle中的dataset中复制一列chars. 并根据lower参数判断是否需要小写化 | ||||
:param data_bundle: | :param data_bundle: | ||||
@@ -175,7 +175,7 @@ def _add_chars_field(data_bundle, lower=False): | |||||
def _drop_empty_instance(data_bundle, field_name): | def _drop_empty_instance(data_bundle, field_name): | ||||
""" | |||||
r""" | |||||
删除data_bundle的DataSet中存在的某个field为空的情况 | 删除data_bundle的DataSet中存在的某个field为空的情况 | ||||
:param ~fastNLP.DataBundle data_bundle: | :param ~fastNLP.DataBundle data_bundle: | ||||
@@ -201,7 +201,7 @@ def _drop_empty_instance(data_bundle, field_name): | |||||
def _granularize(data_bundle, tag_map): | def _granularize(data_bundle, tag_map): | ||||
""" | |||||
r""" | |||||
该函数对data_bundle中'target'列中的内容进行转换。 | 该函数对data_bundle中'target'列中的内容进行转换。 | ||||
:param data_bundle: | :param data_bundle: | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
.. todo:: | .. todo:: | ||||
doc | doc | ||||
""" | """ | ||||
@@ -15,7 +15,7 @@ from ..core import logger | |||||
def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: | def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: | ||||
""" | |||||
r""" | |||||
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果:: | 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果:: | ||||
{ | { | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models.CNNText` 、 | fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models.CNNText` 、 | ||||
:class:`~fastNLP.models.SeqLabeling` 等完整的模型,以供用户直接使用。 | :class:`~fastNLP.models.SeqLabeling` 等完整的模型,以供用户直接使用。 | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [] | __all__ = [] | ||||
@@ -8,7 +8,7 @@ from ..modules.decoder.mlp import MLP | |||||
class BaseModel(torch.nn.Module): | class BaseModel(torch.nn.Module): | ||||
"""Base PyTorch model for all models. | |||||
r"""Base PyTorch model for all models. | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
@@ -22,7 +22,7 @@ class BaseModel(torch.nn.Module): | |||||
class NaiveClassifier(BaseModel): | class NaiveClassifier(BaseModel): | ||||
""" | |||||
r""" | |||||
一个简单的分类器例子,可用于各种测试 | 一个简单的分类器例子,可用于各种测试 | ||||
""" | """ | ||||
def __init__(self, in_feature_dim, out_feature_dim): | def __init__(self, in_feature_dim, out_feature_dim): | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
fastNLP提供了BERT应用到五个下游任务的模型代码,可以直接调用。这五个任务分别为 | fastNLP提供了BERT应用到五个下游任务的模型代码,可以直接调用。这五个任务分别为 | ||||
- 文本分类任务: :class:`~fastNLP.models.BertForSequenceClassification` | - 文本分类任务: :class:`~fastNLP.models.BertForSequenceClassification` | ||||
@@ -43,12 +43,12 @@ from ..embeddings import BertEmbedding | |||||
class BertForSequenceClassification(BaseModel): | class BertForSequenceClassification(BaseModel): | ||||
""" | |||||
r""" | |||||
BERT model for classification. | BERT model for classification. | ||||
""" | """ | ||||
def __init__(self, embed: BertEmbedding, num_labels: int=2, dropout=0.1): | def __init__(self, embed: BertEmbedding, num_labels: int=2, dropout=0.1): | ||||
""" | |||||
r""" | |||||
:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder). | :param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder). | ||||
:param int num_labels: 文本分类类别数目,默认值为2. | :param int num_labels: 文本分类类别数目,默认值为2. | ||||
@@ -69,7 +69,7 @@ class BertForSequenceClassification(BaseModel): | |||||
warnings.warn(warn_msg) | warnings.warn(warn_msg) | ||||
def forward(self, words): | def forward(self, words): | ||||
""" | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, seq_len] | :param torch.LongTensor words: [batch_size, seq_len] | ||||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.Tensor [batch_size, num_labels] | :return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.Tensor [batch_size, num_labels] | ||||
""" | """ | ||||
@@ -80,7 +80,7 @@ class BertForSequenceClassification(BaseModel): | |||||
return {Const.OUTPUT: logits} | return {Const.OUTPUT: logits} | ||||
def predict(self, words): | def predict(self, words): | ||||
""" | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, seq_len] | :param torch.LongTensor words: [batch_size, seq_len] | ||||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size] | :return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size] | ||||
""" | """ | ||||
@@ -89,12 +89,12 @@ class BertForSequenceClassification(BaseModel): | |||||
class BertForSentenceMatching(BaseModel): | class BertForSentenceMatching(BaseModel): | ||||
""" | |||||
r""" | |||||
BERT model for sentence matching. | BERT model for sentence matching. | ||||
""" | """ | ||||
def __init__(self, embed: BertEmbedding, num_labels: int=2, dropout=0.1): | def __init__(self, embed: BertEmbedding, num_labels: int=2, dropout=0.1): | ||||
""" | |||||
r""" | |||||
:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder). | :param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder). | ||||
:param int num_labels: Matching任务类别数目,默认值为2. | :param int num_labels: Matching任务类别数目,默认值为2. | ||||
@@ -114,7 +114,7 @@ class BertForSentenceMatching(BaseModel): | |||||
warnings.warn(warn_msg) | warnings.warn(warn_msg) | ||||
def forward(self, words): | def forward(self, words): | ||||
""" | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, seq_len] | :param torch.LongTensor words: [batch_size, seq_len] | ||||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.Tensor [batch_size, num_labels] | :return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.Tensor [batch_size, num_labels] | ||||
""" | """ | ||||
@@ -125,7 +125,7 @@ class BertForSentenceMatching(BaseModel): | |||||
return {Const.OUTPUT: logits} | return {Const.OUTPUT: logits} | ||||
def predict(self, words): | def predict(self, words): | ||||
""" | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, seq_len] | :param torch.LongTensor words: [batch_size, seq_len] | ||||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size] | :return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size] | ||||
""" | """ | ||||
@@ -134,12 +134,12 @@ class BertForSentenceMatching(BaseModel): | |||||
class BertForMultipleChoice(BaseModel): | class BertForMultipleChoice(BaseModel): | ||||
""" | |||||
r""" | |||||
BERT model for multiple choice. | BERT model for multiple choice. | ||||
""" | """ | ||||
def __init__(self, embed: BertEmbedding, num_choices=2, dropout=0.1): | def __init__(self, embed: BertEmbedding, num_choices=2, dropout=0.1): | ||||
""" | |||||
r""" | |||||
:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder). | :param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder). | ||||
:param int num_choices: 多选任务选项数目,默认值为2. | :param int num_choices: 多选任务选项数目,默认值为2. | ||||
@@ -160,7 +160,7 @@ class BertForMultipleChoice(BaseModel): | |||||
warnings.warn(warn_msg) | warnings.warn(warn_msg) | ||||
def forward(self, words): | def forward(self, words): | ||||
""" | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, num_choices, seq_len] | :param torch.LongTensor words: [batch_size, num_choices, seq_len] | ||||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size, num_choices] | :return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size, num_choices] | ||||
""" | """ | ||||
@@ -175,7 +175,7 @@ class BertForMultipleChoice(BaseModel): | |||||
return {Const.OUTPUT: reshaped_logits} | return {Const.OUTPUT: reshaped_logits} | ||||
def predict(self, words): | def predict(self, words): | ||||
""" | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, num_choices, seq_len] | :param torch.LongTensor words: [batch_size, num_choices, seq_len] | ||||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size] | :return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size] | ||||
""" | """ | ||||
@@ -184,12 +184,12 @@ class BertForMultipleChoice(BaseModel): | |||||
class BertForTokenClassification(BaseModel): | class BertForTokenClassification(BaseModel): | ||||
""" | |||||
r""" | |||||
BERT model for token classification. | BERT model for token classification. | ||||
""" | """ | ||||
def __init__(self, embed: BertEmbedding, num_labels, dropout=0.1): | def __init__(self, embed: BertEmbedding, num_labels, dropout=0.1): | ||||
""" | |||||
r""" | |||||
:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder). | :param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder). | ||||
:param int num_labels: 序列标注标签数目,无默认值. | :param int num_labels: 序列标注标签数目,无默认值. | ||||
@@ -210,7 +210,7 @@ class BertForTokenClassification(BaseModel): | |||||
warnings.warn(warn_msg) | warnings.warn(warn_msg) | ||||
def forward(self, words): | def forward(self, words): | ||||
""" | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, seq_len] | :param torch.LongTensor words: [batch_size, seq_len] | ||||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.Tensor [batch_size, seq_len, num_labels] | :return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.Tensor [batch_size, seq_len, num_labels] | ||||
""" | """ | ||||
@@ -221,7 +221,7 @@ class BertForTokenClassification(BaseModel): | |||||
return {Const.OUTPUT: logits} | return {Const.OUTPUT: logits} | ||||
def predict(self, words): | def predict(self, words): | ||||
""" | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, seq_len] | :param torch.LongTensor words: [batch_size, seq_len] | ||||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size, seq_len] | :return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size, seq_len] | ||||
""" | """ | ||||
@@ -230,12 +230,12 @@ class BertForTokenClassification(BaseModel): | |||||
class BertForQuestionAnswering(BaseModel): | class BertForQuestionAnswering(BaseModel): | ||||
""" | |||||
r""" | |||||
用于做Q&A的Bert模型,如果是Squad2.0请将BertEmbedding的include_cls_sep设置为True,Squad1.0或CMRC则设置为False | 用于做Q&A的Bert模型,如果是Squad2.0请将BertEmbedding的include_cls_sep设置为True,Squad1.0或CMRC则设置为False | ||||
""" | """ | ||||
def __init__(self, embed: BertEmbedding): | def __init__(self, embed: BertEmbedding): | ||||
""" | |||||
r""" | |||||
:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder). | :param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder). | ||||
:param int num_labels: 抽取式QA列数,默认值为2(即第一列为start_span, 第二列为end_span). | :param int num_labels: 抽取式QA列数,默认值为2(即第一列为start_span, 第二列为end_span). | ||||
@@ -246,7 +246,7 @@ class BertForQuestionAnswering(BaseModel): | |||||
self.qa_outputs = nn.Linear(self.bert.embedding_dim, 2) | self.qa_outputs = nn.Linear(self.bert.embedding_dim, 2) | ||||
def forward(self, words): | def forward(self, words): | ||||
""" | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, seq_len] | :param torch.LongTensor words: [batch_size, seq_len] | ||||
:return: 一个包含num_labels个logit的dict,每一个logit的形状都是[batch_size, seq_len + 2] | :return: 一个包含num_labels个logit的dict,每一个logit的形状都是[batch_size, seq_len + 2] | ||||
""" | """ | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
Biaffine Dependency Parser 的 Pytorch 实现. | Biaffine Dependency Parser 的 Pytorch 实现. | ||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
@@ -26,7 +26,7 @@ from ..modules.utils import initial_parameter | |||||
def _mst(scores): | def _mst(scores): | ||||
""" | |||||
r""" | |||||
with some modification to support parser output for MST decoding | with some modification to support parser output for MST decoding | ||||
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | ||||
""" | """ | ||||
@@ -85,7 +85,7 @@ def _mst(scores): | |||||
def _find_cycle(vertices, edges): | def _find_cycle(vertices, edges): | ||||
""" | |||||
r""" | |||||
https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm | https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm | ||||
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/etc/tarjan.py | https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/etc/tarjan.py | ||||
""" | """ | ||||
@@ -129,7 +129,7 @@ def _find_cycle(vertices, edges): | |||||
class GraphParser(BaseModel): | class GraphParser(BaseModel): | ||||
""" | |||||
r""" | |||||
基于图的parser base class, 支持贪婪解码和最大生成树解码 | 基于图的parser base class, 支持贪婪解码和最大生成树解码 | ||||
""" | """ | ||||
@@ -138,7 +138,7 @@ class GraphParser(BaseModel): | |||||
@staticmethod | @staticmethod | ||||
def greedy_decoder(arc_matrix, mask=None): | def greedy_decoder(arc_matrix, mask=None): | ||||
""" | |||||
r""" | |||||
贪心解码方式, 输入图, 输出贪心解码的parsing结果, 不保证合法的构成树 | 贪心解码方式, 输入图, 输出贪心解码的parsing结果, 不保证合法的构成树 | ||||
:param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵 | :param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵 | ||||
@@ -157,7 +157,7 @@ class GraphParser(BaseModel): | |||||
@staticmethod | @staticmethod | ||||
def mst_decoder(arc_matrix, mask=None): | def mst_decoder(arc_matrix, mask=None): | ||||
""" | |||||
r""" | |||||
用最大生成树算法, 计算parsing结果, 保证输出合法的树结构 | 用最大生成树算法, 计算parsing结果, 保证输出合法的树结构 | ||||
:param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵 | :param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵 | ||||
@@ -178,13 +178,13 @@ class GraphParser(BaseModel): | |||||
class ArcBiaffine(nn.Module): | class ArcBiaffine(nn.Module): | ||||
""" | |||||
r""" | |||||
Biaffine Dependency Parser 的子模块, 用于构建预测边的图 | Biaffine Dependency Parser 的子模块, 用于构建预测边的图 | ||||
""" | """ | ||||
def __init__(self, hidden_size, bias=True): | def __init__(self, hidden_size, bias=True): | ||||
""" | |||||
r""" | |||||
:param hidden_size: 输入的特征维度 | :param hidden_size: 输入的特征维度 | ||||
:param bias: 是否使用bias. Default: ``True`` | :param bias: 是否使用bias. Default: ``True`` | ||||
@@ -199,7 +199,7 @@ class ArcBiaffine(nn.Module): | |||||
initial_parameter(self) | initial_parameter(self) | ||||
def forward(self, head, dep): | def forward(self, head, dep): | ||||
""" | |||||
r""" | |||||
:param head: arc-head tensor [batch, length, hidden] | :param head: arc-head tensor [batch, length, hidden] | ||||
:param dep: arc-dependent tensor [batch, length, hidden] | :param dep: arc-dependent tensor [batch, length, hidden] | ||||
@@ -213,13 +213,13 @@ class ArcBiaffine(nn.Module): | |||||
class LabelBilinear(nn.Module): | class LabelBilinear(nn.Module): | ||||
""" | |||||
r""" | |||||
Biaffine Dependency Parser 的子模块, 用于构建预测边类别的图 | Biaffine Dependency Parser 的子模块, 用于构建预测边类别的图 | ||||
""" | """ | ||||
def __init__(self, in1_features, in2_features, num_label, bias=True): | def __init__(self, in1_features, in2_features, num_label, bias=True): | ||||
""" | |||||
r""" | |||||
:param in1_features: 输入的特征1维度 | :param in1_features: 输入的特征1维度 | ||||
:param in2_features: 输入的特征2维度 | :param in2_features: 输入的特征2维度 | ||||
@@ -231,7 +231,7 @@ class LabelBilinear(nn.Module): | |||||
self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) | self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) | ||||
def forward(self, x1, x2): | def forward(self, x1, x2): | ||||
""" | |||||
r""" | |||||
:param x1: [batch, seq_len, hidden] 输入特征1, 即label-head | :param x1: [batch, seq_len, hidden] 输入特征1, 即label-head | ||||
:param x2: [batch, seq_len, hidden] 输入特征2, 即label-dep | :param x2: [batch, seq_len, hidden] 输入特征2, 即label-dep | ||||
@@ -243,7 +243,7 @@ class LabelBilinear(nn.Module): | |||||
class BiaffineParser(GraphParser): | class BiaffineParser(GraphParser): | ||||
""" | |||||
r""" | |||||
Biaffine Dependency Parser 实现. | Biaffine Dependency Parser 实现. | ||||
论文参考 `Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) <https://arxiv.org/abs/1611.01734>`_ . | 论文参考 `Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) <https://arxiv.org/abs/1611.01734>`_ . | ||||
@@ -261,7 +261,7 @@ class BiaffineParser(GraphParser): | |||||
dropout=0.3, | dropout=0.3, | ||||
encoder='lstm', | encoder='lstm', | ||||
use_greedy_infer=False): | use_greedy_infer=False): | ||||
""" | |||||
r""" | |||||
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | :param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | ||||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, | embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, | ||||
@@ -347,7 +347,7 @@ class BiaffineParser(GraphParser): | |||||
nn.init.normal_(p, 0, 0.1) | nn.init.normal_(p, 0, 0.1) | ||||
def forward(self, words1, words2, seq_len, target1=None): | def forward(self, words1, words2, seq_len, target1=None): | ||||
"""模型forward阶段 | |||||
r"""模型forward阶段 | |||||
:param words1: [batch_size, seq_len] 输入word序列 | :param words1: [batch_size, seq_len] 输入word序列 | ||||
:param words2: [batch_size, seq_len] 输入pos序列 | :param words2: [batch_size, seq_len] 输入pos序列 | ||||
@@ -428,7 +428,7 @@ class BiaffineParser(GraphParser): | |||||
@staticmethod | @staticmethod | ||||
def loss(pred1, pred2, target1, target2, seq_len): | def loss(pred1, pred2, target1, target2, seq_len): | ||||
""" | |||||
r""" | |||||
计算parser的loss | 计算parser的loss | ||||
:param pred1: [batch_size, seq_len, seq_len] 边预测logits | :param pred1: [batch_size, seq_len, seq_len] 边预测logits | ||||
@@ -458,7 +458,7 @@ class BiaffineParser(GraphParser): | |||||
return arc_nll + label_nll | return arc_nll + label_nll | ||||
def predict(self, words1, words2, seq_len): | def predict(self, words1, words2, seq_len): | ||||
"""模型预测API | |||||
r"""模型预测API | |||||
:param words1: [batch_size, seq_len] 输入word序列 | :param words1: [batch_size, seq_len] 输入word序列 | ||||
:param words2: [batch_size, seq_len] 输入pos序列 | :param words2: [batch_size, seq_len] 输入pos序列 | ||||
@@ -479,7 +479,7 @@ class BiaffineParser(GraphParser): | |||||
class ParserLoss(LossFunc): | class ParserLoss(LossFunc): | ||||
""" | |||||
r""" | |||||
计算parser的loss | 计算parser的loss | ||||
""" | """ | ||||
@@ -487,7 +487,7 @@ class ParserLoss(LossFunc): | |||||
def __init__(self, pred1=None, pred2=None, | def __init__(self, pred1=None, pred2=None, | ||||
target1=None, target2=None, | target1=None, target2=None, | ||||
seq_len=None): | seq_len=None): | ||||
""" | |||||
r""" | |||||
:param pred1: [batch_size, seq_len, seq_len] 边预测logits | :param pred1: [batch_size, seq_len, seq_len] 边预测logits | ||||
:param pred2: [batch_size, seq_len, num_label] label预测logits | :param pred2: [batch_size, seq_len, num_label] label预测logits | ||||
@@ -505,14 +505,14 @@ class ParserLoss(LossFunc): | |||||
class ParserMetric(MetricBase): | class ParserMetric(MetricBase): | ||||
""" | |||||
r""" | |||||
评估parser的性能 | 评估parser的性能 | ||||
""" | """ | ||||
def __init__(self, pred1=None, pred2=None, | def __init__(self, pred1=None, pred2=None, | ||||
target1=None, target2=None, seq_len=None): | target1=None, target2=None, seq_len=None): | ||||
""" | |||||
r""" | |||||
:param pred1: 边预测logits | :param pred1: 边预测logits | ||||
:param pred2: label预测logits | :param pred2: label预测logits | ||||
@@ -539,7 +539,7 @@ class ParserMetric(MetricBase): | |||||
return res | return res | ||||
def evaluate(self, pred1, pred2, target1, target2, seq_len=None): | def evaluate(self, pred1, pred2, target1, target2, seq_len=None): | ||||
"""Evaluate the performance of prediction. | |||||
r"""Evaluate the performance of prediction. | |||||
""" | """ | ||||
if seq_len is None: | if seq_len is None: | ||||
seq_mask = pred1.new_ones(pred1.size(), dtype=torch.long) | seq_mask = pred1.new_ones(pred1.size(), dtype=torch.long) | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
.. todo:: | .. todo:: | ||||
doc | doc | ||||
""" | """ | ||||
@@ -17,7 +17,7 @@ from ..modules import encoder | |||||
class CNNText(torch.nn.Module): | class CNNText(torch.nn.Module): | ||||
""" | |||||
r""" | |||||
使用CNN进行文本分类的模型 | 使用CNN进行文本分类的模型 | ||||
'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.' | 'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.' | ||||
@@ -28,7 +28,7 @@ class CNNText(torch.nn.Module): | |||||
kernel_nums=(30, 40, 50), | kernel_nums=(30, 40, 50), | ||||
kernel_sizes=(1, 3, 5), | kernel_sizes=(1, 3, 5), | ||||
dropout=0.5): | dropout=0.5): | ||||
""" | |||||
r""" | |||||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | ||||
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding | 第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding | ||||
@@ -48,7 +48,7 @@ class CNNText(torch.nn.Module): | |||||
self.fc = nn.Linear(sum(kernel_nums), num_classes) | self.fc = nn.Linear(sum(kernel_nums), num_classes) | ||||
def forward(self, words, seq_len=None): | def forward(self, words, seq_len=None): | ||||
""" | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, seq_len],句子中word的index | :param torch.LongTensor words: [batch_size, seq_len],句子中word的index | ||||
:param torch.LongTensor seq_len: [batch,] 每个句子的长度 | :param torch.LongTensor seq_len: [batch,] 每个句子的长度 | ||||
@@ -65,7 +65,7 @@ class CNNText(torch.nn.Module): | |||||
return {C.OUTPUT: x} | return {C.OUTPUT: x} | ||||
def predict(self, words, seq_len=None): | def predict(self, words, seq_len=None): | ||||
""" | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, seq_len],句子中word的index | :param torch.LongTensor words: [batch_size, seq_len],句子中word的index | ||||
:param torch.LongTensor seq_len: [batch,] 每个句子的长度 | :param torch.LongTensor seq_len: [batch,] 每个句子的长度 | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
本模块实现了几种序列标注模型 | 本模块实现了几种序列标注模型 | ||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
@@ -22,13 +22,13 @@ from ..modules.decoder.crf import allowed_transitions | |||||
class BiLSTMCRF(BaseModel): | class BiLSTMCRF(BaseModel): | ||||
""" | |||||
r""" | |||||
结构为embedding + BiLSTM + FC + Dropout + CRF. | 结构为embedding + BiLSTM + FC + Dropout + CRF. | ||||
""" | """ | ||||
def __init__(self, embed, num_classes, num_layers=1, hidden_size=100, dropout=0.5, | def __init__(self, embed, num_classes, num_layers=1, hidden_size=100, dropout=0.5, | ||||
target_vocab=None): | target_vocab=None): | ||||
""" | |||||
r""" | |||||
:param embed: 支持(1)fastNLP的各种Embedding, (2) tuple, 指明num_embedding, dimension, 如(1000, 100) | :param embed: 支持(1)fastNLP的各种Embedding, (2) tuple, 指明num_embedding, dimension, 如(1000, 100) | ||||
:param num_classes: 一共多少个类 | :param num_classes: 一共多少个类 | ||||
@@ -79,14 +79,14 @@ class BiLSTMCRF(BaseModel): | |||||
class SeqLabeling(BaseModel): | class SeqLabeling(BaseModel): | ||||
""" | |||||
r""" | |||||
一个基础的Sequence labeling的模型。 | 一个基础的Sequence labeling的模型。 | ||||
用于做sequence labeling的基础类。结构包含一层Embedding,一层LSTM(单向,一层),一层FC,以及一层CRF。 | 用于做sequence labeling的基础类。结构包含一层Embedding,一层LSTM(单向,一层),一层FC,以及一层CRF。 | ||||
""" | """ | ||||
def __init__(self, embed, hidden_size, num_classes): | def __init__(self, embed, hidden_size, num_classes): | ||||
""" | |||||
r""" | |||||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | ||||
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, embedding, ndarray等则直接使用该值初始化Embedding | 第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, embedding, ndarray等则直接使用该值初始化Embedding | ||||
@@ -101,7 +101,7 @@ class SeqLabeling(BaseModel): | |||||
self.crf = decoder.ConditionalRandomField(num_classes) | self.crf = decoder.ConditionalRandomField(num_classes) | ||||
def forward(self, words, seq_len, target): | def forward(self, words, seq_len, target): | ||||
""" | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, max_len],序列的index | :param torch.LongTensor words: [batch_size, max_len],序列的index | ||||
:param torch.LongTensor seq_len: [batch_size,], 这个序列的长度 | :param torch.LongTensor seq_len: [batch_size,], 这个序列的长度 | ||||
:param torch.LongTensor target: [batch_size, max_len], 序列的目标值 | :param torch.LongTensor target: [batch_size, max_len], 序列的目标值 | ||||
@@ -118,7 +118,7 @@ class SeqLabeling(BaseModel): | |||||
return {C.LOSS: self._internal_loss(x, target, mask)} | return {C.LOSS: self._internal_loss(x, target, mask)} | ||||
def predict(self, words, seq_len): | def predict(self, words, seq_len): | ||||
""" | |||||
r""" | |||||
用于在预测时使用 | 用于在预测时使用 | ||||
:param torch.LongTensor words: [batch_size, max_len] | :param torch.LongTensor words: [batch_size, max_len] | ||||
@@ -137,7 +137,7 @@ class SeqLabeling(BaseModel): | |||||
return {C.OUTPUT: pred} | return {C.OUTPUT: pred} | ||||
def _internal_loss(self, x, y, mask): | def _internal_loss(self, x, y, mask): | ||||
""" | |||||
r""" | |||||
Negative log likelihood loss. | Negative log likelihood loss. | ||||
:param x: Tensor, [batch_size, max_len, tag_size] | :param x: Tensor, [batch_size, max_len, tag_size] | ||||
:param y: Tensor, [batch_size, max_len] | :param y: Tensor, [batch_size, max_len] | ||||
@@ -150,7 +150,7 @@ class SeqLabeling(BaseModel): | |||||
return torch.mean(total_loss) | return torch.mean(total_loss) | ||||
def _decode(self, x, mask): | def _decode(self, x, mask): | ||||
""" | |||||
r""" | |||||
:param torch.FloatTensor x: [batch_size, max_len, tag_size] | :param torch.FloatTensor x: [batch_size, max_len, tag_size] | ||||
:return prediction: [batch_size, max_len] | :return prediction: [batch_size, max_len] | ||||
""" | """ | ||||
@@ -159,12 +159,12 @@ class SeqLabeling(BaseModel): | |||||
class AdvSeqLabel(nn.Module): | class AdvSeqLabel(nn.Module): | ||||
""" | |||||
r""" | |||||
更复杂的Sequence Labelling模型。结构为Embedding, LayerNorm, 双向LSTM(两层),FC,LayerNorm,DropOut,FC,CRF。 | 更复杂的Sequence Labelling模型。结构为Embedding, LayerNorm, 双向LSTM(两层),FC,LayerNorm,DropOut,FC,CRF。 | ||||
""" | """ | ||||
def __init__(self, embed, hidden_size, num_classes, dropout=0.3, id2words=None, encoding_type='bmes'): | def __init__(self, embed, hidden_size, num_classes, dropout=0.3, id2words=None, encoding_type='bmes'): | ||||
""" | |||||
r""" | |||||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | ||||
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding | 第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding | ||||
@@ -197,7 +197,7 @@ class AdvSeqLabel(nn.Module): | |||||
encoding_type=encoding_type)) | encoding_type=encoding_type)) | ||||
def _decode(self, x, mask): | def _decode(self, x, mask): | ||||
""" | |||||
r""" | |||||
:param torch.FloatTensor x: [batch_size, max_len, tag_size] | :param torch.FloatTensor x: [batch_size, max_len, tag_size] | ||||
:param torch.ByteTensor mask: [batch_size, max_len] | :param torch.ByteTensor mask: [batch_size, max_len] | ||||
:return torch.LongTensor, [batch_size, max_len] | :return torch.LongTensor, [batch_size, max_len] | ||||
@@ -206,7 +206,7 @@ class AdvSeqLabel(nn.Module): | |||||
return tag_seq | return tag_seq | ||||
def _internal_loss(self, x, y, mask): | def _internal_loss(self, x, y, mask): | ||||
""" | |||||
r""" | |||||
Negative log likelihood loss. | Negative log likelihood loss. | ||||
:param x: Tensor, [batch_size, max_len, tag_size] | :param x: Tensor, [batch_size, max_len, tag_size] | ||||
:param y: Tensor, [batch_size, max_len] | :param y: Tensor, [batch_size, max_len] | ||||
@@ -220,7 +220,7 @@ class AdvSeqLabel(nn.Module): | |||||
return torch.mean(total_loss) | return torch.mean(total_loss) | ||||
def _forward(self, words, seq_len, target=None): | def _forward(self, words, seq_len, target=None): | ||||
""" | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, mex_len] | :param torch.LongTensor words: [batch_size, mex_len] | ||||
:param torch.LongTensor seq_len:[batch_size, ] | :param torch.LongTensor seq_len:[batch_size, ] | ||||
:param torch.LongTensor target: [batch_size, max_len] | :param torch.LongTensor target: [batch_size, max_len] | ||||
@@ -254,7 +254,7 @@ class AdvSeqLabel(nn.Module): | |||||
return {"pred": self._decode(x, mask)} | return {"pred": self._decode(x, mask)} | ||||
def forward(self, words, seq_len, target): | def forward(self, words, seq_len, target): | ||||
""" | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, mex_len] | :param torch.LongTensor words: [batch_size, mex_len] | ||||
:param torch.LongTensor seq_len: [batch_size, ] | :param torch.LongTensor seq_len: [batch_size, ] | ||||
@@ -264,7 +264,7 @@ class AdvSeqLabel(nn.Module): | |||||
return self._forward(words, seq_len, target) | return self._forward(words, seq_len, target) | ||||
def predict(self, words, seq_len): | def predict(self, words, seq_len): | ||||
""" | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, mex_len] | :param torch.LongTensor words: [batch_size, mex_len] | ||||
:param torch.LongTensor seq_len: [batch_size, ] | :param torch.LongTensor seq_len: [batch_size, ] | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
.. todo:: | .. todo:: | ||||
doc | doc | ||||
""" | """ | ||||
@@ -19,7 +19,7 @@ from ..modules.encoder import BiAttention | |||||
class ESIM(BaseModel): | class ESIM(BaseModel): | ||||
""" | |||||
r""" | |||||
ESIM model的一个PyTorch实现 | ESIM model的一个PyTorch实现 | ||||
论文参见: https://arxiv.org/pdf/1609.06038.pdf | 论文参见: https://arxiv.org/pdf/1609.06038.pdf | ||||
@@ -27,7 +27,7 @@ class ESIM(BaseModel): | |||||
def __init__(self, embed, hidden_size=None, num_labels=3, dropout_rate=0.3, | def __init__(self, embed, hidden_size=None, num_labels=3, dropout_rate=0.3, | ||||
dropout_embed=0.1): | dropout_embed=0.1): | ||||
""" | |||||
r""" | |||||
:param embed: 初始化的Embedding | :param embed: 初始化的Embedding | ||||
:param int hidden_size: 隐藏层大小,默认值为Embedding的维度 | :param int hidden_size: 隐藏层大小,默认值为Embedding的维度 | ||||
@@ -68,7 +68,7 @@ class ESIM(BaseModel): | |||||
nn.init.xavier_uniform_(self.classifier[4].weight.data) | nn.init.xavier_uniform_(self.classifier[4].weight.data) | ||||
def forward(self, words1, words2, seq_len1, seq_len2, target=None): | def forward(self, words1, words2, seq_len1, seq_len2, target=None): | ||||
""" | |||||
r""" | |||||
:param words1: [batch, seq_len] | :param words1: [batch, seq_len] | ||||
:param words2: [batch, seq_len] | :param words2: [batch, seq_len] | ||||
:param seq_len1: [batch] | :param seq_len1: [batch] | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
Star-Transformer 的 Pytorch 实现。 | Star-Transformer 的 Pytorch 实现。 | ||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
@@ -18,7 +18,7 @@ from ..modules.encoder.star_transformer import StarTransformer | |||||
class StarTransEnc(nn.Module): | class StarTransEnc(nn.Module): | ||||
""" | |||||
r""" | |||||
带word embedding的Star-Transformer Encoder | 带word embedding的Star-Transformer Encoder | ||||
""" | """ | ||||
@@ -31,7 +31,7 @@ class StarTransEnc(nn.Module): | |||||
max_len, | max_len, | ||||
emb_dropout, | emb_dropout, | ||||
dropout): | dropout): | ||||
""" | |||||
r""" | |||||
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | :param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | ||||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,此时就以传入的对象作为embedding | embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,此时就以传入的对象作为embedding | ||||
@@ -56,7 +56,7 @@ class StarTransEnc(nn.Module): | |||||
max_len=max_len) | max_len=max_len) | ||||
def forward(self, x, mask): | def forward(self, x, mask): | ||||
""" | |||||
r""" | |||||
:param FloatTensor x: [batch, length, hidden] 输入的序列 | :param FloatTensor x: [batch, length, hidden] 输入的序列 | ||||
:param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0, | :param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0, | ||||
否则为 1 | 否则为 1 | ||||
@@ -103,7 +103,7 @@ class _NLICls(nn.Module): | |||||
class STSeqLabel(nn.Module): | class STSeqLabel(nn.Module): | ||||
""" | |||||
r""" | |||||
用于序列标注的Star-Transformer模型 | 用于序列标注的Star-Transformer模型 | ||||
""" | """ | ||||
@@ -117,7 +117,7 @@ class STSeqLabel(nn.Module): | |||||
cls_hidden_size=600, | cls_hidden_size=600, | ||||
emb_dropout=0.1, | emb_dropout=0.1, | ||||
dropout=0.1, ): | dropout=0.1, ): | ||||
""" | |||||
r""" | |||||
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | :param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | ||||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, 此时就以传入的对象作为embedding | embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, 此时就以传入的对象作为embedding | ||||
@@ -143,7 +143,7 @@ class STSeqLabel(nn.Module): | |||||
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size) | self.cls = _Cls(hidden_size, num_cls, cls_hidden_size) | ||||
def forward(self, words, seq_len): | def forward(self, words, seq_len): | ||||
""" | |||||
r""" | |||||
:param words: [batch, seq_len] 输入序列 | :param words: [batch, seq_len] 输入序列 | ||||
:param seq_len: [batch,] 输入序列的长度 | :param seq_len: [batch,] 输入序列的长度 | ||||
@@ -156,7 +156,7 @@ class STSeqLabel(nn.Module): | |||||
return {Const.OUTPUT: output} # [bsz, n_cls, seq_len] | return {Const.OUTPUT: output} # [bsz, n_cls, seq_len] | ||||
def predict(self, words, seq_len): | def predict(self, words, seq_len): | ||||
""" | |||||
r""" | |||||
:param words: [batch, seq_len] 输入序列 | :param words: [batch, seq_len] 输入序列 | ||||
:param seq_len: [batch,] 输入序列的长度 | :param seq_len: [batch,] 输入序列的长度 | ||||
@@ -168,7 +168,7 @@ class STSeqLabel(nn.Module): | |||||
class STSeqCls(nn.Module): | class STSeqCls(nn.Module): | ||||
""" | |||||
r""" | |||||
用于分类任务的Star-Transformer | 用于分类任务的Star-Transformer | ||||
""" | """ | ||||
@@ -182,7 +182,7 @@ class STSeqCls(nn.Module): | |||||
cls_hidden_size=600, | cls_hidden_size=600, | ||||
emb_dropout=0.1, | emb_dropout=0.1, | ||||
dropout=0.1, ): | dropout=0.1, ): | ||||
""" | |||||
r""" | |||||
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | :param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | ||||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, 此时就以传入的对象作为embedding | embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, 此时就以传入的对象作为embedding | ||||
@@ -208,7 +208,7 @@ class STSeqCls(nn.Module): | |||||
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size, dropout=dropout) | self.cls = _Cls(hidden_size, num_cls, cls_hidden_size, dropout=dropout) | ||||
def forward(self, words, seq_len): | def forward(self, words, seq_len): | ||||
""" | |||||
r""" | |||||
:param words: [batch, seq_len] 输入序列 | :param words: [batch, seq_len] 输入序列 | ||||
:param seq_len: [batch,] 输入序列的长度 | :param seq_len: [batch,] 输入序列的长度 | ||||
@@ -221,7 +221,7 @@ class STSeqCls(nn.Module): | |||||
return {Const.OUTPUT: output} | return {Const.OUTPUT: output} | ||||
def predict(self, words, seq_len): | def predict(self, words, seq_len): | ||||
""" | |||||
r""" | |||||
:param words: [batch, seq_len] 输入序列 | :param words: [batch, seq_len] 输入序列 | ||||
:param seq_len: [batch,] 输入序列的长度 | :param seq_len: [batch,] 输入序列的长度 | ||||
@@ -233,7 +233,7 @@ class STSeqCls(nn.Module): | |||||
class STNLICls(nn.Module): | class STNLICls(nn.Module): | ||||
""" | |||||
r""" | |||||
用于自然语言推断(NLI)的Star-Transformer | 用于自然语言推断(NLI)的Star-Transformer | ||||
""" | """ | ||||
@@ -247,7 +247,7 @@ class STNLICls(nn.Module): | |||||
cls_hidden_size=600, | cls_hidden_size=600, | ||||
emb_dropout=0.1, | emb_dropout=0.1, | ||||
dropout=0.1, ): | dropout=0.1, ): | ||||
""" | |||||
r""" | |||||
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | :param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | ||||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, 此时就以传入的对象作为embedding | embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, 此时就以传入的对象作为embedding | ||||
@@ -273,7 +273,7 @@ class STNLICls(nn.Module): | |||||
self.cls = _NLICls(hidden_size, num_cls, cls_hidden_size) | self.cls = _NLICls(hidden_size, num_cls, cls_hidden_size) | ||||
def forward(self, words1, words2, seq_len1, seq_len2): | def forward(self, words1, words2, seq_len1, seq_len2): | ||||
""" | |||||
r""" | |||||
:param words1: [batch, seq_len] 输入序列1 | :param words1: [batch, seq_len] 输入序列1 | ||||
:param words2: [batch, seq_len] 输入序列2 | :param words2: [batch, seq_len] 输入序列2 | ||||
@@ -294,7 +294,7 @@ class STNLICls(nn.Module): | |||||
return {Const.OUTPUT: output} | return {Const.OUTPUT: output} | ||||
def predict(self, words1, words2, seq_len1, seq_len2): | def predict(self, words1, words2, seq_len1, seq_len2): | ||||
""" | |||||
r""" | |||||
:param words1: [batch, seq_len] 输入序列1 | :param words1: [batch, seq_len] 输入序列1 | ||||
:param words2: [batch, seq_len] 输入序列2 | :param words2: [batch, seq_len] 输入序列2 | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
.. image:: figures/text_classification.png | .. image:: figures/text_classification.png | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
.. todo:: | .. todo:: | ||||
doc | doc | ||||
""" | """ | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"ConditionalRandomField", | "ConditionalRandomField", | ||||
@@ -16,7 +16,7 @@ from ...core.vocabulary import Vocabulary | |||||
def allowed_transitions(tag_vocab:Union[Vocabulary, dict], encoding_type=None, include_start_end=False): | def allowed_transitions(tag_vocab:Union[Vocabulary, dict], encoding_type=None, include_start_end=False): | ||||
""" | |||||
r""" | |||||
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | 给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | ||||
:param ~fastNLP.Vocabulary,dict tag_vocab: 支持类型为tag或tag-label。只有tag的,比如"B", "M"; 也可以是"B-NN", "M-NN", | :param ~fastNLP.Vocabulary,dict tag_vocab: 支持类型为tag或tag-label。只有tag的,比如"B", "M"; 也可以是"B-NN", "M-NN", | ||||
@@ -73,7 +73,7 @@ def allowed_transitions(tag_vocab:Union[Vocabulary, dict], encoding_type=None, i | |||||
def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | ||||
""" | |||||
r""" | |||||
:param str encoding_type: 支持"BIO", "BMES", "BEMSO", 'bioes'。 | :param str encoding_type: 支持"BIO", "BMES", "BEMSO", 'bioes'。 | ||||
:param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | :param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | ||||
@@ -86,7 +86,7 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label | |||||
return False | return False | ||||
encoding_type = encoding_type.lower() | encoding_type = encoding_type.lower() | ||||
if encoding_type == 'bio': | if encoding_type == 'bio': | ||||
""" | |||||
r""" | |||||
第一行是to_tag, 第一列是from_tag. y任意条件下可转,-只有在label相同时可转,n不可转 | 第一行是to_tag, 第一列是from_tag. y任意条件下可转,-只有在label相同时可转,n不可转 | ||||
+-------+---+---+---+-------+-----+ | +-------+---+---+---+-------+-----+ | ||||
| | B | I | O | start | end | | | | B | I | O | start | end | | ||||
@@ -112,7 +112,7 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label | |||||
raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag)) | raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag)) | ||||
elif encoding_type == 'bmes': | elif encoding_type == 'bmes': | ||||
""" | |||||
r""" | |||||
第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转 | 第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转 | ||||
+-------+---+---+---+---+-------+-----+ | +-------+---+---+---+---+-------+-----+ | ||||
| | B | M | E | S | start | end | | | | B | M | E | S | start | end | | ||||
@@ -167,14 +167,14 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label | |||||
class ConditionalRandomField(nn.Module): | class ConditionalRandomField(nn.Module): | ||||
""" | |||||
r""" | |||||
条件随机场。提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。 | 条件随机场。提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。 | ||||
""" | """ | ||||
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, | def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, | ||||
initial_method=None): | initial_method=None): | ||||
""" | |||||
r""" | |||||
:param int num_tags: 标签的数量 | :param int num_tags: 标签的数量 | ||||
:param bool include_start_end_trans: 是否考虑各个tag作为开始以及结尾的分数。 | :param bool include_start_end_trans: 是否考虑各个tag作为开始以及结尾的分数。 | ||||
@@ -205,7 +205,7 @@ class ConditionalRandomField(nn.Module): | |||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def _normalizer_likelihood(self, logits, mask): | def _normalizer_likelihood(self, logits, mask): | ||||
"""Computes the (batch_size,) denominator term for the log-likelihood, which is the | |||||
r"""Computes the (batch_size,) denominator term for the log-likelihood, which is the | |||||
sum of the likelihoods across all possible state sequences. | sum of the likelihoods across all possible state sequences. | ||||
:param logits:FloatTensor, max_len x batch_size x num_tags | :param logits:FloatTensor, max_len x batch_size x num_tags | ||||
@@ -232,7 +232,7 @@ class ConditionalRandomField(nn.Module): | |||||
return torch.logsumexp(alpha, 1) | return torch.logsumexp(alpha, 1) | ||||
def _gold_score(self, logits, tags, mask): | def _gold_score(self, logits, tags, mask): | ||||
""" | |||||
r""" | |||||
Compute the score for the gold path. | Compute the score for the gold path. | ||||
:param logits: FloatTensor, max_len x batch_size x num_tags | :param logits: FloatTensor, max_len x batch_size x num_tags | ||||
:param tags: LongTensor, max_len x batch_size | :param tags: LongTensor, max_len x batch_size | ||||
@@ -261,7 +261,7 @@ class ConditionalRandomField(nn.Module): | |||||
return score | return score | ||||
def forward(self, feats, tags, mask): | def forward(self, feats, tags, mask): | ||||
""" | |||||
r""" | |||||
用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 | 用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 | ||||
:param torch.FloatTensor feats: batch_size x max_len x num_tags,特征矩阵。 | :param torch.FloatTensor feats: batch_size x max_len x num_tags,特征矩阵。 | ||||
@@ -278,7 +278,7 @@ class ConditionalRandomField(nn.Module): | |||||
return all_path_score - gold_path_score | return all_path_score - gold_path_score | ||||
def viterbi_decode(self, logits, mask, unpad=False): | def viterbi_decode(self, logits, mask, unpad=False): | ||||
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | |||||
r"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | |||||
:param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 | :param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 | ||||
:param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | :param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"MLP" | "MLP" | ||||
@@ -11,7 +11,7 @@ from ..utils import initial_parameter | |||||
class MLP(nn.Module): | class MLP(nn.Module): | ||||
""" | |||||
r""" | |||||
多层感知器 | 多层感知器 | ||||
@@ -36,7 +36,7 @@ class MLP(nn.Module): | |||||
""" | """ | ||||
def __init__(self, size_layer, activation='relu', output_activation=None, initial_method=None, dropout=0.0): | def __init__(self, size_layer, activation='relu', output_activation=None, initial_method=None, dropout=0.0): | ||||
""" | |||||
r""" | |||||
:param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1 | :param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1 | ||||
:param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和 | :param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和 | ||||
@@ -87,7 +87,7 @@ class MLP(nn.Module): | |||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def forward(self, x): | def forward(self, x): | ||||
""" | |||||
r""" | |||||
:param torch.Tensor x: MLP接受的输入 | :param torch.Tensor x: MLP接受的输入 | ||||
:return: torch.Tensor : MLP的输出结果 | :return: torch.Tensor : MLP的输出结果 | ||||
""" | """ | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"viterbi_decode" | "viterbi_decode" | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"TimestepDropout" | "TimestepDropout" | ||||
@@ -8,7 +8,7 @@ import torch | |||||
class TimestepDropout(torch.nn.Dropout): | class TimestepDropout(torch.nn.Dropout): | ||||
""" | |||||
r""" | |||||
传入参数的shape为 ``(batch_size, num_timesteps, embedding_dim)`` | 传入参数的shape为 ``(batch_size, num_timesteps, embedding_dim)`` | ||||
使用同一个shape为 ``(batch_size, embedding_dim)`` 的mask在每个timestamp上做dropout。 | 使用同一个shape为 ``(batch_size, embedding_dim)`` 的mask在每个timestamp上做dropout。 | ||||
""" | """ | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
.. todo:: | .. todo:: | ||||
doc | doc | ||||
""" | """ | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented | |||||
r"""undocumented | |||||
这个页面的代码大量参考了 allenNLP | 这个页面的代码大量参考了 allenNLP | ||||
""" | """ | ||||
@@ -15,7 +15,7 @@ from ..utils import get_dropout_mask | |||||
class LstmCellWithProjection(torch.nn.Module): | class LstmCellWithProjection(torch.nn.Module): | ||||
""" | |||||
r""" | |||||
An LSTM with Recurrent Dropout and a projected and clipped hidden state and | An LSTM with Recurrent Dropout and a projected and clipped hidden state and | ||||
memory. Note: this implementation is slower than the native Pytorch LSTM because | memory. Note: this implementation is slower than the native Pytorch LSTM because | ||||
it cannot make use of CUDNN optimizations for stacked RNNs due to and | it cannot make use of CUDNN optimizations for stacked RNNs due to and | ||||
@@ -96,7 +96,7 @@ class LstmCellWithProjection(torch.nn.Module): | |||||
inputs: torch.FloatTensor, | inputs: torch.FloatTensor, | ||||
batch_lengths: List[int], | batch_lengths: List[int], | ||||
initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None): | initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None): | ||||
""" | |||||
r""" | |||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
inputs : ``torch.FloatTensor``, required. | inputs : ``torch.FloatTensor``, required. | ||||
@@ -307,7 +307,7 @@ class ElmobiLm(torch.nn.Module): | |||||
self.backward_layers = backward_layers | self.backward_layers = backward_layers | ||||
def forward(self, inputs, seq_len): | def forward(self, inputs, seq_len): | ||||
""" | |||||
r""" | |||||
:param inputs: batch_size x max_len x embed_size | :param inputs: batch_size x max_len x embed_size | ||||
:param seq_len: batch_size | :param seq_len: batch_size | ||||
@@ -326,7 +326,7 @@ class ElmobiLm(torch.nn.Module): | |||||
inputs: PackedSequence, | inputs: PackedSequence, | ||||
initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> \ | initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> \ | ||||
Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | ||||
""" | |||||
r""" | |||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
inputs : ``PackedSequence``, required. | inputs : ``PackedSequence``, required. | ||||
@@ -451,7 +451,7 @@ class ConvTokenEmbedder(nn.Module): | |||||
self._projection = torch.nn.Linear(n_filters, self.output_dim, bias=True) | self._projection = torch.nn.Linear(n_filters, self.output_dim, bias=True) | ||||
def forward(self, words, chars): | def forward(self, words, chars): | ||||
""" | |||||
r""" | |||||
:param words: | :param words: | ||||
:param chars: Tensor Shape ``(batch_size, sequence_length, 50)``: | :param chars: Tensor Shape ``(batch_size, sequence_length, 50)``: | ||||
:return Tensor Shape ``(batch_size, sequence_length + 2, embedding_dim)`` : | :return Tensor Shape ``(batch_size, sequence_length + 2, embedding_dim)`` : | ||||
@@ -491,7 +491,7 @@ class ConvTokenEmbedder(nn.Module): | |||||
class Highway(torch.nn.Module): | class Highway(torch.nn.Module): | ||||
""" | |||||
r""" | |||||
A `Highway layer <https://arxiv.org/abs/1505.00387>`_ does a gated combination of a linear | A `Highway layer <https://arxiv.org/abs/1505.00387>`_ does a gated combination of a linear | ||||
transformation and a non-linear transformation of its input. :math:`y = g * x + (1 - g) * | transformation and a non-linear transformation of its input. :math:`y = g * x + (1 - g) * | ||||
f(A(x))`, where :math:`A` is a linear transformation, :math:`f` is an element-wise | f(A(x))`, where :math:`A` is a linear transformation, :math:`f` is an element-wise | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"MultiHeadAttention", | "MultiHeadAttention", | ||||
@@ -16,7 +16,7 @@ from fastNLP.modules.utils import initial_parameter | |||||
class DotAttention(nn.Module): | class DotAttention(nn.Module): | ||||
""" | |||||
r""" | |||||
Transformer当中的DotAttention | Transformer当中的DotAttention | ||||
""" | """ | ||||
@@ -29,7 +29,7 @@ class DotAttention(nn.Module): | |||||
self.softmax = nn.Softmax(dim=-1) | self.softmax = nn.Softmax(dim=-1) | ||||
def forward(self, Q, K, V, mask_out=None): | def forward(self, Q, K, V, mask_out=None): | ||||
""" | |||||
r""" | |||||
:param Q: [..., seq_len_q, key_size] | :param Q: [..., seq_len_q, key_size] | ||||
:param K: [..., seq_len_k, key_size] | :param K: [..., seq_len_k, key_size] | ||||
@@ -45,12 +45,12 @@ class DotAttention(nn.Module): | |||||
class MultiHeadAttention(nn.Module): | class MultiHeadAttention(nn.Module): | ||||
""" | |||||
r""" | |||||
Transformer当中的MultiHeadAttention | Transformer当中的MultiHeadAttention | ||||
""" | """ | ||||
def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): | def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): | ||||
""" | |||||
r""" | |||||
:param input_size: int, 输入维度的大小。同时也是输出维度的大小。 | :param input_size: int, 输入维度的大小。同时也是输出维度的大小。 | ||||
:param key_size: int, 每个head的维度大小。 | :param key_size: int, 每个head的维度大小。 | ||||
@@ -80,7 +80,7 @@ class MultiHeadAttention(nn.Module): | |||||
nn.init.normal_(self.out.weight, mean=0, std=sqrt(1.0 / self.input_size)) | nn.init.normal_(self.out.weight, mean=0, std=sqrt(1.0 / self.input_size)) | ||||
def forward(self, Q, K, V, atte_mask_out=None): | def forward(self, Q, K, V, atte_mask_out=None): | ||||
""" | |||||
r""" | |||||
:param Q: [batch, seq_len_q, model_size] | :param Q: [batch, seq_len_q, model_size] | ||||
:param K: [batch, seq_len_k, model_size] | :param K: [batch, seq_len_k, model_size] | ||||
@@ -147,7 +147,7 @@ class BiAttention(nn.Module): | |||||
""" | """ | ||||
def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): | def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): | ||||
""" | |||||
r""" | |||||
:param torch.Tensor premise_batch: [batch_size, a_seq_len, hidden_size] | :param torch.Tensor premise_batch: [batch_size, a_seq_len, hidden_size] | ||||
:param torch.Tensor premise_mask: [batch_size, a_seq_len] | :param torch.Tensor premise_mask: [batch_size, a_seq_len] | ||||
:param torch.Tensor hypothesis_batch: [batch_size, b_seq_len, hidden_size] | :param torch.Tensor hypothesis_batch: [batch_size, b_seq_len, hidden_size] | ||||
@@ -173,13 +173,13 @@ class BiAttention(nn.Module): | |||||
class SelfAttention(nn.Module): | class SelfAttention(nn.Module): | ||||
""" | |||||
r""" | |||||
这是一个基于论文 `A structured self-attentive sentence embedding <https://arxiv.org/pdf/1703.03130.pdf>`_ | 这是一个基于论文 `A structured self-attentive sentence embedding <https://arxiv.org/pdf/1703.03130.pdf>`_ | ||||
的Self Attention Module. | 的Self Attention Module. | ||||
""" | """ | ||||
def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5, initial_method=None, ): | def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5, initial_method=None, ): | ||||
""" | |||||
r""" | |||||
:param int input_size: 输入tensor的hidden维度 | :param int input_size: 输入tensor的hidden维度 | ||||
:param int attention_unit: 输出tensor的hidden维度 | :param int attention_unit: 输出tensor的hidden维度 | ||||
@@ -199,7 +199,7 @@ class SelfAttention(nn.Module): | |||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def _penalization(self, attention): | def _penalization(self, attention): | ||||
""" | |||||
r""" | |||||
compute the penalization term for attention module | compute the penalization term for attention module | ||||
""" | """ | ||||
baz = attention.size(0) | baz = attention.size(0) | ||||
@@ -213,7 +213,7 @@ class SelfAttention(nn.Module): | |||||
return torch.sum(ret) / size[0] | return torch.sum(ret) / size[0] | ||||
def forward(self, input, input_origin): | def forward(self, input, input_origin): | ||||
""" | |||||
r""" | |||||
:param torch.Tensor input: [batch_size, seq_len, hidden_size] 要做attention的矩阵 | :param torch.Tensor input: [batch_size, seq_len, hidden_size] 要做attention的矩阵 | ||||
:param torch.Tensor input_origin: [batch_size, seq_len] 原始token的index组成的矩阵,含有pad部分内容 | :param torch.Tensor input_origin: [batch_size, seq_len] 原始token的index组成的矩阵,含有pad部分内容 | ||||
:return torch.Tensor output1: [batch_size, multi-head, hidden_size] 经过attention操作后输入矩阵的结果 | :return torch.Tensor output1: [batch_size, multi-head, hidden_size] 经过attention操作后输入矩阵的结果 | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented | |||||
r"""undocumented | |||||
这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你 | 这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你 | ||||
有用,也请引用一下他们。 | 有用,也请引用一下他们。 | ||||
""" | """ | ||||
@@ -45,7 +45,7 @@ BERT_KEY_RENAME_MAP_2 = { | |||||
class BertConfig(object): | class BertConfig(object): | ||||
"""Configuration class to store the configuration of a `BertModel`. | |||||
r"""Configuration class to store the configuration of a `BertModel`. | |||||
""" | """ | ||||
def __init__(self, | def __init__(self, | ||||
@@ -61,7 +61,7 @@ class BertConfig(object): | |||||
type_vocab_size=2, | type_vocab_size=2, | ||||
initializer_range=0.02, | initializer_range=0.02, | ||||
layer_norm_eps=1e-12): | layer_norm_eps=1e-12): | ||||
"""Constructs BertConfig. | |||||
r"""Constructs BertConfig. | |||||
Args: | Args: | ||||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. | ||||
@@ -110,7 +110,7 @@ class BertConfig(object): | |||||
@classmethod | @classmethod | ||||
def from_dict(cls, json_object): | def from_dict(cls, json_object): | ||||
"""Constructs a `BertConfig` from a Python dictionary of parameters.""" | |||||
r"""Constructs a `BertConfig` from a Python dictionary of parameters.""" | |||||
config = BertConfig(vocab_size_or_config_json_file=-1) | config = BertConfig(vocab_size_or_config_json_file=-1) | ||||
for key, value in json_object.items(): | for key, value in json_object.items(): | ||||
config.__dict__[key] = value | config.__dict__[key] = value | ||||
@@ -118,7 +118,7 @@ class BertConfig(object): | |||||
@classmethod | @classmethod | ||||
def from_json_file(cls, json_file): | def from_json_file(cls, json_file): | ||||
"""Constructs a `BertConfig` from a json file of parameters.""" | |||||
r"""Constructs a `BertConfig` from a json file of parameters.""" | |||||
with open(json_file, "r", encoding='utf-8') as reader: | with open(json_file, "r", encoding='utf-8') as reader: | ||||
text = reader.read() | text = reader.read() | ||||
return cls.from_dict(json.loads(text)) | return cls.from_dict(json.loads(text)) | ||||
@@ -127,16 +127,16 @@ class BertConfig(object): | |||||
return str(self.to_json_string()) | return str(self.to_json_string()) | ||||
def to_dict(self): | def to_dict(self): | ||||
"""Serializes this instance to a Python dictionary.""" | |||||
r"""Serializes this instance to a Python dictionary.""" | |||||
output = copy.deepcopy(self.__dict__) | output = copy.deepcopy(self.__dict__) | ||||
return output | return output | ||||
def to_json_string(self): | def to_json_string(self): | ||||
"""Serializes this instance to a JSON string.""" | |||||
r"""Serializes this instance to a JSON string.""" | |||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" | ||||
def to_json_file(self, json_file_path): | def to_json_file(self, json_file_path): | ||||
""" Save this instance to a json file.""" | |||||
r""" Save this instance to a json file.""" | |||||
with open(json_file_path, "w", encoding='utf-8') as writer: | with open(json_file_path, "w", encoding='utf-8') as writer: | ||||
writer.write(self.to_json_string()) | writer.write(self.to_json_string()) | ||||
@@ -167,7 +167,7 @@ def _get_bert_dir(model_dir_or_name: str = 'en-base-uncased'): | |||||
class BertLayerNorm(nn.Module): | class BertLayerNorm(nn.Module): | ||||
def __init__(self, hidden_size, eps=1e-12): | def __init__(self, hidden_size, eps=1e-12): | ||||
"""Construct a layernorm module in the TF style (epsilon inside the square root). | |||||
r"""Construct a layernorm module in the TF style (epsilon inside the square root). | |||||
""" | """ | ||||
super(BertLayerNorm, self).__init__() | super(BertLayerNorm, self).__init__() | ||||
self.weight = nn.Parameter(torch.ones(hidden_size)) | self.weight = nn.Parameter(torch.ones(hidden_size)) | ||||
@@ -206,7 +206,7 @@ class DistilBertEmbeddings(nn.Module): | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||
def forward(self, input_ids, token_type_ids): | def forward(self, input_ids, token_type_ids): | ||||
""" | |||||
r""" | |||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
input_ids: torch.tensor(bs, max_seq_length) | input_ids: torch.tensor(bs, max_seq_length) | ||||
@@ -231,7 +231,7 @@ class DistilBertEmbeddings(nn.Module): | |||||
class BertEmbeddings(nn.Module): | class BertEmbeddings(nn.Module): | ||||
"""Construct the embeddings from word, position and token_type embeddings. | |||||
r"""Construct the embeddings from word, position and token_type embeddings. | |||||
""" | """ | ||||
def __init__(self, config): | def __init__(self, config): | ||||
@@ -415,7 +415,7 @@ class BertPooler(nn.Module): | |||||
class BertModel(nn.Module): | class BertModel(nn.Module): | ||||
""" | |||||
r""" | |||||
BERT(Bidirectional Embedding Representations from Transformers). | BERT(Bidirectional Embedding Representations from Transformers). | ||||
用预训练权重矩阵来建立BERT模型:: | 用预训练权重矩阵来建立BERT模型:: | ||||
@@ -470,7 +470,7 @@ class BertModel(nn.Module): | |||||
self.apply(self.init_bert_weights) | self.apply(self.init_bert_weights) | ||||
def init_bert_weights(self, module): | def init_bert_weights(self, module): | ||||
""" Initialize the weights. | |||||
r""" Initialize the weights. | |||||
""" | """ | ||||
if isinstance(module, (nn.Linear, nn.Embedding)): | if isinstance(module, (nn.Linear, nn.Embedding)): | ||||
# Slightly different from the TF version which uses truncated_normal for initialization | # Slightly different from the TF version which uses truncated_normal for initialization | ||||
@@ -613,7 +613,7 @@ class BertModel(nn.Module): | |||||
def whitespace_tokenize(text): | def whitespace_tokenize(text): | ||||
"""Runs basic whitespace cleaning and splitting on a piece of text.""" | |||||
r"""Runs basic whitespace cleaning and splitting on a piece of text.""" | |||||
text = text.strip() | text = text.strip() | ||||
if not text: | if not text: | ||||
return [] | return [] | ||||
@@ -622,7 +622,7 @@ def whitespace_tokenize(text): | |||||
class WordpieceTokenizer(object): | class WordpieceTokenizer(object): | ||||
"""Runs WordPiece tokenization.""" | |||||
r"""Runs WordPiece tokenization.""" | |||||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): | ||||
self.vocab = vocab | self.vocab = vocab | ||||
@@ -630,7 +630,7 @@ class WordpieceTokenizer(object): | |||||
self.max_input_chars_per_word = max_input_chars_per_word | self.max_input_chars_per_word = max_input_chars_per_word | ||||
def tokenize(self, text): | def tokenize(self, text): | ||||
"""Tokenizes a piece of text into its word pieces. | |||||
r"""Tokenizes a piece of text into its word pieces. | |||||
This uses a greedy longest-match-first algorithm to perform tokenization | This uses a greedy longest-match-first algorithm to perform tokenization | ||||
using the given vocabulary. | using the given vocabulary. | ||||
@@ -684,7 +684,7 @@ class WordpieceTokenizer(object): | |||||
def load_vocab(vocab_file): | def load_vocab(vocab_file): | ||||
"""Loads a vocabulary file into a dictionary.""" | |||||
r"""Loads a vocabulary file into a dictionary.""" | |||||
vocab = collections.OrderedDict() | vocab = collections.OrderedDict() | ||||
index = 0 | index = 0 | ||||
with open(vocab_file, "r", encoding="utf-8") as reader: | with open(vocab_file, "r", encoding="utf-8") as reader: | ||||
@@ -699,12 +699,12 @@ def load_vocab(vocab_file): | |||||
class BasicTokenizer(object): | class BasicTokenizer(object): | ||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.).""" | |||||
r"""Runs basic tokenization (punctuation splitting, lower casing, etc.).""" | |||||
def __init__(self, | def __init__(self, | ||||
do_lower_case=True, | do_lower_case=True, | ||||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): | ||||
"""Constructs a BasicTokenizer. | |||||
r"""Constructs a BasicTokenizer. | |||||
Args: | Args: | ||||
do_lower_case: Whether to lower case the input. | do_lower_case: Whether to lower case the input. | ||||
@@ -713,7 +713,7 @@ class BasicTokenizer(object): | |||||
self.never_split = never_split | self.never_split = never_split | ||||
def tokenize(self, text): | def tokenize(self, text): | ||||
"""Tokenizes a piece of text.""" | |||||
r"""Tokenizes a piece of text.""" | |||||
text = self._clean_text(text) | text = self._clean_text(text) | ||||
# This was added on November 1st, 2018 for the multilingual and Chinese | # This was added on November 1st, 2018 for the multilingual and Chinese | ||||
# models. This is also applied to the English models now, but it doesn't | # models. This is also applied to the English models now, but it doesn't | ||||
@@ -734,7 +734,7 @@ class BasicTokenizer(object): | |||||
return output_tokens | return output_tokens | ||||
def _run_strip_accents(self, text): | def _run_strip_accents(self, text): | ||||
"""Strips accents from a piece of text.""" | |||||
r"""Strips accents from a piece of text.""" | |||||
text = unicodedata.normalize("NFD", text) | text = unicodedata.normalize("NFD", text) | ||||
output = [] | output = [] | ||||
for char in text: | for char in text: | ||||
@@ -745,7 +745,7 @@ class BasicTokenizer(object): | |||||
return "".join(output) | return "".join(output) | ||||
def _run_split_on_punc(self, text): | def _run_split_on_punc(self, text): | ||||
"""Splits punctuation on a piece of text.""" | |||||
r"""Splits punctuation on a piece of text.""" | |||||
if text in self.never_split: | if text in self.never_split: | ||||
return [text] | return [text] | ||||
chars = list(text) | chars = list(text) | ||||
@@ -767,7 +767,7 @@ class BasicTokenizer(object): | |||||
return ["".join(x) for x in output] | return ["".join(x) for x in output] | ||||
def _tokenize_chinese_chars(self, text): | def _tokenize_chinese_chars(self, text): | ||||
"""Adds whitespace around any CJK character.""" | |||||
r"""Adds whitespace around any CJK character.""" | |||||
output = [] | output = [] | ||||
for char in text: | for char in text: | ||||
cp = ord(char) | cp = ord(char) | ||||
@@ -780,7 +780,7 @@ class BasicTokenizer(object): | |||||
return "".join(output) | return "".join(output) | ||||
def _is_chinese_char(self, cp): | def _is_chinese_char(self, cp): | ||||
"""Checks whether CP is the codepoint of a CJK character.""" | |||||
r"""Checks whether CP is the codepoint of a CJK character.""" | |||||
# This defines a "chinese character" as anything in the CJK Unicode block: | # This defines a "chinese character" as anything in the CJK Unicode block: | ||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | ||||
# | # | ||||
@@ -802,7 +802,7 @@ class BasicTokenizer(object): | |||||
return False | return False | ||||
def _clean_text(self, text): | def _clean_text(self, text): | ||||
"""Performs invalid character removal and whitespace cleanup on text.""" | |||||
r"""Performs invalid character removal and whitespace cleanup on text.""" | |||||
output = [] | output = [] | ||||
for char in text: | for char in text: | ||||
cp = ord(char) | cp = ord(char) | ||||
@@ -816,7 +816,7 @@ class BasicTokenizer(object): | |||||
def _is_whitespace(char): | def _is_whitespace(char): | ||||
"""Checks whether `chars` is a whitespace character.""" | |||||
r"""Checks whether `chars` is a whitespace character.""" | |||||
# \t, \n, and \r are technically contorl characters but we treat them | # \t, \n, and \r are technically contorl characters but we treat them | ||||
# as whitespace since they are generally considered as such. | # as whitespace since they are generally considered as such. | ||||
if char == " " or char == "\t" or char == "\n" or char == "\r": | if char == " " or char == "\t" or char == "\n" or char == "\r": | ||||
@@ -828,7 +828,7 @@ def _is_whitespace(char): | |||||
def _is_control(char): | def _is_control(char): | ||||
"""Checks whether `chars` is a control character.""" | |||||
r"""Checks whether `chars` is a control character.""" | |||||
# These are technically control characters but we count them as whitespace | # These are technically control characters but we count them as whitespace | ||||
# characters. | # characters. | ||||
if char == "\t" or char == "\n" or char == "\r": | if char == "\t" or char == "\n" or char == "\r": | ||||
@@ -840,7 +840,7 @@ def _is_control(char): | |||||
def _is_punctuation(char): | def _is_punctuation(char): | ||||
"""Checks whether `chars` is a punctuation character.""" | |||||
r"""Checks whether `chars` is a punctuation character.""" | |||||
cp = ord(char) | cp = ord(char) | ||||
# We treat all non-letter/number ASCII as punctuation. | # We treat all non-letter/number ASCII as punctuation. | ||||
# Characters such as "^", "$", and "`" are not in the Unicode | # Characters such as "^", "$", and "`" are not in the Unicode | ||||
@@ -856,11 +856,11 @@ def _is_punctuation(char): | |||||
class BertTokenizer(object): | class BertTokenizer(object): | ||||
"""Runs end-to-end tokenization: punctuation splitting + wordpiece""" | |||||
r"""Runs end-to-end tokenization: punctuation splitting + wordpiece""" | |||||
def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, | def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, | ||||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): | ||||
"""Constructs a BertTokenizer. | |||||
r"""Constructs a BertTokenizer. | |||||
Args: | Args: | ||||
vocab_file: Path to a one-wordpiece-per-line vocabulary file | vocab_file: Path to a one-wordpiece-per-line vocabulary file | ||||
@@ -889,7 +889,7 @@ class BertTokenizer(object): | |||||
self.max_len = max_len if max_len is not None else int(1e12) | self.max_len = max_len if max_len is not None else int(1e12) | ||||
def _reinit_on_new_vocab(self, vocab): | def _reinit_on_new_vocab(self, vocab): | ||||
""" | |||||
r""" | |||||
在load bert之后,可能会对vocab进行重新排列。重新排列之后调用这个函数重新初始化与vocab相关的性质 | 在load bert之后,可能会对vocab进行重新排列。重新排列之后调用这个函数重新初始化与vocab相关的性质 | ||||
:param vocab: | :param vocab: | ||||
@@ -909,7 +909,7 @@ class BertTokenizer(object): | |||||
return split_tokens | return split_tokens | ||||
def convert_tokens_to_ids(self, tokens): | def convert_tokens_to_ids(self, tokens): | ||||
"""Converts a sequence of tokens into ids using the vocab.""" | |||||
r"""Converts a sequence of tokens into ids using the vocab.""" | |||||
ids = [] | ids = [] | ||||
for token in tokens: | for token in tokens: | ||||
ids.append(self.vocab[token]) | ids.append(self.vocab[token]) | ||||
@@ -922,14 +922,14 @@ class BertTokenizer(object): | |||||
return ids | return ids | ||||
def convert_ids_to_tokens(self, ids): | def convert_ids_to_tokens(self, ids): | ||||
"""Converts a sequence of ids in wordpiece tokens using the vocab.""" | |||||
r"""Converts a sequence of ids in wordpiece tokens using the vocab.""" | |||||
tokens = [] | tokens = [] | ||||
for i in ids: | for i in ids: | ||||
tokens.append(self.ids_to_tokens[i]) | tokens.append(self.ids_to_tokens[i]) | ||||
return tokens | return tokens | ||||
def save_vocabulary(self, vocab_path): | def save_vocabulary(self, vocab_path): | ||||
"""Save the tokenizer vocabulary to a directory or file.""" | |||||
r"""Save the tokenizer vocabulary to a directory or file.""" | |||||
index = 0 | index = 0 | ||||
if os.path.isdir(vocab_path): | if os.path.isdir(vocab_path): | ||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME) | vocab_file = os.path.join(vocab_path, VOCAB_NAME) | ||||
@@ -947,7 +947,7 @@ class BertTokenizer(object): | |||||
@classmethod | @classmethod | ||||
def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs): | def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs): | ||||
""" | |||||
r""" | |||||
给定模型的名字或者路径,直接读取vocab. | 给定模型的名字或者路径,直接读取vocab. | ||||
""" | """ | ||||
model_dir = _get_bert_dir(model_dir_or_name) | model_dir = _get_bert_dir(model_dir_or_name) | ||||
@@ -961,7 +961,7 @@ class BertTokenizer(object): | |||||
class _WordPieceBertModel(nn.Module): | class _WordPieceBertModel(nn.Module): | ||||
""" | |||||
r""" | |||||
这个模块用于直接计算word_piece的结果. | 这个模块用于直接计算word_piece的结果. | ||||
""" | """ | ||||
@@ -989,7 +989,7 @@ class _WordPieceBertModel(nn.Module): | |||||
self.pooled_cls = pooled_cls | self.pooled_cls = pooled_cls | ||||
def index_dataset(self, *datasets, field_name, add_cls_sep=True): | def index_dataset(self, *datasets, field_name, add_cls_sep=True): | ||||
""" | |||||
r""" | |||||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是 | 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是 | ||||
[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 | [CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 | ||||
@@ -1024,7 +1024,7 @@ class _WordPieceBertModel(nn.Module): | |||||
raise e | raise e | ||||
def forward(self, word_pieces, token_type_ids=None): | def forward(self, word_pieces, token_type_ids=None): | ||||
""" | |||||
r""" | |||||
:param word_pieces: torch.LongTensor, batch_size x max_len | :param word_pieces: torch.LongTensor, batch_size x max_len | ||||
:param token_type_ids: torch.LongTensor, batch_size x max_len | :param token_type_ids: torch.LongTensor, batch_size x max_len | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"ConvolutionCharEncoder", | "ConvolutionCharEncoder", | ||||
@@ -12,13 +12,13 @@ from ..utils import initial_parameter | |||||
# from torch.nn.init import xavier_uniform | # from torch.nn.init import xavier_uniform | ||||
class ConvolutionCharEncoder(nn.Module): | class ConvolutionCharEncoder(nn.Module): | ||||
""" | |||||
r""" | |||||
char级别的卷积编码器. | char级别的卷积编码器. | ||||
""" | """ | ||||
def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(1, 3, 5), initial_method=None): | def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(1, 3, 5), initial_method=None): | ||||
""" | |||||
r""" | |||||
:param int char_emb_size: char级别embedding的维度. Default: 50 | :param int char_emb_size: char级别embedding的维度. Default: 50 | ||||
:例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50. | :例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50. | ||||
@@ -35,7 +35,7 @@ class ConvolutionCharEncoder(nn.Module): | |||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def forward(self, x): | def forward(self, x): | ||||
""" | |||||
r""" | |||||
:param torch.Tensor x: ``[batch_size * sent_length, word_length, char_emb_size]`` 输入字符的embedding | :param torch.Tensor x: ``[batch_size * sent_length, word_length, char_emb_size]`` 输入字符的embedding | ||||
:return: torch.Tensor : 卷积计算的结果, 维度为[batch_size * sent_length, sum(feature_maps), 1] | :return: torch.Tensor : 卷积计算的结果, 维度为[batch_size * sent_length, sum(feature_maps), 1] | ||||
""" | """ | ||||
@@ -60,12 +60,12 @@ class ConvolutionCharEncoder(nn.Module): | |||||
class LSTMCharEncoder(nn.Module): | class LSTMCharEncoder(nn.Module): | ||||
""" | |||||
r""" | |||||
char级别基于LSTM的encoder. | char级别基于LSTM的encoder. | ||||
""" | """ | ||||
def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None): | def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None): | ||||
""" | |||||
r""" | |||||
:param int char_emb_size: char级别embedding的维度. Default: 50 | :param int char_emb_size: char级别embedding的维度. Default: 50 | ||||
例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50. | 例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50. | ||||
:param int hidden_size: LSTM隐层的大小, 默认为char的embedding维度 | :param int hidden_size: LSTM隐层的大小, 默认为char的embedding维度 | ||||
@@ -82,7 +82,7 @@ class LSTMCharEncoder(nn.Module): | |||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def forward(self, x): | def forward(self, x): | ||||
""" | |||||
r""" | |||||
:param torch.Tensor x: ``[ n_batch*n_word, word_length, char_emb_size]`` 输入字符的embedding | :param torch.Tensor x: ``[ n_batch*n_word, word_length, char_emb_size]`` 输入字符的embedding | ||||
:return: torch.Tensor : [ n_batch*n_word, char_emb_size]经过LSTM编码的结果 | :return: torch.Tensor : [ n_batch*n_word, char_emb_size]经过LSTM编码的结果 | ||||
""" | """ | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"ConvMaxpool" | "ConvMaxpool" | ||||
@@ -9,7 +9,7 @@ import torch.nn.functional as F | |||||
class ConvMaxpool(nn.Module): | class ConvMaxpool(nn.Module): | ||||
""" | |||||
r""" | |||||
集合了Convolution和Max-Pooling于一体的层。给定一个batch_size x max_len x input_size的输入,返回batch_size x | 集合了Convolution和Max-Pooling于一体的层。给定一个batch_size x max_len x input_size的输入,返回batch_size x | ||||
sum(output_channels) 大小的matrix。在内部,是先使用CNN给输入做卷积,然后经过activation激活层,在通过在长度(max_len) | sum(output_channels) 大小的matrix。在内部,是先使用CNN给输入做卷积,然后经过activation激活层,在通过在长度(max_len) | ||||
这一维进行max_pooling。最后得到每个sample的一个向量表示。 | 这一维进行max_pooling。最后得到每个sample的一个向量表示。 | ||||
@@ -17,7 +17,7 @@ class ConvMaxpool(nn.Module): | |||||
""" | """ | ||||
def __init__(self, in_channels, out_channels, kernel_sizes, activation="relu"): | def __init__(self, in_channels, out_channels, kernel_sizes, activation="relu"): | ||||
""" | |||||
r""" | |||||
:param int in_channels: 输入channel的大小,一般是embedding的维度; 或encoder的output维度 | :param int in_channels: 输入channel的大小,一般是embedding的维度; 或encoder的output维度 | ||||
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 | :param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 | ||||
@@ -68,7 +68,7 @@ class ConvMaxpool(nn.Module): | |||||
"Undefined activation function: choose from: relu, tanh, sigmoid") | "Undefined activation function: choose from: relu, tanh, sigmoid") | ||||
def forward(self, x, mask=None): | def forward(self, x, mask=None): | ||||
""" | |||||
r""" | |||||
:param torch.FloatTensor x: batch_size x max_len x input_size, 一般是经过embedding后的值 | :param torch.FloatTensor x: batch_size x max_len x input_size, 一般是经过embedding后的值 | ||||
:param mask: batch_size x max_len, pad的地方为0。不影响卷积运算,max-pool一定不会pool到pad为0的位置 | :param mask: batch_size x max_len, pad的地方为0。不影响卷积运算,max-pool一定不会pool到pad为0的位置 | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented | |||||
r"""undocumented | |||||
轻量封装的 Pytorch LSTM 模块. | 轻量封装的 Pytorch LSTM 模块. | ||||
可在 forward 时传入序列的长度, 自动对padding做合适的处理. | 可在 forward 时传入序列的长度, 自动对padding做合适的处理. | ||||
""" | """ | ||||
@@ -13,7 +13,7 @@ import torch.nn.utils.rnn as rnn | |||||
class LSTM(nn.Module): | class LSTM(nn.Module): | ||||
""" | |||||
r""" | |||||
LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 | LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 | ||||
为1; 且可以应对DataParallel中LSTM的使用问题。 | 为1; 且可以应对DataParallel中LSTM的使用问题。 | ||||
@@ -21,7 +21,7 @@ class LSTM(nn.Module): | |||||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, | def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, | ||||
bidirectional=False, bias=True): | bidirectional=False, bias=True): | ||||
""" | |||||
r""" | |||||
:param input_size: 输入 `x` 的特征维度 | :param input_size: 输入 `x` 的特征维度 | ||||
:param hidden_size: 隐状态 `h` 的特征维度. 如果bidirectional为True,则输出的维度会是hidde_size*2 | :param hidden_size: 隐状态 `h` 的特征维度. 如果bidirectional为True,则输出的维度会是hidde_size*2 | ||||
@@ -50,7 +50,7 @@ class LSTM(nn.Module): | |||||
nn.init.xavier_uniform_(param) | nn.init.xavier_uniform_(param) | ||||
def forward(self, x, seq_len=None, h0=None, c0=None): | def forward(self, x, seq_len=None, h0=None, c0=None): | ||||
""" | |||||
r""" | |||||
:param x: [batch, seq_len, input_size] 输入序列 | :param x: [batch, seq_len, input_size] 输入序列 | ||||
:param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None`` | :param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None`` | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"MaxPool", | "MaxPool", | ||||
@@ -12,13 +12,13 @@ import torch.nn as nn | |||||
class MaxPool(nn.Module): | class MaxPool(nn.Module): | ||||
""" | |||||
r""" | |||||
Max-pooling模块。 | Max-pooling模块。 | ||||
""" | """ | ||||
def __init__(self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None, ceil_mode=False): | def __init__(self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None, ceil_mode=False): | ||||
""" | |||||
r""" | |||||
:param stride: 窗口移动大小,默认为kernel_size | :param stride: 窗口移动大小,默认为kernel_size | ||||
:param padding: padding的内容,默认为0 | :param padding: padding的内容,默认为0 | ||||
@@ -61,7 +61,7 @@ class MaxPool(nn.Module): | |||||
class MaxPoolWithMask(nn.Module): | class MaxPoolWithMask(nn.Module): | ||||
""" | |||||
r""" | |||||
带mask矩阵的max pooling。在做max-pooling的时候不会考虑mask值为0的位置。 | 带mask矩阵的max pooling。在做max-pooling的时候不会考虑mask值为0的位置。 | ||||
""" | """ | ||||
@@ -70,7 +70,7 @@ class MaxPoolWithMask(nn.Module): | |||||
self.inf = 10e12 | self.inf = 10e12 | ||||
def forward(self, tensor, mask, dim=1): | def forward(self, tensor, mask, dim=1): | ||||
""" | |||||
r""" | |||||
:param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor | :param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor | ||||
:param torch.LongTensor mask: [batch_size, seq_len] 0/1的mask矩阵 | :param torch.LongTensor mask: [batch_size, seq_len] 0/1的mask矩阵 | ||||
:param int dim: 需要进行max pooling的维度 | :param int dim: 需要进行max pooling的维度 | ||||
@@ -82,14 +82,14 @@ class MaxPoolWithMask(nn.Module): | |||||
class KMaxPool(nn.Module): | class KMaxPool(nn.Module): | ||||
"""K max-pooling module.""" | |||||
r"""K max-pooling module.""" | |||||
def __init__(self, k=1): | def __init__(self, k=1): | ||||
super(KMaxPool, self).__init__() | super(KMaxPool, self).__init__() | ||||
self.k = k | self.k = k | ||||
def forward(self, x): | def forward(self, x): | ||||
""" | |||||
r""" | |||||
:param torch.Tensor x: [N, C, L] 初始tensor | :param torch.Tensor x: [N, C, L] 初始tensor | ||||
:return: torch.Tensor x: [N, C*k] k-max pool后的结果 | :return: torch.Tensor x: [N, C*k] k-max pool后的结果 | ||||
""" | """ | ||||
@@ -99,7 +99,7 @@ class KMaxPool(nn.Module): | |||||
class AvgPool(nn.Module): | class AvgPool(nn.Module): | ||||
""" | |||||
r""" | |||||
给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size] | 给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size] | ||||
""" | """ | ||||
@@ -109,7 +109,7 @@ class AvgPool(nn.Module): | |||||
self.padding = padding | self.padding = padding | ||||
def forward(self, x): | def forward(self, x): | ||||
""" | |||||
r""" | |||||
:param torch.Tensor x: [N, C, L] 初始tensor | :param torch.Tensor x: [N, C, L] 初始tensor | ||||
:return: torch.Tensor x: [N, C] avg pool后的结果 | :return: torch.Tensor x: [N, C] avg pool后的结果 | ||||
""" | """ | ||||
@@ -124,7 +124,7 @@ class AvgPool(nn.Module): | |||||
class AvgPoolWithMask(nn.Module): | class AvgPoolWithMask(nn.Module): | ||||
""" | |||||
r""" | |||||
给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling | 给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling | ||||
的时候只会考虑mask为1的位置 | 的时候只会考虑mask为1的位置 | ||||
""" | """ | ||||
@@ -134,7 +134,7 @@ class AvgPoolWithMask(nn.Module): | |||||
self.inf = 10e12 | self.inf = 10e12 | ||||
def forward(self, tensor, mask, dim=1): | def forward(self, tensor, mask, dim=1): | ||||
""" | |||||
r""" | |||||
:param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor | :param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor | ||||
:param torch.LongTensor mask: [batch_size, seq_len] 0/1的mask矩阵 | :param torch.LongTensor mask: [batch_size, seq_len] 0/1的mask矩阵 | ||||
:param int dim: 需要进行max pooling的维度 | :param int dim: 需要进行max pooling的维度 | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented | |||||
r"""undocumented | |||||
Star-Transformer 的encoder部分的 Pytorch 实现 | Star-Transformer 的encoder部分的 Pytorch 实现 | ||||
""" | """ | ||||
@@ -13,7 +13,7 @@ from torch.nn import functional as F | |||||
class StarTransformer(nn.Module): | class StarTransformer(nn.Module): | ||||
""" | |||||
r""" | |||||
Star-Transformer 的encoder部分。 输入3d的文本输入, 返回相同长度的文本编码 | Star-Transformer 的encoder部分。 输入3d的文本输入, 返回相同长度的文本编码 | ||||
paper: https://arxiv.org/abs/1902.09113 | paper: https://arxiv.org/abs/1902.09113 | ||||
@@ -21,7 +21,7 @@ class StarTransformer(nn.Module): | |||||
""" | """ | ||||
def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None): | def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None): | ||||
""" | |||||
r""" | |||||
:param int hidden_size: 输入维度的大小。同时也是输出维度的大小。 | :param int hidden_size: 输入维度的大小。同时也是输出维度的大小。 | ||||
:param int num_layers: star-transformer的层数 | :param int num_layers: star-transformer的层数 | ||||
@@ -51,7 +51,7 @@ class StarTransformer(nn.Module): | |||||
self.pos_emb = None | self.pos_emb = None | ||||
def forward(self, data, mask): | def forward(self, data, mask): | ||||
""" | |||||
r""" | |||||
:param FloatTensor data: [batch, length, hidden] 输入的序列 | :param FloatTensor data: [batch, length, hidden] 输入的序列 | ||||
:param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0, | :param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0, | ||||
否则为 1 | 否则为 1 | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented""" | |||||
r"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"TransformerEncoder" | "TransformerEncoder" | ||||
@@ -9,7 +9,7 @@ from .attention import MultiHeadAttention | |||||
class TransformerEncoder(nn.Module): | class TransformerEncoder(nn.Module): | ||||
""" | |||||
r""" | |||||
transformer的encoder模块,不包含embedding层 | transformer的encoder模块,不包含embedding层 | ||||
""" | """ | ||||
@@ -27,7 +27,7 @@ class TransformerEncoder(nn.Module): | |||||
self.dropout = nn.Dropout(dropout) | self.dropout = nn.Dropout(dropout) | ||||
def forward(self, input, seq_mask=None, atte_mask_out=None): | def forward(self, input, seq_mask=None, atte_mask_out=None): | ||||
""" | |||||
r""" | |||||
:param input: [batch, seq_len, model_size] | :param input: [batch, seq_len, model_size] | ||||
:param seq_mask: [batch, seq_len] | :param seq_mask: [batch, seq_len] | ||||
@@ -46,7 +46,7 @@ class TransformerEncoder(nn.Module): | |||||
return input | return input | ||||
def __init__(self, num_layers, **kargs): | def __init__(self, num_layers, **kargs): | ||||
""" | |||||
r""" | |||||
:param int num_layers: transformer的层数 | :param int num_layers: transformer的层数 | ||||
:param int model_size: 输入维度的大小。同时也是输出维度的大小。 | :param int model_size: 输入维度的大小。同时也是输出维度的大小。 | ||||
@@ -61,7 +61,7 @@ class TransformerEncoder(nn.Module): | |||||
self.norm = nn.LayerNorm(kargs['model_size'], eps=1e-6) | self.norm = nn.LayerNorm(kargs['model_size'], eps=1e-6) | ||||
def forward(self, x, seq_mask=None): | def forward(self, x, seq_mask=None): | ||||
""" | |||||
r""" | |||||
:param x: [batch, seq_len, model_size] 输入序列 | :param x: [batch, seq_len, model_size] 输入序列 | ||||
:param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. | :param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. | ||||
Default: ``None`` | Default: ``None`` | ||||
@@ -1,4 +1,4 @@ | |||||
"""undocumented | |||||
r"""undocumented | |||||
Variational RNN 及相关模型的 fastNLP实现,相关论文参考: | Variational RNN 及相关模型的 fastNLP实现,相关论文参考: | ||||
`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | ||||
""" | """ | ||||
@@ -27,7 +27,7 @@ from ..utils import initial_parameter | |||||
class VarRnnCellWrapper(nn.Module): | class VarRnnCellWrapper(nn.Module): | ||||
""" | |||||
r""" | |||||
Wrapper for normal RNN Cells, make it support variational dropout | Wrapper for normal RNN Cells, make it support variational dropout | ||||
""" | """ | ||||
@@ -39,7 +39,7 @@ class VarRnnCellWrapper(nn.Module): | |||||
self.hidden_p = hidden_p | self.hidden_p = hidden_p | ||||
def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): | def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): | ||||
""" | |||||
r""" | |||||
:param PackedSequence input_x: [seq_len, batch_size, input_size] | :param PackedSequence input_x: [seq_len, batch_size, input_size] | ||||
:param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] | :param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] | ||||
for other RNN, h_0, [batch_size, hidden_size] | for other RNN, h_0, [batch_size, hidden_size] | ||||
@@ -101,7 +101,7 @@ class VarRnnCellWrapper(nn.Module): | |||||
class VarRNNBase(nn.Module): | class VarRNNBase(nn.Module): | ||||
""" | |||||
r""" | |||||
Variational Dropout RNN 实现. | Variational Dropout RNN 实现. | ||||
论文参考: `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) | 论文参考: `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) | ||||
@@ -112,7 +112,7 @@ class VarRNNBase(nn.Module): | |||||
def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | ||||
bias=True, batch_first=False, | bias=True, batch_first=False, | ||||
input_dropout=0, hidden_dropout=0, bidirectional=False): | input_dropout=0, hidden_dropout=0, bidirectional=False): | ||||
""" | |||||
r""" | |||||
:param mode: rnn 模式, (lstm or not) | :param mode: rnn 模式, (lstm or not) | ||||
:param Cell: rnn cell 类型, (lstm, gru, etc) | :param Cell: rnn cell 类型, (lstm, gru, etc) | ||||
@@ -157,7 +157,7 @@ class VarRNNBase(nn.Module): | |||||
return output_x, hidden_x | return output_x, hidden_x | ||||
def forward(self, x, hx=None): | def forward(self, x, hx=None): | ||||
""" | |||||
r""" | |||||
:param x: [batch, seq_len, input_size] 输入序列 | :param x: [batch, seq_len, input_size] 输入序列 | ||||
:param hx: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全1向量. Default: ``None`` | :param hx: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全1向量. Default: ``None`` | ||||
@@ -226,14 +226,14 @@ class VarRNNBase(nn.Module): | |||||
class VarLSTM(VarRNNBase): | class VarLSTM(VarRNNBase): | ||||
""" | |||||
r""" | |||||
Variational Dropout LSTM. | Variational Dropout LSTM. | ||||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | 相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | ||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
""" | |||||
r""" | |||||
:param input_size: 输入 `x` 的特征维度 | :param input_size: 输入 `x` 的特征维度 | ||||
:param hidden_size: 隐状态 `h` 的特征维度 | :param hidden_size: 隐状态 `h` 的特征维度 | ||||
@@ -253,14 +253,14 @@ class VarLSTM(VarRNNBase): | |||||
class VarRNN(VarRNNBase): | class VarRNN(VarRNNBase): | ||||
""" | |||||
r""" | |||||
Variational Dropout RNN. | Variational Dropout RNN. | ||||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | 相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | ||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
""" | |||||
r""" | |||||
:param input_size: 输入 `x` 的特征维度 | :param input_size: 输入 `x` 的特征维度 | ||||
:param hidden_size: 隐状态 `h` 的特征维度 | :param hidden_size: 隐状态 `h` 的特征维度 | ||||
@@ -280,14 +280,14 @@ class VarRNN(VarRNNBase): | |||||
class VarGRU(VarRNNBase): | class VarGRU(VarRNNBase): | ||||
""" | |||||
r""" | |||||
Variational Dropout GRU. | Variational Dropout GRU. | ||||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | 相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | ||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
""" | |||||
r""" | |||||
:param input_size: 输入 `x` 的特征维度 | :param input_size: 输入 `x` 的特征维度 | ||||
:param hidden_size: 隐状态 `h` 的特征维度 | :param hidden_size: 隐状态 `h` 的特征维度 | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
r""" | |||||
.. todo:: | .. todo:: | ||||
doc | doc | ||||
""" | """ | ||||
@@ -17,7 +17,7 @@ import torch.nn.init as init | |||||
def initial_parameter(net, initial_method=None): | def initial_parameter(net, initial_method=None): | ||||
"""A method used to initialize the weights of PyTorch models. | |||||
r"""A method used to initialize the weights of PyTorch models. | |||||
:param net: a PyTorch model | :param net: a PyTorch model | ||||
:param str initial_method: one of the following initializations. | :param str initial_method: one of the following initializations. | ||||
@@ -81,7 +81,7 @@ def initial_parameter(net, initial_method=None): | |||||
def summary(model: nn.Module): | def summary(model: nn.Module): | ||||
""" | |||||
r""" | |||||
得到模型的总参数量 | 得到模型的总参数量 | ||||
:params model: Pytorch 模型 | :params model: Pytorch 模型 | ||||
@@ -122,7 +122,7 @@ def summary(model: nn.Module): | |||||
def get_dropout_mask(drop_p: float, tensor: torch.Tensor): | def get_dropout_mask(drop_p: float, tensor: torch.Tensor): | ||||
""" | |||||
r""" | |||||
根据tensor的形状,生成一个mask | 根据tensor的形状,生成一个mask | ||||
:param drop_p: float, 以多大的概率置为0。 | :param drop_p: float, 以多大的概率置为0。 | ||||
@@ -136,7 +136,7 @@ def get_dropout_mask(drop_p: float, tensor: torch.Tensor): | |||||
def _get_file_name_base_on_postfix(dir_path, postfix): | def _get_file_name_base_on_postfix(dir_path, postfix): | ||||
""" | |||||
r""" | |||||
在dir_path中寻找后缀为postfix的文件. | 在dir_path中寻找后缀为postfix的文件. | ||||
:param dir_path: str, 文件夹 | :param dir_path: str, 文件夹 | ||||
:param postfix: 形如".bin", ".json"等 | :param postfix: 形如".bin", ".json"等 | ||||