From 56e7641eb8bf7e637cadf4edc5a3a8066377deff Mon Sep 17 00:00:00 2001 From: yh Date: Tue, 8 Jan 2019 21:49:31 +0800 Subject: [PATCH] =?UTF-8?q?1.=20=E4=BF=AE=E5=A4=8DTrainer=20check=5Fcode?= =?UTF-8?q?=E4=B8=AD=E6=A3=80=E6=9F=A5evaluate=E6=97=B6=E4=BD=BF=E7=94=A8t?= =?UTF-8?q?rain=5Fdata=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/trainer.py | 51 ++++++++---- .../process/cws_processor.py | 31 ++++++- .../chinese_word_segment/train_context.py | 81 +++---------------- 3 files changed, 75 insertions(+), 88 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 370026c7..c1bb4ec9 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -138,20 +138,30 @@ class Trainer(object): 开始训练过程。主要有以下几个步骤:: - 对于每次循环 - 1. 使用Batch从DataSet中按批取出数据,并自动对DataSet中dtype为float, int的fields进行padding。并转换为Tensor。 + for epoch in range(num_epochs): + # 使用Batch从DataSet中按批取出数据,并自动对DataSet中dtype为(float, int)的fields进行padding。并转换为Tensor。 非float,int类型的参数将不会被转换为Tensor,且不进行padding。 for batch_x, batch_y in Batch(DataSet) - # batch_x中为设置为input的field - # batch_y中为设置为target的field - 2. 将batch_x的数据送入到model.forward函数中,并获取结果 - 3. 将batch_y与model.forward的结果一并送入loss中计算loss + # batch_x是一个dict, 被设为input的field会出现在这个dict中, + key为DataSet中的field_name, value为该field的value + # batch_y也是一个dict,被设为target的field会出现在这个dict中, + key为DataSet中的field_name, value为该field的value + 2. 将batch_x的数据送入到model.forward函数中,并获取结果。这里我们就是通过匹配batch_x中的key与forward函数的形 + 参完成参数传递。例如, + forward(self, x, seq_lens) # fastNLP会在batch_x中找到key为"x"的value传递给x,key为"seq_lens"的 + value传递给seq_lens。若在batch_x中没有找到所有必须要传递的参数,就会报错。如果forward存在默认参数 + 而且默认参数这个key没有在batch_x中,则使用默认参数。 + 3. 将batch_y与model.forward的结果一并送入loss中计算loss。loss计算时一般都涉及到pred与target。但是在不同情况 + 中,可能pred称为output或prediction, target称为y或label。fastNLP通过初始化loss时传入的映射找到pred或 + target。比如在初始化Trainer时初始化loss为CrossEntropyLoss(pred='output', target='y'), 那么fastNLP计 + 算loss时,就会使用"output"在batch_y与forward的结果中找到pred;使用"y"在batch_y与forward的结果中找target + , 并完成loss的计算。 4. 获取到loss之后,进行反向求导并更新梯度 - 如果测试集不为空 - 根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 + 根据需要适时进行验证机测试 + 根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 - :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现最好的 - 模型参数。 + :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 + 最好的模型参数。 :return results: 返回一个字典类型的数据, 内含以下内容:: seconds: float, 表示训练时长 @@ -196,8 +206,11 @@ class Trainer(object): results['best_step'] = self.best_dev_step if load_best_model: model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) - # self._load_model(self.model, model_name) - print("Reloaded the best model.") + load_succeed = self._load_model(self.model, model_name) + if load_succeed: + print("Reloaded the best model.") + else: + print("Fail to reload best model.") finally: self._summary_writer.close() del self._summary_writer @@ -208,7 +221,7 @@ class Trainer(object): def _tqdm_train(self): self.step = 0 data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, - as_numpy=False) + as_numpy=False) total_steps = data_iterator.num_batches*self.n_epochs with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: avg_loss = 0 @@ -297,7 +310,8 @@ class Trainer(object): if self.save_path is not None: self._save_model(self.model, "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) - + else: + self._best_model_states = {name:param.cpu().clone() for name, param in self.model.named_parameters()} self.best_dev_perf = res self.best_dev_epoch = epoch self.best_dev_step = step @@ -356,7 +370,7 @@ class Trainer(object): torch.save(model, model_name) def _load_model(self, model, model_name, only_param=False): - # TODO: 这个是不是有问题? + # 返回bool值指示是否成功reload模型 if self.save_path is not None: model_path = os.path.join(self.save_path, model_name) if only_param: @@ -364,6 +378,11 @@ class Trainer(object): else: states = torch.load(model_path).state_dict() model.load_state_dict(states) + elif hasattr(self, "_best_model_states"): + model.load_state_dict(self._best_model_states) + else: + return False + return True def _better_eval_result(self, metrics): """Check if the current epoch yields better validation results. @@ -469,7 +488,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ break if dev_data is not None: - tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, + tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, batch_size=batch_size, verbose=-1) evaluate_results = tester.test() _check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics) diff --git a/reproduction/chinese_word_segment/process/cws_processor.py b/reproduction/chinese_word_segment/process/cws_processor.py index 3f7b6176..fa9d7b2c 100644 --- a/reproduction/chinese_word_segment/process/cws_processor.py +++ b/reproduction/chinese_word_segment/process/cws_processor.py @@ -448,4 +448,33 @@ class BMES2OutputProcessor(Processor): words.append(''.join(chars[start_idx:idx+1])) start_idx = idx + 1 return ' '.join(words) - dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) \ No newline at end of file + dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) + + +class InputTargetProcessor(Processor): + def __init__(self, input_fields, target_fields): + """ + 对DataSet操作,将input_fields中的field设置为input,target_fields的中field设置为target + + :param input_fields: List[str], 设置为input_field的field_name。如果为None,则不将任何field设置为target。 + :param target_fields: List[str], 设置为target_field的field_name。 如果为None,则不将任何field设置为target。 + """ + super(InputTargetProcessor, self).__init__(None, None) + + if input_fields is not None and not isinstance(input_fields, list): + raise TypeError("input_fields should be List[str], not {}.".format(type(input_fields))) + else: + self.input_fields = input_fields + if target_fields is not None and not isinstance(target_fields, list): + raise TypeError("target_fiels should be List[str], not{}.".format(type(target_fields))) + else: + self.target_fields = target_fields + + def process(self, dataset): + assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) + if self.input_fields is not None: + for field in self.input_fields: + dataset.set_input(field) + if self.target_fields is not None: + for field in self.target_fields: + dataset.set_target(field) \ No newline at end of file diff --git a/reproduction/chinese_word_segment/train_context.py b/reproduction/chinese_word_segment/train_context.py index 93e3de50..83243863 100644 --- a/reproduction/chinese_word_segment/train_context.py +++ b/reproduction/chinese_word_segment/train_context.py @@ -6,7 +6,7 @@ from reproduction.chinese_word_segment.process.cws_processor import CWSCharSegPr from reproduction.chinese_word_segment.process.cws_processor import CWSBMESTagProcessor from reproduction.chinese_word_segment.process.cws_processor import Pre2Post2BigramProcessor from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor - +from reproduction.chinese_word_segment.process.cws_processor import InputTargetProcessor from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMCRF @@ -39,6 +39,8 @@ bigram_vocab_proc = VocabIndexerProcessor('bigrams_lst', new_added_filed_name='b seq_len_proc = SeqLenProcessor('chars') +input_target_proc = InputTargetProcessor(input_fields=['chars', 'bigrams', 'seq_lens', "target"], + target_fields=['target', 'seq_lens']) # 2. 使用processor fs2hs_proc(tr_dataset) @@ -61,14 +63,11 @@ char_vocab_proc(dev_dataset) bigram_vocab_proc(dev_dataset) seq_len_proc(dev_dataset) -dev_dataset.set_input('chars', 'bigrams', 'target') -tr_dataset.set_input('chars', 'bigrams', 'target') -dev_dataset.set_target('seq_lens') -tr_dataset.set_target('seq_lens') +input_target_proc(tr_dataset) +input_target_proc(dev_dataset) print("Finish preparing data.") - # 3. 得到数据集可以用于训练了 # TODO pretrain的embedding是怎么解决的? @@ -86,80 +85,18 @@ cws_model = CWSBiLSTMCRF(char_vocab_proc.get_vocab_size(), embed_dim=100, cws_model.cuda() num_epochs = 5 -optimizer = optim.Adagrad(cws_model.parameters(), lr=0.02) +optimizer = optim.Adagrad(cws_model.parameters(), lr=0.005) from fastNLP.core.trainer import Trainer from fastNLP.core.sampler import BucketSampler from fastNLP.core.metrics import BMESF1PreRecMetric metric = BMESF1PreRecMetric(target='tags') -trainer = Trainer(train_data=tr_dataset, model=cws_model, loss=None, metrics=metric, n_epochs=3, +trainer = Trainer(train_data=tr_dataset, model=cws_model, loss=None, metrics=metric, n_epochs=num_epochs, batch_size=32, print_every=50, validate_every=-1, dev_data=dev_dataset, save_path=None, optimizer=optimizer, check_code_level=0, metric_key='f', sampler=BucketSampler(), use_tqdm=True) trainer.train() -exit(0) - -# -# print_every = 50 -# batch_size = 32 -# tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) -# dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False) -# num_batch_per_epoch = len(tr_dataset) // batch_size -# best_f1 = 0 -# best_epoch = 0 -# for num_epoch in range(num_epochs): -# print('X' * 10 + ' Epoch: {}/{} '.format(num_epoch + 1, num_epochs) + 'X' * 10) -# sys.stdout.flush() -# avg_loss = 0 -# with tqdm(total=num_batch_per_epoch, leave=True) as pbar: -# pbar.set_description_str('Epoch:%d' % (num_epoch + 1)) -# cws_model.train() -# for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): -# optimizer.zero_grad() -# -# tags = batch_y['tags'].long() -# pred_dict = cws_model(**batch_x, tags=tags) # B x L x tag_size -# -# seq_lens = pred_dict['seq_lens'] -# masks = seq_lens_to_mask(seq_lens).float() -# tags = tags.to(seq_lens.device) -# -# loss = pred_dict['loss'] -# -# # loss = torch.sum(loss_fn(pred_dict['pred_probs'].view(-1, tag_size), -# # tags.view(-1)) * masks.view(-1)) / torch.sum(masks) -# # loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float()) -# -# avg_loss += loss.item() -# -# loss.backward() -# for group in optimizer.param_groups: -# for param in group['params']: -# param.grad.clamp_(-5, 5) -# -# optimizer.step() -# -# if batch_idx % print_every == 0: -# pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every)) -# avg_loss = 0 -# pbar.update(print_every) -# tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) -# # 验证集 -# pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher, type='bmes') -# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100, -# pre*100, -# rec*100)) -# if best_f1