diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 370026c7..c1bb4ec9 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -138,20 +138,30 @@ class Trainer(object): 开始训练过程。主要有以下几个步骤:: - 对于每次循环 - 1. 使用Batch从DataSet中按批取出数据,并自动对DataSet中dtype为float, int的fields进行padding。并转换为Tensor。 + for epoch in range(num_epochs): + # 使用Batch从DataSet中按批取出数据,并自动对DataSet中dtype为(float, int)的fields进行padding。并转换为Tensor。 非float,int类型的参数将不会被转换为Tensor,且不进行padding。 for batch_x, batch_y in Batch(DataSet) - # batch_x中为设置为input的field - # batch_y中为设置为target的field - 2. 将batch_x的数据送入到model.forward函数中,并获取结果 - 3. 将batch_y与model.forward的结果一并送入loss中计算loss + # batch_x是一个dict, 被设为input的field会出现在这个dict中, + key为DataSet中的field_name, value为该field的value + # batch_y也是一个dict,被设为target的field会出现在这个dict中, + key为DataSet中的field_name, value为该field的value + 2. 将batch_x的数据送入到model.forward函数中,并获取结果。这里我们就是通过匹配batch_x中的key与forward函数的形 + 参完成参数传递。例如, + forward(self, x, seq_lens) # fastNLP会在batch_x中找到key为"x"的value传递给x,key为"seq_lens"的 + value传递给seq_lens。若在batch_x中没有找到所有必须要传递的参数,就会报错。如果forward存在默认参数 + 而且默认参数这个key没有在batch_x中,则使用默认参数。 + 3. 将batch_y与model.forward的结果一并送入loss中计算loss。loss计算时一般都涉及到pred与target。但是在不同情况 + 中,可能pred称为output或prediction, target称为y或label。fastNLP通过初始化loss时传入的映射找到pred或 + target。比如在初始化Trainer时初始化loss为CrossEntropyLoss(pred='output', target='y'), 那么fastNLP计 + 算loss时,就会使用"output"在batch_y与forward的结果中找到pred;使用"y"在batch_y与forward的结果中找target + , 并完成loss的计算。 4. 获取到loss之后,进行反向求导并更新梯度 - 如果测试集不为空 - 根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 + 根据需要适时进行验证机测试 + 根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 - :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现最好的 - 模型参数。 + :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 + 最好的模型参数。 :return results: 返回一个字典类型的数据, 内含以下内容:: seconds: float, 表示训练时长 @@ -196,8 +206,11 @@ class Trainer(object): results['best_step'] = self.best_dev_step if load_best_model: model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) - # self._load_model(self.model, model_name) - print("Reloaded the best model.") + load_succeed = self._load_model(self.model, model_name) + if load_succeed: + print("Reloaded the best model.") + else: + print("Fail to reload best model.") finally: self._summary_writer.close() del self._summary_writer @@ -208,7 +221,7 @@ class Trainer(object): def _tqdm_train(self): self.step = 0 data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, - as_numpy=False) + as_numpy=False) total_steps = data_iterator.num_batches*self.n_epochs with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: avg_loss = 0 @@ -297,7 +310,8 @@ class Trainer(object): if self.save_path is not None: self._save_model(self.model, "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) - + else: + self._best_model_states = {name:param.cpu().clone() for name, param in self.model.named_parameters()} self.best_dev_perf = res self.best_dev_epoch = epoch self.best_dev_step = step @@ -356,7 +370,7 @@ class Trainer(object): torch.save(model, model_name) def _load_model(self, model, model_name, only_param=False): - # TODO: 这个是不是有问题? + # 返回bool值指示是否成功reload模型 if self.save_path is not None: model_path = os.path.join(self.save_path, model_name) if only_param: @@ -364,6 +378,11 @@ class Trainer(object): else: states = torch.load(model_path).state_dict() model.load_state_dict(states) + elif hasattr(self, "_best_model_states"): + model.load_state_dict(self._best_model_states) + else: + return False + return True def _better_eval_result(self, metrics): """Check if the current epoch yields better validation results. @@ -469,7 +488,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ break if dev_data is not None: - tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, + tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, batch_size=batch_size, verbose=-1) evaluate_results = tester.test() _check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics) diff --git a/reproduction/chinese_word_segment/process/cws_processor.py b/reproduction/chinese_word_segment/process/cws_processor.py index 3f7b6176..fa9d7b2c 100644 --- a/reproduction/chinese_word_segment/process/cws_processor.py +++ b/reproduction/chinese_word_segment/process/cws_processor.py @@ -448,4 +448,33 @@ class BMES2OutputProcessor(Processor): words.append(''.join(chars[start_idx:idx+1])) start_idx = idx + 1 return ' '.join(words) - dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) \ No newline at end of file + dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) + + +class InputTargetProcessor(Processor): + def __init__(self, input_fields, target_fields): + """ + 对DataSet操作,将input_fields中的field设置为input,target_fields的中field设置为target + + :param input_fields: List[str], 设置为input_field的field_name。如果为None,则不将任何field设置为target。 + :param target_fields: List[str], 设置为target_field的field_name。 如果为None,则不将任何field设置为target。 + """ + super(InputTargetProcessor, self).__init__(None, None) + + if input_fields is not None and not isinstance(input_fields, list): + raise TypeError("input_fields should be List[str], not {}.".format(type(input_fields))) + else: + self.input_fields = input_fields + if target_fields is not None and not isinstance(target_fields, list): + raise TypeError("target_fiels should be List[str], not{}.".format(type(target_fields))) + else: + self.target_fields = target_fields + + def process(self, dataset): + assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) + if self.input_fields is not None: + for field in self.input_fields: + dataset.set_input(field) + if self.target_fields is not None: + for field in self.target_fields: + dataset.set_target(field) \ No newline at end of file diff --git a/reproduction/chinese_word_segment/train_context.py b/reproduction/chinese_word_segment/train_context.py index 93e3de50..83243863 100644 --- a/reproduction/chinese_word_segment/train_context.py +++ b/reproduction/chinese_word_segment/train_context.py @@ -6,7 +6,7 @@ from reproduction.chinese_word_segment.process.cws_processor import CWSCharSegPr from reproduction.chinese_word_segment.process.cws_processor import CWSBMESTagProcessor from reproduction.chinese_word_segment.process.cws_processor import Pre2Post2BigramProcessor from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor - +from reproduction.chinese_word_segment.process.cws_processor import InputTargetProcessor from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMCRF @@ -39,6 +39,8 @@ bigram_vocab_proc = VocabIndexerProcessor('bigrams_lst', new_added_filed_name='b seq_len_proc = SeqLenProcessor('chars') +input_target_proc = InputTargetProcessor(input_fields=['chars', 'bigrams', 'seq_lens', "target"], + target_fields=['target', 'seq_lens']) # 2. 使用processor fs2hs_proc(tr_dataset) @@ -61,14 +63,11 @@ char_vocab_proc(dev_dataset) bigram_vocab_proc(dev_dataset) seq_len_proc(dev_dataset) -dev_dataset.set_input('chars', 'bigrams', 'target') -tr_dataset.set_input('chars', 'bigrams', 'target') -dev_dataset.set_target('seq_lens') -tr_dataset.set_target('seq_lens') +input_target_proc(tr_dataset) +input_target_proc(dev_dataset) print("Finish preparing data.") - # 3. 得到数据集可以用于训练了 # TODO pretrain的embedding是怎么解决的? @@ -86,80 +85,18 @@ cws_model = CWSBiLSTMCRF(char_vocab_proc.get_vocab_size(), embed_dim=100, cws_model.cuda() num_epochs = 5 -optimizer = optim.Adagrad(cws_model.parameters(), lr=0.02) +optimizer = optim.Adagrad(cws_model.parameters(), lr=0.005) from fastNLP.core.trainer import Trainer from fastNLP.core.sampler import BucketSampler from fastNLP.core.metrics import BMESF1PreRecMetric metric = BMESF1PreRecMetric(target='tags') -trainer = Trainer(train_data=tr_dataset, model=cws_model, loss=None, metrics=metric, n_epochs=3, +trainer = Trainer(train_data=tr_dataset, model=cws_model, loss=None, metrics=metric, n_epochs=num_epochs, batch_size=32, print_every=50, validate_every=-1, dev_data=dev_dataset, save_path=None, optimizer=optimizer, check_code_level=0, metric_key='f', sampler=BucketSampler(), use_tqdm=True) trainer.train() -exit(0) - -# -# print_every = 50 -# batch_size = 32 -# tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) -# dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False) -# num_batch_per_epoch = len(tr_dataset) // batch_size -# best_f1 = 0 -# best_epoch = 0 -# for num_epoch in range(num_epochs): -# print('X' * 10 + ' Epoch: {}/{} '.format(num_epoch + 1, num_epochs) + 'X' * 10) -# sys.stdout.flush() -# avg_loss = 0 -# with tqdm(total=num_batch_per_epoch, leave=True) as pbar: -# pbar.set_description_str('Epoch:%d' % (num_epoch + 1)) -# cws_model.train() -# for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): -# optimizer.zero_grad() -# -# tags = batch_y['tags'].long() -# pred_dict = cws_model(**batch_x, tags=tags) # B x L x tag_size -# -# seq_lens = pred_dict['seq_lens'] -# masks = seq_lens_to_mask(seq_lens).float() -# tags = tags.to(seq_lens.device) -# -# loss = pred_dict['loss'] -# -# # loss = torch.sum(loss_fn(pred_dict['pred_probs'].view(-1, tag_size), -# # tags.view(-1)) * masks.view(-1)) / torch.sum(masks) -# # loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float()) -# -# avg_loss += loss.item() -# -# loss.backward() -# for group in optimizer.param_groups: -# for param in group['params']: -# param.grad.clamp_(-5, 5) -# -# optimizer.step() -# -# if batch_idx % print_every == 0: -# pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every)) -# avg_loss = 0 -# pbar.update(print_every) -# tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) -# # 验证集 -# pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher, type='bmes') -# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100, -# pre*100, -# rec*100)) -# if best_f1每个batch的输出都会调用Metric的\\__call__(pred_dict, target_dict)方法,而\\__call__方法会调用evaluate()(需要实现)方法。 \n", + "  (2.2) 在所有batch传入之后,调用Metric的get_metric()方法得到最终的metric值。 \n", + "  所以Metric在调用evaluate方法时,根据拿到的数据: pred_dict与batch_y, 改变自己的状态(比如累加正确的次数,总的sample数等)。在调用get_metric()的时候给出一个最终计算结果。 \n", + "所有的Metric必须继承自fastNLP.core.metrics.MetricBase. 例子见下一个cell \n", + "(3) 尽量不要复写\\__call__(), _init_param_map()方法。\n" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "MetricBase: \n", + " def __init__(self):\n", + " self.param_map = {} # 一般情况下也不需要自己创建。调用_init_param_map()更好\n", + " self._checked = False # 这个参数可以忽略\n", + "\n", + " def _init_param_map(self, key_map=None, **kwargs):\n", + " # 这个函数是用于注册Metric的“键映射”,有两种传值方法,\n", + " # 第一种是通过key_map传入dict,取值是用value到forward和batch_y取\n", + " # key_map = {'pred': 'output', 'target': 'label'} \n", + " # 第二种是自己写(建议使用改种方式)\n", + " # _init_param_map(pred='output', target='label')\n", + " # 为什么会提供这么一个方法?通过调用这个方法会自动注册param_map,并会做一些检查,防止出现传入的key其实并不是evaluate()\n", + " # 的一个参数。注意传入这个方法的参数必须都是需要做键映射的内容,其它evaluate参数不要传入。如果传入(pred=None, target=None)\n", + " # 则__call__()会到pred_dict与target_dict去寻找key为'pred'和'target'的值。\n", + " # 但这个参数不是必须要调用的。\n", + "\n", + " def __call__(self, pred_dict, target_dict, check=False): # check=False忽略这个参数,之后应该会被删除的\n", + " # 这个函数主要会做一些check的工作,比如pred_dict与target_dict中是否包含了计算evaluate所必须的key等。检查通过,则调用\n", + " # evaluate方法。\n", + " fast_param = self._fast_param_map(predict_dict, target_dict):\n", + " if fast_param:\n", + " return self.evaluate(**fast_param)\n", + " # 如果没有fast_param则通过匹配参数然后调用get_loss完成\n", + " xxxx\n", + "\n", + " def _fast_param_map(self, pred_dict, target_dict):\n", + " # 这是一种快速计算loss的机制,因为在很多情况下其实都不需要通过\"键映射\",比如evaluate时,pred_dict只有一个元素,\n", + " # target_dict也只有一个元素,那么无歧义地就可以把预测值与实际值用于计算metric, 基类判断了这种情况(可能还有其它无歧义的\n", + " # 情况)。即_fast_param_map成功的话,就不需要使用键映射,这样即使在没有传递或者传递错误\"键映射\"的情况也可以直接计算metric。\n", + " # 返回值是一个dict, 如果匹配成功,应该返回类似{'pred':value, 'target': value}的结果;如果dict为空则说明匹配失败,\n", + " # __call__方法会继续尝试匹配。\n", + "\n", + " def evaluate(self, *args, **kwargs):\n", + " # 这个是一定需要实现的,累加metric状态\n", + " # (1) evaluate()中一定不能包含*arg这种参数形式。\n", + " # (2) 如果包含**kwargs这种参数,这会将pred_dict与target_dict中所有参数传入。但是建议不要用这个参数\n", + " raise NotImplementedError\n", + "\n", + " def get_metric(self, reset=True):\n", + " # 这是一定需要实现的,获取最终的metric。返回值必须是一个dict。会在所有batch传入之后调用\n", + " raise NotImplemented\n", + "\n", + "下面使用AccuracyMetric举例\n", + "class AccuracyMetric(MetricBase): # MetricBase\n", + " # 初始化需要映射的值,这里需要映射的值'pred', 'target'必须与evaluate()需要参数名是对应的\n", + " def __init__(self, pred=None, target=None): \n", + " super(AccuracyMetric, self).__init__()\n", + " # 这里传入_init_param_map以使得pred和target被正确注册,但这一步不是必须的, 建议调用。传入_init_param_map的是用于\n", + " # “键映射\"的键值对。假设初始化__init__(pred=None, target=None, threshold=0.1)中threshold是用于控制loss计算的,则\n", + " # 不要将threshold传入_init_param_map.\n", + " self._init_param_map(pred=pred, target=target)\n", + "\n", + " self.total = 0 # 用于累加一共有多少sample\n", + " self.corr = 0 # 用于累加一共有多少正确的sample\n", + "\n", + " def evaluate(self, pred, target):\n", + " # 对pred和target做一些基本的判断或者预处理等\n", + " if pred.size()==target.size() and len(pred.size())=1: #如果pred已经做了argmax\n", + " pass\n", + " elif len(pred.size())==2 and len(target.size())==1: # pred还没有进行argmax\n", + " pred = pred.argmax(dim=1)\n", + " else:\n", + " raise ValueError(\"The shape of pred and target should be ((B, n_classes), (B, )) or (\"\n", + " \"(B,),(B,)).\")\n", + " assert pred.size(0)==target.size(0), \"Mismatch batch size.\"\n", + " # 进行相应的累加\n", + " self.total += pred.size(0)\n", + " self.corr += torch.sum(torch.eq(pred, target).float()).item()\n", + "\n", + " def get_metric(self, reset=True):\n", + " # reset用于指示是否清空累加信息。默认为True\n", + " # 这个函数需要返回dict,可以包含多个metric。\n", + " metric = {}\n", + " metric['acc'] = self.corr/self.total\n", + " if reset:\n", + " self.total = 0\n", + " self.corr = 0\n", + " return metric" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Tester: 用于做evaluation,应该不需要更改\n", + "重要的初始化参数有,data, model, metric \n", + "比较重要的function是test() \n", + "test中的运行过程 \n", + "  predict_func = 如果有model.predict则为model.predict, 否则是model.forward \n", + "  for batch_x, batch_y in batch: \n", + "    # (1) 同步数据与model \n", + "    # (2) 根据predict_func的参数从batch_x中取出数据传入到predict_func中,得到结果pred_dict \n", + "    # (3) 调用metric(pred_dict, batch_y \n", + "    #(4) 当所有batch都运行完毕,会调用metric的get_metric方法,并且以返回的值作为evaluation的结果 \n", + "  metric.get_metric()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Trainer: 对训练过程的封装。 \n", + "里面比较重要的function是train() \n", + "train()中的运行过程 \n", + "  # (1) 创建batch \n", + "  batch = Batch(dataset, batch_size, sampler=sampler) \n", + "  for batch_x, batch_y in batch: \n", + "    \"\"\" \n", + "    batch_x,batch_y都是dict。batch_x是DataSet中被设置为input的field;batch_y是DataSet中被设置为target的field。 \n", + "    两个dict中的key就是DataSet中的key,value会根据情况做好padding的tensor。 \n", + "    \"\"\" \n", + "    # (2)会将batch_x, batch_y中tensor移动到model所在的device \n", + "    # (3)根据model.forward的参数列表, 从batch_x中取出需要传递给forward的数据。 \n", + "    # (4)获取model.forward的输出结果pred_dict,并与batch_y一起传递给loss函数, 求得loss \n", + "    # (5)对loss进行反向梯度并更新参数 \n", + "  # (6) 如果有验证集,则需要做验证 \n", + "  tester = Tester(model, dev_data,metric) \n", + "  eval_results = tester.test() \n", + "  # (7) 如果eval_results是当前的最佳结果,则保存模型。 " + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "除了以上的内容,\n", + "Trainer中还提供了\"预跑\"的功能。该功能通过check_code_level管理,如果check_code_level为-1,则不进行\"预跑\"。\n", + "check_code_level=0,1,2代表不同的提醒级别。目前不同提醒级别对应的是对DataSet中设置为input或target但又没有使用的field的提醒级别。\n", + "0是忽略(默认);1是会warning发生了未使用field的情况;2是出现了unused会直接报错并退出运行\n", + "\"预跑\"的主要目的有两个: (1) 防止train完了之后进行evaluation的时候出现错误。之前的train就白费了\n", + " (2) 由于存在\"键映射\",直接运行导致的报错可能不太容易debug,通过\"预跑\"过程的报错会有一些debug提示\n", + "\"预跑\"会进行以下的操作:(1) 使用很小的batch_size, 检查batch_x中是否包含Model.forward所需要的参数。只会运行两个循环。\n", + " (2) 将Model.foward的输出pred_dict与batch_y输入到loss中, 并尝试backward. 不会更新参数,而且grad会被清零\n", + " 如果传入了dev_data,还将进行metric的测试\n", + " (3) 创建Tester,并传入少量数据,检测是否可以正常运行\n", + "\"预跑\"操作是在Trainer初始化的时候执行的。\n", + "正常情况下,应该不需要改动\"预跑\"的代码。但如果你遇到bug或者有什么好的建议,欢迎在开发群或者github提交issue。" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}