@@ -70,10 +70,11 @@ __all__ = [ | |||||
] | ] | ||||
import os | import os | ||||
import sys | |||||
from copy import deepcopy | |||||
import torch | import torch | ||||
from copy import deepcopy | |||||
import sys | |||||
from .utils import _save_model | from .utils import _save_model | ||||
try: | try: | ||||
@@ -928,13 +929,15 @@ class WarmupCallback(Callback): | |||||
class SaveModelCallback(Callback): | class SaveModelCallback(Callback): | ||||
""" | """ | ||||
由于Trainer在训练过程中只会保存最佳的模型, 该callback可实现多种方式的结果存储。 | 由于Trainer在训练过程中只会保存最佳的模型, 该callback可实现多种方式的结果存储。 | ||||
会根据训练开始的时间戳在save_dir下建立文件夹,再在文件夹下存放多个模型 | |||||
-save_dir | |||||
-2019-07-03-15-06-36 | |||||
-epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_performance是性能 | |||||
-epoch:1_step:40_{metric_key}:{evaluate_performance}.pt | |||||
-2019-07-03-15-10-00 | |||||
-epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能 | |||||
会根据训练开始的时间戳在save_dir下建立文件夹,再在文件夹下存放多个模型:: | |||||
-save_dir | |||||
-2019-07-03-15-06-36 | |||||
-epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_performance是性能 | |||||
-epoch:1_step:40_{metric_key}:{evaluate_performance}.pt | |||||
-2019-07-03-15-10-00 | |||||
-epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能 | |||||
:param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型 | :param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型 | ||||
:param int top: 保存dev表现top多少模型。-1为保存所有模型。 | :param int top: 保存dev表现top多少模型。-1为保存所有模型。 | ||||
:param bool only_param: 是否只保存模型d饿权重。 | :param bool only_param: 是否只保存模型d饿权重。 | ||||
@@ -204,7 +204,7 @@ class DataBundle: | |||||
行的数据进行类型和维度推断本列的数据的类型和维度。 | 行的数据进行类型和维度推断本列的数据的类型和维度。 | ||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | ||||
如果为False,则报错 | 如果为False,则报错 | ||||
:return self | |||||
:return: self | |||||
""" | """ | ||||
for field_name in field_names: | for field_name in field_names: | ||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
@@ -229,7 +229,7 @@ class DataBundle: | |||||
行的数据进行类型和维度推断本列的数据的类型和维度。 | 行的数据进行类型和维度推断本列的数据的类型和维度。 | ||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | ||||
如果为False,则报错 | 如果为False,则报错 | ||||
:return self | |||||
:return: self | |||||
""" | """ | ||||
for field_name in field_names: | for field_name in field_names: | ||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
@@ -51,7 +51,7 @@ class _NERPipe(Pipe): | |||||
"[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" | "[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" | ||||
"[...]", "[...]" | "[...]", "[...]" | ||||
:param DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。 | |||||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。 | |||||
在传入DataBundle基础上原位修改。 | 在传入DataBundle基础上原位修改。 | ||||
:return: DataBundle | :return: DataBundle | ||||
""" | """ | ||||
@@ -244,7 +244,7 @@ class _CNNERPipe(Pipe): | |||||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | ||||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | ||||
:param DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。 | |||||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。 | |||||
在传入DataBundle基础上原位修改。 | 在传入DataBundle基础上原位修改。 | ||||
:return: DataBundle | :return: DataBundle | ||||
""" | """ | ||||
@@ -177,7 +177,7 @@ class MatchingPipe(Pipe): | |||||
def _tokenize(self, data_bundle, field_names, new_field_names): | def _tokenize(self, data_bundle, field_names, new_field_names): | ||||
""" | """ | ||||
:param DataBundle data_bundle: DataBundle. | |||||
:param ~fastNLP.DataBundle data_bundle: DataBundle. | |||||
:param list field_names: List[str], 需要tokenize的field名称 | :param list field_names: List[str], 需要tokenize的field名称 | ||||
:param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。 | :param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。 | ||||
:return: 输入的DataBundle对象 | :return: 输入的DataBundle对象 | ||||
@@ -199,7 +199,7 @@ class MatchingPipe(Pipe): | |||||
"This site includes a...", "The Government Executive...", "not_entailment" | "This site includes a...", "The Government Executive...", "not_entailment" | ||||
"...", "..." | "...", "..." | ||||
:param data_bundle: 通过loader读取得到的data_bundle,里面包含了数据集的原始数据内容 | |||||
:param ~fastNLP.DataBundle data_bundle: 通过loader读取得到的data_bundle,里面包含了数据集的原始数据内容 | |||||
:return: data_bundle | :return: data_bundle | ||||
""" | """ | ||||
data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)], | data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)], | ||||
@@ -15,7 +15,7 @@ class Pipe: | |||||
""" | """ | ||||
对输入的DataBundle进行处理,然后返回该DataBundle。 | 对输入的DataBundle进行处理,然后返回该DataBundle。 | ||||
:param data_bundle: 需要处理的DataBundle对象 | |||||
:param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象 | |||||
:return: | :return: | ||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -92,7 +92,7 @@ def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Con | |||||
""" | """ | ||||
在dataset中的field_name列建立词表,Const.TARGET列建立词表,并把词表加入到data_bundle中。 | 在dataset中的field_name列建立词表,Const.TARGET列建立词表,并把词表加入到data_bundle中。 | ||||
:param data_bundle: | |||||
:param ~fastNLP.DataBundle data_bundle: | |||||
:param: str,list input_field_names: | :param: str,list input_field_names: | ||||
:param: str,list target_field_names: 这一列的vocabulary没有unknown和padding | :param: str,list target_field_names: 这一列的vocabulary没有unknown和padding | ||||
:return: | :return: | ||||
@@ -154,7 +154,7 @@ def _drop_empty_instance(data_bundle, field_name): | |||||
""" | """ | ||||
删除data_bundle的DataSet中存在的某个field为空的情况 | 删除data_bundle的DataSet中存在的某个field为空的情况 | ||||
:param data_bundle: DataBundle | |||||
:param ~fastNLP.DataBundle data_bundle: | |||||
:param str field_name: 对哪个field进行检查,如果为None,则任意field为空都会删掉 | :param str field_name: 对哪个field进行检查,如果为None,则任意field为空都会删掉 | ||||
:return: 传入的DataBundle | :return: 传入的DataBundle | ||||
""" | """ | ||||