@@ -138,20 +138,30 @@ class Trainer(object): | |||||
开始训练过程。主要有以下几个步骤:: | 开始训练过程。主要有以下几个步骤:: | ||||
对于每次循环 | |||||
1. 使用Batch从DataSet中按批取出数据,并自动对DataSet中dtype为float, int的fields进行padding。并转换为Tensor。 | |||||
for epoch in range(num_epochs): | |||||
# 使用Batch从DataSet中按批取出数据,并自动对DataSet中dtype为(float, int)的fields进行padding。并转换为Tensor。 | |||||
非float,int类型的参数将不会被转换为Tensor,且不进行padding。 | 非float,int类型的参数将不会被转换为Tensor,且不进行padding。 | ||||
for batch_x, batch_y in Batch(DataSet) | for batch_x, batch_y in Batch(DataSet) | ||||
# batch_x中为设置为input的field | |||||
# batch_y中为设置为target的field | |||||
2. 将batch_x的数据送入到model.forward函数中,并获取结果 | |||||
3. 将batch_y与model.forward的结果一并送入loss中计算loss | |||||
# batch_x是一个dict, 被设为input的field会出现在这个dict中, | |||||
key为DataSet中的field_name, value为该field的value | |||||
# batch_y也是一个dict,被设为target的field会出现在这个dict中, | |||||
key为DataSet中的field_name, value为该field的value | |||||
2. 将batch_x的数据送入到model.forward函数中,并获取结果。这里我们就是通过匹配batch_x中的key与forward函数的形 | |||||
参完成参数传递。例如, | |||||
forward(self, x, seq_lens) # fastNLP会在batch_x中找到key为"x"的value传递给x,key为"seq_lens"的 | |||||
value传递给seq_lens。若在batch_x中没有找到所有必须要传递的参数,就会报错。如果forward存在默认参数 | |||||
而且默认参数这个key没有在batch_x中,则使用默认参数。 | |||||
3. 将batch_y与model.forward的结果一并送入loss中计算loss。loss计算时一般都涉及到pred与target。但是在不同情况 | |||||
中,可能pred称为output或prediction, target称为y或label。fastNLP通过初始化loss时传入的映射找到pred或 | |||||
target。比如在初始化Trainer时初始化loss为CrossEntropyLoss(pred='output', target='y'), 那么fastNLP计 | |||||
算loss时,就会使用"output"在batch_y与forward的结果中找到pred;使用"y"在batch_y与forward的结果中找target | |||||
, 并完成loss的计算。 | |||||
4. 获取到loss之后,进行反向求导并更新梯度 | 4. 获取到loss之后,进行反向求导并更新梯度 | ||||
如果测试集不为空 | |||||
根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 | |||||
根据需要适时进行验证机测试 | |||||
根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 | |||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现最好的 | |||||
模型参数。 | |||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | |||||
最好的模型参数。 | |||||
:return results: 返回一个字典类型的数据, 内含以下内容:: | :return results: 返回一个字典类型的数据, 内含以下内容:: | ||||
seconds: float, 表示训练时长 | seconds: float, 表示训练时长 | ||||
@@ -196,8 +206,11 @@ class Trainer(object): | |||||
results['best_step'] = self.best_dev_step | results['best_step'] = self.best_dev_step | ||||
if load_best_model: | if load_best_model: | ||||
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) | model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) | ||||
# self._load_model(self.model, model_name) | |||||
print("Reloaded the best model.") | |||||
load_succeed = self._load_model(self.model, model_name) | |||||
if load_succeed: | |||||
print("Reloaded the best model.") | |||||
else: | |||||
print("Fail to reload best model.") | |||||
finally: | finally: | ||||
self._summary_writer.close() | self._summary_writer.close() | ||||
del self._summary_writer | del self._summary_writer | ||||
@@ -208,7 +221,7 @@ class Trainer(object): | |||||
def _tqdm_train(self): | def _tqdm_train(self): | ||||
self.step = 0 | self.step = 0 | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | ||||
as_numpy=False) | |||||
as_numpy=False) | |||||
total_steps = data_iterator.num_batches*self.n_epochs | total_steps = data_iterator.num_batches*self.n_epochs | ||||
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | ||||
avg_loss = 0 | avg_loss = 0 | ||||
@@ -297,7 +310,8 @@ class Trainer(object): | |||||
if self.save_path is not None: | if self.save_path is not None: | ||||
self._save_model(self.model, | self._save_model(self.model, | ||||
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) | "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) | ||||
else: | |||||
self._best_model_states = {name:param.cpu().clone() for name, param in self.model.named_parameters()} | |||||
self.best_dev_perf = res | self.best_dev_perf = res | ||||
self.best_dev_epoch = epoch | self.best_dev_epoch = epoch | ||||
self.best_dev_step = step | self.best_dev_step = step | ||||
@@ -356,7 +370,7 @@ class Trainer(object): | |||||
torch.save(model, model_name) | torch.save(model, model_name) | ||||
def _load_model(self, model, model_name, only_param=False): | def _load_model(self, model, model_name, only_param=False): | ||||
# TODO: 这个是不是有问题? | |||||
# 返回bool值指示是否成功reload模型 | |||||
if self.save_path is not None: | if self.save_path is not None: | ||||
model_path = os.path.join(self.save_path, model_name) | model_path = os.path.join(self.save_path, model_name) | ||||
if only_param: | if only_param: | ||||
@@ -364,6 +378,11 @@ class Trainer(object): | |||||
else: | else: | ||||
states = torch.load(model_path).state_dict() | states = torch.load(model_path).state_dict() | ||||
model.load_state_dict(states) | model.load_state_dict(states) | ||||
elif hasattr(self, "_best_model_states"): | |||||
model.load_state_dict(self._best_model_states) | |||||
else: | |||||
return False | |||||
return True | |||||
def _better_eval_result(self, metrics): | def _better_eval_result(self, metrics): | ||||
"""Check if the current epoch yields better validation results. | """Check if the current epoch yields better validation results. | ||||
@@ -469,7 +488,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
break | break | ||||
if dev_data is not None: | if dev_data is not None: | ||||
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | |||||
tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | |||||
batch_size=batch_size, verbose=-1) | batch_size=batch_size, verbose=-1) | ||||
evaluate_results = tester.test() | evaluate_results = tester.test() | ||||
_check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics) | _check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics) | ||||
@@ -448,4 +448,33 @@ class BMES2OutputProcessor(Processor): | |||||
words.append(''.join(chars[start_idx:idx+1])) | words.append(''.join(chars[start_idx:idx+1])) | ||||
start_idx = idx + 1 | start_idx = idx + 1 | ||||
return ' '.join(words) | return ' '.join(words) | ||||
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||||
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||||
class InputTargetProcessor(Processor): | |||||
def __init__(self, input_fields, target_fields): | |||||
""" | |||||
对DataSet操作,将input_fields中的field设置为input,target_fields的中field设置为target | |||||
:param input_fields: List[str], 设置为input_field的field_name。如果为None,则不将任何field设置为target。 | |||||
:param target_fields: List[str], 设置为target_field的field_name。 如果为None,则不将任何field设置为target。 | |||||
""" | |||||
super(InputTargetProcessor, self).__init__(None, None) | |||||
if input_fields is not None and not isinstance(input_fields, list): | |||||
raise TypeError("input_fields should be List[str], not {}.".format(type(input_fields))) | |||||
else: | |||||
self.input_fields = input_fields | |||||
if target_fields is not None and not isinstance(target_fields, list): | |||||
raise TypeError("target_fiels should be List[str], not{}.".format(type(target_fields))) | |||||
else: | |||||
self.target_fields = target_fields | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
if self.input_fields is not None: | |||||
for field in self.input_fields: | |||||
dataset.set_input(field) | |||||
if self.target_fields is not None: | |||||
for field in self.target_fields: | |||||
dataset.set_target(field) |
@@ -6,7 +6,7 @@ from reproduction.chinese_word_segment.process.cws_processor import CWSCharSegPr | |||||
from reproduction.chinese_word_segment.process.cws_processor import CWSBMESTagProcessor | from reproduction.chinese_word_segment.process.cws_processor import CWSBMESTagProcessor | ||||
from reproduction.chinese_word_segment.process.cws_processor import Pre2Post2BigramProcessor | from reproduction.chinese_word_segment.process.cws_processor import Pre2Post2BigramProcessor | ||||
from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor | from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor | ||||
from reproduction.chinese_word_segment.process.cws_processor import InputTargetProcessor | |||||
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader | from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader | ||||
from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMCRF | from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMCRF | ||||
@@ -39,6 +39,8 @@ bigram_vocab_proc = VocabIndexerProcessor('bigrams_lst', new_added_filed_name='b | |||||
seq_len_proc = SeqLenProcessor('chars') | seq_len_proc = SeqLenProcessor('chars') | ||||
input_target_proc = InputTargetProcessor(input_fields=['chars', 'bigrams', 'seq_lens', "target"], | |||||
target_fields=['target', 'seq_lens']) | |||||
# 2. 使用processor | # 2. 使用processor | ||||
fs2hs_proc(tr_dataset) | fs2hs_proc(tr_dataset) | ||||
@@ -61,14 +63,11 @@ char_vocab_proc(dev_dataset) | |||||
bigram_vocab_proc(dev_dataset) | bigram_vocab_proc(dev_dataset) | ||||
seq_len_proc(dev_dataset) | seq_len_proc(dev_dataset) | ||||
dev_dataset.set_input('chars', 'bigrams', 'target') | |||||
tr_dataset.set_input('chars', 'bigrams', 'target') | |||||
dev_dataset.set_target('seq_lens') | |||||
tr_dataset.set_target('seq_lens') | |||||
input_target_proc(tr_dataset) | |||||
input_target_proc(dev_dataset) | |||||
print("Finish preparing data.") | print("Finish preparing data.") | ||||
# 3. 得到数据集可以用于训练了 | # 3. 得到数据集可以用于训练了 | ||||
# TODO pretrain的embedding是怎么解决的? | # TODO pretrain的embedding是怎么解决的? | ||||
@@ -86,80 +85,18 @@ cws_model = CWSBiLSTMCRF(char_vocab_proc.get_vocab_size(), embed_dim=100, | |||||
cws_model.cuda() | cws_model.cuda() | ||||
num_epochs = 5 | num_epochs = 5 | ||||
optimizer = optim.Adagrad(cws_model.parameters(), lr=0.02) | |||||
optimizer = optim.Adagrad(cws_model.parameters(), lr=0.005) | |||||
from fastNLP.core.trainer import Trainer | from fastNLP.core.trainer import Trainer | ||||
from fastNLP.core.sampler import BucketSampler | from fastNLP.core.sampler import BucketSampler | ||||
from fastNLP.core.metrics import BMESF1PreRecMetric | from fastNLP.core.metrics import BMESF1PreRecMetric | ||||
metric = BMESF1PreRecMetric(target='tags') | metric = BMESF1PreRecMetric(target='tags') | ||||
trainer = Trainer(train_data=tr_dataset, model=cws_model, loss=None, metrics=metric, n_epochs=3, | |||||
trainer = Trainer(train_data=tr_dataset, model=cws_model, loss=None, metrics=metric, n_epochs=num_epochs, | |||||
batch_size=32, print_every=50, validate_every=-1, dev_data=dev_dataset, save_path=None, | batch_size=32, print_every=50, validate_every=-1, dev_data=dev_dataset, save_path=None, | ||||
optimizer=optimizer, check_code_level=0, metric_key='f', sampler=BucketSampler(), use_tqdm=True) | optimizer=optimizer, check_code_level=0, metric_key='f', sampler=BucketSampler(), use_tqdm=True) | ||||
trainer.train() | trainer.train() | ||||
exit(0) | |||||
# | |||||
# print_every = 50 | |||||
# batch_size = 32 | |||||
# tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) | |||||
# dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False) | |||||
# num_batch_per_epoch = len(tr_dataset) // batch_size | |||||
# best_f1 = 0 | |||||
# best_epoch = 0 | |||||
# for num_epoch in range(num_epochs): | |||||
# print('X' * 10 + ' Epoch: {}/{} '.format(num_epoch + 1, num_epochs) + 'X' * 10) | |||||
# sys.stdout.flush() | |||||
# avg_loss = 0 | |||||
# with tqdm(total=num_batch_per_epoch, leave=True) as pbar: | |||||
# pbar.set_description_str('Epoch:%d' % (num_epoch + 1)) | |||||
# cws_model.train() | |||||
# for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): | |||||
# optimizer.zero_grad() | |||||
# | |||||
# tags = batch_y['tags'].long() | |||||
# pred_dict = cws_model(**batch_x, tags=tags) # B x L x tag_size | |||||
# | |||||
# seq_lens = pred_dict['seq_lens'] | |||||
# masks = seq_lens_to_mask(seq_lens).float() | |||||
# tags = tags.to(seq_lens.device) | |||||
# | |||||
# loss = pred_dict['loss'] | |||||
# | |||||
# # loss = torch.sum(loss_fn(pred_dict['pred_probs'].view(-1, tag_size), | |||||
# # tags.view(-1)) * masks.view(-1)) / torch.sum(masks) | |||||
# # loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float()) | |||||
# | |||||
# avg_loss += loss.item() | |||||
# | |||||
# loss.backward() | |||||
# for group in optimizer.param_groups: | |||||
# for param in group['params']: | |||||
# param.grad.clamp_(-5, 5) | |||||
# | |||||
# optimizer.step() | |||||
# | |||||
# if batch_idx % print_every == 0: | |||||
# pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every)) | |||||
# avg_loss = 0 | |||||
# pbar.update(print_every) | |||||
# tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) | |||||
# # 验证集 | |||||
# pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher, type='bmes') | |||||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100, | |||||
# pre*100, | |||||
# rec*100)) | |||||
# if best_f1<f1: | |||||
# best_f1 = f1 | |||||
# # 缓存最佳的parameter,可能之后会用于保存 | |||||
# best_state_dict = { | |||||
# key:value.clone() for key, value in | |||||
# cws_model.state_dict().items() | |||||
# } | |||||
# best_epoch = num_epoch | |||||
# | |||||
# cws_model.load_state_dict(best_state_dict) | |||||
# 4. 组装需要存下的内容 | # 4. 组装需要存下的内容 | ||||
pp = Pipeline() | pp = Pipeline() | ||||
@@ -171,6 +108,7 @@ pp.add_processor(bigram_proc) | |||||
pp.add_processor(char_vocab_proc) | pp.add_processor(char_vocab_proc) | ||||
pp.add_processor(bigram_vocab_proc) | pp.add_processor(bigram_vocab_proc) | ||||
pp.add_processor(seq_len_proc) | pp.add_processor(seq_len_proc) | ||||
pp.add_processor(input_target_proc) | |||||
# te_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name) | # te_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name) | ||||
te_filename = '/home/hyan/ctb3/test.conllx' | te_filename = '/home/hyan/ctb3/test.conllx' | ||||
@@ -181,6 +119,7 @@ from fastNLP.core.tester import Tester | |||||
tester = Tester(data=te_dataset, model=cws_model, metrics=metric, batch_size=64, use_cuda=False, | tester = Tester(data=te_dataset, model=cws_model, metrics=metric, batch_size=64, use_cuda=False, | ||||
verbose=1) | verbose=1) | ||||
tester.test() | |||||
# | # | ||||
# batch_size = 64 | # batch_size = 64 | ||||
# te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False) | # te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False) | ||||
@@ -193,7 +132,7 @@ tester = Tester(data=te_dataset, model=cws_model, metrics=metric, batch_size=64, | |||||
test_context_dict = {'pipeline': pp, | test_context_dict = {'pipeline': pp, | ||||
'model': cws_model} | 'model': cws_model} | ||||
torch.save(test_context_dict, 'models/test_context_crf.pkl') | |||||
# torch.save(test_context_dict, 'models/test_context_crf.pkl') | |||||
# 5. dev的pp | # 5. dev的pp | ||||
@@ -0,0 +1,353 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### 一共会涉及到如下的几个类\n", | |||||
"\n", | |||||
"#### DataSet\n", | |||||
"#### Sampler\n", | |||||
"#### Batch\n", | |||||
"#### Model\n", | |||||
"#### Loss\n", | |||||
"#### Metric\n", | |||||
"#### Trainer\n", | |||||
"#### Tester" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### 下面具体讲一下它们的作用" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"#### DataSet: 用于承载数据。\n", | |||||
"(1) DataSet里面每个元素只能是以下的三类np.float64, np.int64, np.str。如果传入的数据是int则被转换为np.int64, float被转为np.float64。 \n", | |||||
"(2) DataSet可以将field设置为input,target。其中被设置为input的field会被传递给Model.forward, 这个过程中我们是通过键匹配完成传递的。举例来说,假设DataSet中有'x1', 'x2', 'x3'被设置为了input,而 \n", | |||||
"   (2.1)函数是Model.forward(self, x1, x3), 那么DataSet中'x1', 'x3'会被传递给forward函数。多余的'x2'会被忽略 \n", | |||||
"   (2.2)函数是Model.forward(self, x1, x4), 这里多需要了一个'x4', 但是DataSet的input field中没有这个field,会报错。 \n", | |||||
"   (2.3)函数是Model.forward(self, x1, **kwargs), 会把'x1', 'x2', 'x3'都传入。但如果是Model.forward(self, x4, **kwargs)就会发生报错,因为没有'x4'。 \n", | |||||
"(3) 对于设置为target的field的名称,我们建议取名为'target'(如果只有一个需要predict的值),但是不强制。后面会讲为什么target可以不强制。 \n", | |||||
"DataSet应该是不需要单独再开发的,如果有不能满足的场景,请在开发群提出或者github提交issue。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"#### Sampler: 给定一个DataSet,返回一个序号的list,Batch按照这个list输出数据。\n", | |||||
"Sampler需要继承fastNLP.core.sampler.BaseSampler" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "raw", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"class BaseSampler(object):\n", | |||||
"\"\"\"The base class of all samplers.\n", | |||||
"\n", | |||||
" Sub-classes must implement the __call__ method.\n", | |||||
" __call__ takes a DataSet object and returns a list of int - the sampling indices.\n", | |||||
"\"\"\"\n", | |||||
"def __call__(self, *args, **kwargs):\n", | |||||
" raise NotImplementedError\n", | |||||
" \n", | |||||
"# 子类需要复写__call__方法。这个函数只能有一个必选参数, 且必须是DataSet类别, 否则Trainer没法调\n", | |||||
"class SonSampler(BaseSample):\n", | |||||
" def __init__(self, xxx):\n", | |||||
" # 可以实现init也不可以不实现。\n", | |||||
" def __call__(self, data_set):\n", | |||||
" pass" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"#### Batch: 将DataSet中设置为input和target的field取出来构成batch_x, batch_y\n", | |||||
"并且根据情况(主要根据数据类型能不能转为Tensor)将数据转换为pytorch的Tensor。batch中sample的取出顺序是由Sampler决定的。 \n", | |||||
"Sampler是传入一个DataSet,返回一个与DataSet等长的序号list,Batch一次会取出batch_size个sample(最后一个batch可能数量不足batch_size个)。 \n", | |||||
"举例: \n", | |||||
"(1) SequentialSampler是顺序采样\n", | |||||
" 假设传入的DataSet长度是100, SequentialSampler返回的序号list就是[0, 1, ...,98, 99]. batch_size如果被设置为4,那么第一个batch所获取的instance就是[0, 1, 2, 3]这四个instance. 第二个batch所获取instace就是[4, 5, 6, 7], ...直到采完所有的sample。 \n", | |||||
"(2) RandomSampler是随机采样 \n", | |||||
" 假设传入的DataSet长度是100, RandomSampler返回的序号list可能是[0, 99, 20, 5, 3, 1, ...]. 依次按照batch_size的大小取出sample。 \n", | |||||
"Batch应该不需要继承与开发,如果你有特殊需求请在开发群里提出。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"#### Model:用户自定的Model\n", | |||||
"必须是nn.Module的子类, \n", | |||||
"(1) 必须实现forward方法,并且forward方法不能出现*arg这种参数. 例如 \n", | |||||
"   def forward(self, word_seq, *args): #这是不允许的. \n", | |||||
"      xxx \n", | |||||
"返回值必须是dict的 \n", | |||||
"   def forward(self, word_seq, seq_lens): \n", | |||||
"      xxxx \n", | |||||
"   return {'pred': xxx} #return的值必须是dict的。里面的预测的key推荐使用pred,但是不做强制限制。输出元素数目不限。 \n", | |||||
"(2) 如果实现了predict方法,在做evaluation的时候将调用predict方法而不是forward。如果没有predict方法,则在evaluation时调用forward方法。predict方法也不能使用*args这种参数形式,同时结果也必须返回一个dict,同样推荐key为'pred'。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"#### Loss: 根据model.forward()返回的prediction(是一个dict)和batch_y计算相应的loss。 \n", | |||||
"(1) 先介绍\"键映射\"。 如在DataSet, Model一节所看见的那样,fastNLP并不限制Model.forward()的返回值,也不限制DataSet中target field的key。计算的loss的时候,怎么才能知道从哪里取值呢? \n", | |||||
"这里以CrossEntropyLoss为例,一般情况下, 计算CrossEntropy需要prediction和target两个值。而在CrossEntropyLoss初始化时可以传入两个参数(pred=None, target=None), 这两个参数接受的类型是str,假设(pred='output', target='label'),那么CrossEntropyLoss会使用'output'这个key在forward的output与batch_y中寻找值;'label'也是在forward的output与batch_y中寻找值。注意这里pred或target的来源并不一定非要来自于model.forward与batch_y,也可以只来自于forward的结果。 \n", | |||||
"(2)如何创建一个自己的loss \n", | |||||
"   (2.1)使用fastNLP.LossInForward, 在model.forward()的结果中包含一个为loss的key。 \n", | |||||
"   (2.2) trainer中使用loss(假设loss=CrossEntropyLoss())的时候其实是 \n", | |||||
"    los = loss(prediction, batch_y)\n", | |||||
" 即直接调用的是loss.\\__call__()方法,但是CrossEntropyLoss里面并没有自己实现\\__call__方法,这是因为\\__call__在LossBase中实现了。所有的loss必须继承fastNLP.core.loss.LossBase, 下面先说一下LossBase的几个方法,见下一个cell。 \n", | |||||
"(3) 尽量不要复写\\__call__(), _init_param_map()方法。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "raw", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"class LossBase():\n", | |||||
" def __init__(self):\n", | |||||
" self.param_map = {} # 一般情况下也不需要自己创建。调用_init_param_map()更好\n", | |||||
" self._checked = False # 这个参数可以忽略\n", | |||||
"\n", | |||||
" def _init_param_map(self, key_map=None, **kwargs):\n", | |||||
" # 这个函数是用于注册Loss的“键映射”,有两种传值方法,\n", | |||||
" # 第一种是通过key_map传入dict,取值是用value到forward和batch_y取\n", | |||||
" # key_map = {'pred': 'output', 'target': 'label'} \n", | |||||
" # 第二种是自己写\n", | |||||
" # _init_param_map(pred='output', target='label')\n", | |||||
" # 为什么会提供这么一个方法?通过调用这个方法会自动注册param_map,并会做一些检查,防止出现传入的key其实并不是get_loss\n", | |||||
" # 的一个参数。注意传入这个方法的参数必须都是需要做键映射的内容,其它loss参数不要传入。如果传入(pred=None, target=None)\n", | |||||
" # 则__call__()会到pred_dict与target_dict去寻找key为'pred'和'target'的值。\n", | |||||
" # 但这个参数不是必须要调用的。\n", | |||||
"\n", | |||||
" def __call__(self, pred_dict, target_dict, check=False): # check=False忽略这个参数,之后应该会被删除的\n", | |||||
" # 这个函数主要会做一些check的工作,比如pred_dict与target_dict中是否包含了计算loss所必须的key等。检查通过,则调用get_loss\n", | |||||
" # 方法。\n", | |||||
" fast_param = self._fast_param_map(predict_dict, target_dict):\n", | |||||
" if fast_param:\n", | |||||
" return self.get_loss(**fast_param)\n", | |||||
" # 如果没有fast_param则通过匹配参数然后调用get_loss完成\n", | |||||
" xxxx\n", | |||||
" return loss # 返回为Tensor的loss\n", | |||||
" def _fast_param_map(self, pred_dict, target_dict):\n", | |||||
" # 这是一种快速计算loss的机制,因为在很多情况下其实都不需要通过\"键映射\",比如计算loss时,pred_dict只有一个元素,\n", | |||||
" # target_dict也只有一个元素,那么无歧义地就可以把预测值与实际值用于计算loss, 基类判断了这种情况(可能还有其它无歧义的情况)。\n", | |||||
" # 即_fast_param_map成功的话,就不需要使用键映射,这样即使在没有传递或者传递错误\"键映射\"的情况也可以直接计算loss。\n", | |||||
" # 返回值是一个dict, 如果匹配成功,应该返回类似{'pred':value, 'target': value}的结果;如果dict为空则说明匹配失败,\n", | |||||
" # __call__方法会继续执行。\n", | |||||
"\n", | |||||
" def get_loss(self, *args, **kwargs):\n", | |||||
" # 这个是一定需要实现的,计算loss的地方。\n", | |||||
" # (1) get_loss中一定不能包含*arg这种参数形式。\n", | |||||
" # (2) 如果包含**kwargs这种参数,这会将pred_dict与target_dict中所有参数传入。但是建议不要用这个参数\n", | |||||
" raise NotImplementedError\n", | |||||
"\n", | |||||
"# 下面使用L1Loss举例\n", | |||||
"class L1Loss(LossBase): # 继承LossBase\n", | |||||
" # 初始化需要映射的值,这里需要映射的值'pred', 'target'必须与get_loss需要参数名是对应的\n", | |||||
" def __init__(self, pred=None, target=None): \n", | |||||
" super(L1Loss, self).__init__()\n", | |||||
" # 这里传入_init_param_map以使得pred和target被正确注册,但这一步不是必须的, 建议调用。传入_init_param_map的是用于\n", | |||||
" # “键映射\"的键值对。假设初始化__init__(pred=None, target=None, threshold=0.1)中threshold是用于控制loss计算的,则\n", | |||||
" # 不要将threshold传入_init_param_map.\n", | |||||
" self._init_param_map(pred=pred, target=target)\n", | |||||
"\n", | |||||
" def get_loss(self, pred, target):\n", | |||||
" # 这里'pred', 'target'必须和初始化的映射是一致的。\n", | |||||
" return F.l1_loss(input=pred, target=target) #直接返回一个loss即可" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Metric: 根据Model.forward()或者Model.predict()的结果计算metric \n", | |||||
"metric的设计和loss的设计类似。都是传入pred_dict与target_dict进行计算。但是metric的pred_dict来源可能是Model.forward的返回值, 也可能是Model.predict(如果Model具有predict方法则会调用predict方法)的返回值,下面统一用pred_dict代替。 \n", | |||||
"(1) 这里的\"键映射\"与loss的\"键映射\"是类似的。举例来说,若Metric(pred='output', target='label'),则使用'output'到pred_dict和target_dict中寻找pred, 用'label'寻找target。 \n", | |||||
"(2) 如何创建一个自己的Metric方法 \n", | |||||
"  Metric与loss的计算不同在于,Metric的计算有两个步骤。 \n", | |||||
"  (2.1) <b>每个batch的输出</b>都会调用Metric的\\__call__(pred_dict, target_dict)方法,而\\__call__方法会调用evaluate()(需要实现)方法。 \n", | |||||
"  (2.2) 在所有batch传入之后,调用Metric的get_metric()方法得到最终的metric值。 \n", | |||||
"  所以Metric在调用evaluate方法时,根据拿到的数据: pred_dict与batch_y, 改变自己的状态(比如累加正确的次数,总的sample数等)。在调用get_metric()的时候给出一个最终计算结果。 \n", | |||||
"所有的Metric必须继承自fastNLP.core.metrics.MetricBase. 例子见下一个cell \n", | |||||
"(3) 尽量不要复写\\__call__(), _init_param_map()方法。\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "raw", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"MetricBase: \n", | |||||
" def __init__(self):\n", | |||||
" self.param_map = {} # 一般情况下也不需要自己创建。调用_init_param_map()更好\n", | |||||
" self._checked = False # 这个参数可以忽略\n", | |||||
"\n", | |||||
" def _init_param_map(self, key_map=None, **kwargs):\n", | |||||
" # 这个函数是用于注册Metric的“键映射”,有两种传值方法,\n", | |||||
" # 第一种是通过key_map传入dict,取值是用value到forward和batch_y取\n", | |||||
" # key_map = {'pred': 'output', 'target': 'label'} \n", | |||||
" # 第二种是自己写(建议使用改种方式)\n", | |||||
" # _init_param_map(pred='output', target='label')\n", | |||||
" # 为什么会提供这么一个方法?通过调用这个方法会自动注册param_map,并会做一些检查,防止出现传入的key其实并不是evaluate()\n", | |||||
" # 的一个参数。注意传入这个方法的参数必须都是需要做键映射的内容,其它evaluate参数不要传入。如果传入(pred=None, target=None)\n", | |||||
" # 则__call__()会到pred_dict与target_dict去寻找key为'pred'和'target'的值。\n", | |||||
" # 但这个参数不是必须要调用的。\n", | |||||
"\n", | |||||
" def __call__(self, pred_dict, target_dict, check=False): # check=False忽略这个参数,之后应该会被删除的\n", | |||||
" # 这个函数主要会做一些check的工作,比如pred_dict与target_dict中是否包含了计算evaluate所必须的key等。检查通过,则调用\n", | |||||
" # evaluate方法。\n", | |||||
" fast_param = self._fast_param_map(predict_dict, target_dict):\n", | |||||
" if fast_param:\n", | |||||
" return self.evaluate(**fast_param)\n", | |||||
" # 如果没有fast_param则通过匹配参数然后调用get_loss完成\n", | |||||
" xxxx\n", | |||||
"\n", | |||||
" def _fast_param_map(self, pred_dict, target_dict):\n", | |||||
" # 这是一种快速计算loss的机制,因为在很多情况下其实都不需要通过\"键映射\",比如evaluate时,pred_dict只有一个元素,\n", | |||||
" # target_dict也只有一个元素,那么无歧义地就可以把预测值与实际值用于计算metric, 基类判断了这种情况(可能还有其它无歧义的\n", | |||||
" # 情况)。即_fast_param_map成功的话,就不需要使用键映射,这样即使在没有传递或者传递错误\"键映射\"的情况也可以直接计算metric。\n", | |||||
" # 返回值是一个dict, 如果匹配成功,应该返回类似{'pred':value, 'target': value}的结果;如果dict为空则说明匹配失败,\n", | |||||
" # __call__方法会继续尝试匹配。\n", | |||||
"\n", | |||||
" def evaluate(self, *args, **kwargs):\n", | |||||
" # 这个是一定需要实现的,累加metric状态\n", | |||||
" # (1) evaluate()中一定不能包含*arg这种参数形式。\n", | |||||
" # (2) 如果包含**kwargs这种参数,这会将pred_dict与target_dict中所有参数传入。但是建议不要用这个参数\n", | |||||
" raise NotImplementedError\n", | |||||
"\n", | |||||
" def get_metric(self, reset=True):\n", | |||||
" # 这是一定需要实现的,获取最终的metric。返回值必须是一个dict。会在所有batch传入之后调用\n", | |||||
" raise NotImplemented\n", | |||||
"\n", | |||||
"下面使用AccuracyMetric举例\n", | |||||
"class AccuracyMetric(MetricBase): # MetricBase\n", | |||||
" # 初始化需要映射的值,这里需要映射的值'pred', 'target'必须与evaluate()需要参数名是对应的\n", | |||||
" def __init__(self, pred=None, target=None): \n", | |||||
" super(AccuracyMetric, self).__init__()\n", | |||||
" # 这里传入_init_param_map以使得pred和target被正确注册,但这一步不是必须的, 建议调用。传入_init_param_map的是用于\n", | |||||
" # “键映射\"的键值对。假设初始化__init__(pred=None, target=None, threshold=0.1)中threshold是用于控制loss计算的,则\n", | |||||
" # 不要将threshold传入_init_param_map.\n", | |||||
" self._init_param_map(pred=pred, target=target)\n", | |||||
"\n", | |||||
" self.total = 0 # 用于累加一共有多少sample\n", | |||||
" self.corr = 0 # 用于累加一共有多少正确的sample\n", | |||||
"\n", | |||||
" def evaluate(self, pred, target):\n", | |||||
" # 对pred和target做一些基本的判断或者预处理等\n", | |||||
" if pred.size()==target.size() and len(pred.size())=1: #如果pred已经做了argmax\n", | |||||
" pass\n", | |||||
" elif len(pred.size())==2 and len(target.size())==1: # pred还没有进行argmax\n", | |||||
" pred = pred.argmax(dim=1)\n", | |||||
" else:\n", | |||||
" raise ValueError(\"The shape of pred and target should be ((B, n_classes), (B, )) or (\"\n", | |||||
" \"(B,),(B,)).\")\n", | |||||
" assert pred.size(0)==target.size(0), \"Mismatch batch size.\"\n", | |||||
" # 进行相应的累加\n", | |||||
" self.total += pred.size(0)\n", | |||||
" self.corr += torch.sum(torch.eq(pred, target).float()).item()\n", | |||||
"\n", | |||||
" def get_metric(self, reset=True):\n", | |||||
" # reset用于指示是否清空累加信息。默认为True\n", | |||||
" # 这个函数需要返回dict,可以包含多个metric。\n", | |||||
" metric = {}\n", | |||||
" metric['acc'] = self.corr/self.total\n", | |||||
" if reset:\n", | |||||
" self.total = 0\n", | |||||
" self.corr = 0\n", | |||||
" return metric" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"#### Tester: 用于做evaluation,应该不需要更改\n", | |||||
"重要的初始化参数有,data, model, metric \n", | |||||
"比较重要的function是test() \n", | |||||
"test中的运行过程 \n", | |||||
"  predict_func = 如果有model.predict则为model.predict, 否则是model.forward \n", | |||||
"  for batch_x, batch_y in batch: \n", | |||||
"    # (1) 同步数据与model \n", | |||||
"    # (2) 根据predict_func的参数从batch_x中取出数据传入到predict_func中,得到结果pred_dict \n", | |||||
"    # (3) 调用metric(pred_dict, batch_y \n", | |||||
"    #(4) 当所有batch都运行完毕,会调用metric的get_metric方法,并且以返回的值作为evaluation的结果 \n", | |||||
"  metric.get_metric()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"#### Trainer: 对训练过程的封装。 \n", | |||||
"里面比较重要的function是train() \n", | |||||
"train()中的运行过程 \n", | |||||
"  # (1) 创建batch \n", | |||||
"  batch = Batch(dataset, batch_size, sampler=sampler) \n", | |||||
"  for batch_x, batch_y in batch: \n", | |||||
"    \"\"\" \n", | |||||
"    batch_x,batch_y都是dict。batch_x是DataSet中被设置为input的field;batch_y是DataSet中被设置为target的field。 \n", | |||||
"    两个dict中的key就是DataSet中的key,value会根据情况做好padding的tensor。 \n", | |||||
"    \"\"\" \n", | |||||
"    # (2)会将batch_x, batch_y中tensor移动到model所在的device \n", | |||||
"    # (3)根据model.forward的参数列表, 从batch_x中取出需要传递给forward的数据。 \n", | |||||
"    # (4)获取model.forward的输出结果pred_dict,并与batch_y一起传递给loss函数, 求得loss \n", | |||||
"    # (5)对loss进行反向梯度并更新参数 \n", | |||||
"  # (6) 如果有验证集,则需要做验证 \n", | |||||
"  tester = Tester(model, dev_data,metric) \n", | |||||
"  eval_results = tester.test() \n", | |||||
"  # (7) 如果eval_results是当前的最佳结果,则保存模型。 " | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "raw", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"除了以上的内容,\n", | |||||
"Trainer中还提供了\"预跑\"的功能。该功能通过check_code_level管理,如果check_code_level为-1,则不进行\"预跑\"。\n", | |||||
"check_code_level=0,1,2代表不同的提醒级别。目前不同提醒级别对应的是对DataSet中设置为input或target但又没有使用的field的提醒级别。\n", | |||||
"0是忽略(默认);1是会warning发生了未使用field的情况;2是出现了unused会直接报错并退出运行\n", | |||||
"\"预跑\"的主要目的有两个: (1) 防止train完了之后进行evaluation的时候出现错误。之前的train就白费了\n", | |||||
" (2) 由于存在\"键映射\",直接运行导致的报错可能不太容易debug,通过\"预跑\"过程的报错会有一些debug提示\n", | |||||
"\"预跑\"会进行以下的操作:(1) 使用很小的batch_size, 检查batch_x中是否包含Model.forward所需要的参数。只会运行两个循环。\n", | |||||
" (2) 将Model.foward的输出pred_dict与batch_y输入到loss中, 并尝试backward. 不会更新参数,而且grad会被清零\n", | |||||
" 如果传入了dev_data,还将进行metric的测试\n", | |||||
" (3) 创建Tester,并传入少量数据,检测是否可以正常运行\n", | |||||
"\"预跑\"操作是在Trainer初始化的时候执行的。\n", | |||||
"正常情况下,应该不需要改动\"预跑\"的代码。但如果你遇到bug或者有什么好的建议,欢迎在开发群或者github提交issue。" | |||||
] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python 3", | |||||
"language": "python", | |||||
"name": "python3" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.6.7" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |