From 62ea4f7fed30671d816364ff40bc937daf7d97a5 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sat, 19 Jan 2019 18:40:43 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E6=B7=BB=E5=8A=A0LR=20finder=EF=BC=8C?= =?UTF-8?q?=E7=94=A8=E7=AC=AC=E4=B8=80=E4=B8=AAepoch=E6=89=BE=E6=9C=80?= =?UTF-8?q?=E4=BD=B3lr,=E4=BB=8E=E7=AC=AC=E4=BA=8C=E4=B8=AAepoch=E5=BC=80?= =?UTF-8?q?=E5=A7=8B=E8=AE=AD=E7=BB=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callback.py | 80 +++++++++++++++++++++++++++++++++++++ test/core/test_callbacks.py | 23 ++++++++--- 2 files changed, 98 insertions(+), 5 deletions(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index f354ffc6..e0053124 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -1,3 +1,8 @@ +import torch + +from fastNLP.io.model_io import ModelSaver, ModelLoader + + class Callback(object): """An Interface for all callbacks. @@ -315,6 +320,81 @@ class ControlC(Callback): raise exception # 抛出陌生Error +class SmoothValue(object): + def __init__(self, beta: float): + self.beta, self.n, self.mov_avg = beta, 0, 0 + self.smooth = None + + def add_value(self, val: float) -> None: + "Add `val` to calculate updated smoothed value." + self.n += 1 + self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val + self.smooth = self.mov_avg / (1 - self.beta ** self.n) + + +class LRFinder(Callback): + """fastai lr_finder""" + + def __init__(self, n_batch, start_lr=1e-6, end_lr=10): + """用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 + + :param n_batch: 一个epoch内的iteration数 + :param start_lr: 学习率下界 + :param end_lr: 学习率上界 + """ + super(LRFinder, self).__init__() + self.start_lr, self.end_lr = start_lr, end_lr + self.num_it = n_batch + self.stop = False + self.best_loss = 0. + self.best_lr = None + self.loss_history = [] + self.smooth_value = SmoothValue(0.8) + self.opt = None + scale = (self.end_lr - self.start_lr) / self.num_it + + self.lr_gen = (self.start_lr + scale * (step + 1) for step in range(self.num_it)) + self.find = None + self.loader = ModelLoader() + + def before_epoch(self, cur_epoch, total_epoch): + if cur_epoch == 1: + self.opt = self.trainer.optimizer # pytorch optimizer + self.opt.param_groups[0]["lr"] = self.start_lr + # save model + ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) + self.find = True + + def before_backward(self, loss, model): + if self.find: + if torch.isnan(loss) or self.stop is True: + self.stop = True + return + loss_val = loss.detach().cpu().data + self.loss_history.append(loss_val) + self.smooth_value.add_value(loss_val) + if self.best_loss == 0. or self.smooth_value.smooth < self.best_loss: + self.best_loss = self.smooth_value.smooth + self.best_lr = self.opt.param_groups[0]["lr"] + + def after_batch(self, *args): + if self.find: + lr = next(self.lr_gen, None) + if lr is None or self.stop is True or self.loss_history[-1] > 4 * self.best_loss: + self.stop = True + return + self.opt.param_groups[0]["lr"] = lr + # self.loader.load_pytorch(self.trainer.model, "tmp") + + def after_epoch(self, cur_epoch, n_epoch, optimizer): + if cur_epoch == 1: + self.opt.param_groups[0]["lr"] = self.best_lr + self.find = False + # reset model + ModelLoader().load_pytorch(self.trainer.model, "tmp") + print("Model reset. \nFind best lr={}".format(self.best_lr)) + + if __name__ == "__main__": manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) manager.before_train(10, 11, 12) diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py index 59f2be1b..d0c1fb13 100644 --- a/test/core/test_callbacks.py +++ b/test/core/test_callbacks.py @@ -3,7 +3,7 @@ import unittest import numpy as np import torch -from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC +from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, LRFinder from fastNLP.core.dataset import DataSet from fastNLP.core.instance import Instance from fastNLP.core.losses import BCELoss @@ -52,7 +52,7 @@ class TestCallback(unittest.TestCase): data_set, model = prepare_env() trainer = Trainer(data_set, model, loss=BCELoss(pred="predict", target="y"), - n_epochs=30, + n_epochs=20, batch_size=32, print_every=50, optimizer=SGD(lr=0.1), @@ -67,7 +67,7 @@ class TestCallback(unittest.TestCase): data_set, model = prepare_env() trainer = Trainer(data_set, model, loss=BCELoss(pred="predict", target="y"), - n_epochs=50, + n_epochs=20, batch_size=32, print_every=50, optimizer=SGD(lr=0.01), @@ -83,7 +83,7 @@ class TestCallback(unittest.TestCase): optimizer = torch.optim.SGD(model.parameters(), lr=0.01) trainer = Trainer(data_set, model, loss=BCELoss(pred="predict", target="y"), - n_epochs=50, + n_epochs=5, batch_size=32, print_every=50, optimizer=optimizer, @@ -98,7 +98,7 @@ class TestCallback(unittest.TestCase): data_set, model = prepare_env() trainer = Trainer(data_set, model, loss=BCELoss(pred="predict", target="y"), - n_epochs=50, + n_epochs=5, batch_size=32, print_every=50, optimizer=SGD(lr=0.1), @@ -106,3 +106,16 @@ class TestCallback(unittest.TestCase): use_tqdm=False, callbacks=[ControlC(False)]) trainer.train() + + def test_LRFinder(self): + data_set, model = prepare_env() + trainer = Trainer(data_set, model, + loss=BCELoss(pred="predict", target="y"), + n_epochs=5, + batch_size=32, + print_every=50, + optimizer=SGD(lr=0.1), + check_code_level=2, + use_tqdm=False, + callbacks=[LRFinder(len(data_set) // 32)]) + trainer.train() From b14dd588285d0452722b6529991e181fa3e65219 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sat, 19 Jan 2019 18:48:57 +0800 Subject: [PATCH 2/4] Update POS API --- fastNLP/api/api.py | 2 +- fastNLP/api/examples.py | 6 +++++- reproduction/POS_tagging/train_pos_tag.py | 22 +++++++++++----------- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 38af57b3..0c5f17bc 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -18,7 +18,7 @@ from fastNLP.api.processor import IndexerProcessor # TODO add pretrain urls model_urls = { "cws": "http://123.206.98.91:8888/download/cws_crf_1_11-457fc899.pkl", - "pos": "http://123.206.98.91:8888/download/pos_tag_model_20190108-f3c60ee5.pkl", + "pos": "http://123.206.98.91:8888/download/pos_tag_model_20190119-43f8b435.pkl", "parser": "http://123.206.98.91:8888/download/biaffine_parser-3a2f052c.pkl" } diff --git a/fastNLP/api/examples.py b/fastNLP/api/examples.py index 10cc6edc..447d127a 100644 --- a/fastNLP/api/examples.py +++ b/fastNLP/api/examples.py @@ -16,6 +16,10 @@ def chinese_word_segmentation(): def pos_tagging(): + # 输入已分词序列 + text = ['编者 按: 7月 12日 , 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一款 高科技 隐形 无人机 雷电之神 。'] + text = [text[0].split()] + print(text) pos = POS(device='cpu') print(pos.predict(text)) @@ -26,4 +30,4 @@ def syntactic_parsing(): if __name__ == "__main__": - syntactic_parsing() + pos_tagging() diff --git a/reproduction/POS_tagging/train_pos_tag.py b/reproduction/POS_tagging/train_pos_tag.py index 6448c32b..06547701 100644 --- a/reproduction/POS_tagging/train_pos_tag.py +++ b/reproduction/POS_tagging/train_pos_tag.py @@ -14,7 +14,7 @@ from fastNLP.core.metrics import SpanFPreRecMetric from fastNLP.core.trainer import Trainer from fastNLP.io.config_io import ConfigLoader, ConfigSection from fastNLP.models.sequence_modeling import AdvSeqLabel -from fastNLP.io.dataset_loader import ZhConllPOSReader, ConllxDataLoader +from fastNLP.io.dataset_loader import ConllxDataLoader from fastNLP.api.processor import ModelProcessor, Index2WordProcessor @@ -35,7 +35,7 @@ def load_tencent_embed(embed_path, word2id): return embedding_tensor -def train(train_data_path, dev_data_path, checkpoint=None): +def train(train_data_path, dev_data_path, checkpoint=None, save=None): # load config train_param = ConfigSection() model_param = ConfigSection() @@ -44,9 +44,9 @@ def train(train_data_path, dev_data_path, checkpoint=None): # Data Loader print("loading training set...") - dataset = ConllxDataLoader().load(train_data_path) + dataset = ConllxDataLoader().load(train_data_path, return_dataset=True) print("loading dev set...") - dev_data = ConllxDataLoader().load(dev_data_path) + dev_data = ConllxDataLoader().load(dev_data_path, return_dataset=True) print(dataset) print("================= dataset ready =====================") @@ -54,9 +54,9 @@ def train(train_data_path, dev_data_path, checkpoint=None): dev_data.rename_field("tag", "truth") vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq") - tag_proc = VocabIndexerProcessor("truth") + tag_proc = VocabIndexerProcessor("truth", is_input=True) seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True) - set_input_proc = SetInputProcessor("word_seq", "word_seq_origin_len", "truth") + set_input_proc = SetInputProcessor("word_seq", "word_seq_origin_len") vocab_proc(dataset) tag_proc(dataset) @@ -93,7 +93,7 @@ def train(train_data_path, dev_data_path, checkpoint=None): target="truth", seq_lens="word_seq_origin_len"), dev_data=dev_data, metric_key="f", - use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path="./save_0117") + use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path=save) trainer.train(load_best_model=True) # save model & pipeline @@ -102,12 +102,12 @@ def train(train_data_path, dev_data_path, checkpoint=None): pp = Pipeline([vocab_proc, seq_len_proc, set_input_proc, model_proc, id2tag]) save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} - torch.save(save_dict, "model_pp_0117.pkl") + torch.save(save_dict, os.path.join(save, "model_pp.pkl")) print("pipeline saved") def run_test(test_path): - test_data = ZhConllPOSReader().load(test_path) + test_data = ConllxDataLoader().load(test_path, return_dataset=True) with open("model_pp_0117.pkl", "rb") as f: save_dict = torch.load(f) @@ -157,7 +157,7 @@ if __name__ == "__main__": # 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl if args.checkpoint is None: raise RuntimeError("Please provide the checkpoint. -cp ") - train(args.train, args.dev, args.checkpoint) + train(args.train, args.dev, args.checkpoint, save=args.save) else: # 一次训练 python train_pos_tag.py - train(args.train, args.dev) + train(args.train, args.dev, save=args.save) From f3cb8125544199fa51d8329c54e8bdecc4218fe4 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sun, 20 Jan 2019 16:37:58 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E5=B0=86tesorboardX=E5=A4=84=E7=90=86?= =?UTF-8?q?=E4=B8=BAcallback,=20=E4=BB=8Etrainer=E7=A7=BB=E9=99=A4tensorbo?= =?UTF-8?q?ardX=E7=9B=B8=E5=85=B3=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callback.py | 71 +++++++++++++++++++++++++++++++++++-- fastNLP/core/trainer.py | 31 +++------------- test/core/test_batch.py | 14 ++++---- test/core/test_callbacks.py | 19 +++++++++- 4 files changed, 99 insertions(+), 36 deletions(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index e0053124..48d7333c 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -1,4 +1,7 @@ +import os + import torch +from tensorboardX import SummaryWriter from fastNLP.io.model_io import ModelSaver, ModelLoader @@ -12,6 +15,7 @@ class Callback(object): def __init__(self): super(Callback, self).__init__() + self.trainer = None # 在Trainer内部被重新赋值 def before_train(self): # before the main training loop @@ -333,8 +337,6 @@ class SmoothValue(object): class LRFinder(Callback): - """fastai lr_finder""" - def __init__(self, n_batch, start_lr=1e-6, end_lr=10): """用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 @@ -395,6 +397,71 @@ class LRFinder(Callback): print("Model reset. \nFind best lr={}".format(self.best_lr)) +class TensorboardCallback(Callback): + """ + 接受以下一个或多个字符串作为参数: + - "model" + - "loss" + - "metric" + """ + + def __init__(self, *options): + super(TensorboardCallback, self).__init__() + args = {"model", "loss", "metric"} + for opt in options: + if opt not in args: + raise ValueError("Unrecognized argument {}. Expect one of {}".format(opt, args)) + self.options = options + self._summary_writer = None + self.graph_added = False + + def before_train(self): + save_dir = self.trainer.save_path + if save_dir is None: + path = os.path.join("./", 'tensorboard_logs_{}'.format(self.trainer.start_time)) + else: + path = os.path.join(save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time)) + self._summary_writer = SummaryWriter(path) + + def before_batch(self, batch_x, batch_y, indices): + if "model" in self.options and self.graph_added is False: + # tesorboardX 这里有大bug,暂时没法画模型图 + # from fastNLP.core.utils import _build_args + # inputs = _build_args(self.trainer.model, **batch_x) + # args = tuple([value for value in inputs.values()]) + # args = args[0] if len(args) == 1 else args + # self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2)) + self.graph_added = True + + def before_backward(self, loss, model): + if "loss" in self.options: + self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step) + + if "model" in self.options: + for name, param in self.trainer.model.named_parameters(): + if param.requires_grad: + self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.trainer.step) + # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.trainer.step) + self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), + global_step=self.trainer.step) + + def after_valid(self, eval_result, metric_key, optimizer): + if "metric" in self.options: + for name, metric in eval_result.items(): + for metric_key, metric_val in metric.items(): + self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, + global_step=self.trainer.step) + + def after_train(self, model): + self._summary_writer.close() + del self._summary_writer + + def on_exception(self, exception, model): + if hasattr(self, "_summary_writer"): + self._summary_writer.close() + del self._summary_writer + + if __name__ == "__main__": manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) manager.before_train(10, 11, 12) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index a5861091..b7a8f72b 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -5,7 +5,6 @@ from datetime import timedelta import numpy as np import torch -from tensorboardX import SummaryWriter from torch import nn try: @@ -195,21 +194,9 @@ class Trainer(object): self._model_device = self.model.parameters().__next__().device self._mode(self.model, is_test=False) - self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S')) + self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) start_time = time.time() print("training epochs started " + self.start_time, flush=True) - if self.save_path is None: - class psudoSW: - def __getattr__(self, item): - def pass_func(*args, **kwargs): - pass - - return pass_func - - self._summary_writer = psudoSW() - else: - path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) - self._summary_writer = SummaryWriter(path) try: self.callback_manager.before_train() @@ -232,8 +219,7 @@ class Trainer(object): else: print("Fail to reload best model.") finally: - self._summary_writer.close() - del self._summary_writer + pass results['seconds'] = round(time.time() - start_time, 2) return results @@ -261,7 +247,7 @@ class Trainer(object): # negative sampling; replace unknown; re-weight batch_y self.callback_manager.before_batch(batch_x, batch_y, indices) _move_dict_value_to_device(batch_x, batch_y, device=self._model_device, - non_blocking=self.pin_memory) # pin_memory, use non_blockling. + non_blocking=self.pin_memory) # pin_memory, use non_blocking. prediction = self._data_forward(self.model, batch_x) # edit prediction @@ -279,12 +265,6 @@ class Trainer(object): # lr scheduler; lr_finder; one_cycle self.callback_manager.after_step(self.optimizer) - self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) - for name, param in self.model.named_parameters(): - if param.requires_grad: - self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) - # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) - # self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) if (self.step+1) % self.print_every == 0: if self.use_tqdm: print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) @@ -319,10 +299,7 @@ class Trainer(object): def _do_validation(self, epoch, step): res = self.tester.test() - for name, metric in res.items(): - for metric_key, metric_val in metric.items(): - self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, - global_step=self.step) + if self._better_eval_result(res): if self.save_path is not None: self._save_model(self.model, diff --git a/test/core/test_batch.py b/test/core/test_batch.py index 29a48559..e1561942 100644 --- a/test/core/test_batch.py +++ b/test/core/test_batch.py @@ -1,3 +1,4 @@ +import time import unittest import numpy as np @@ -8,7 +9,7 @@ from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import construct_dataset from fastNLP.core.instance import Instance from fastNLP.core.sampler import SequentialSampler -import time + def generate_fake_dataset(num_samples=1000): """ @@ -161,12 +162,13 @@ class TestCase1(unittest.TestCase): dataset = generate_fake_dataset(num_samples) batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), pin_memory=True) - for batch_x, batch_y in batch: - time.sleep(pause_seconds) + # 这里发生OOM + # for batch_x, batch_y in batch: + # time.sleep(pause_seconds) num_workers = 2 batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers, pin_memory=True) - for batch_x, batch_y in batch: - time.sleep(pause_seconds) - + # 这里发生OOM + # for batch_x, batch_y in batch: + # time.sleep(pause_seconds) diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py index d0c1fb13..74ce4876 100644 --- a/test/core/test_callbacks.py +++ b/test/core/test_callbacks.py @@ -3,7 +3,9 @@ import unittest import numpy as np import torch -from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, LRFinder +from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ + LRFinder, \ + TensorboardCallback from fastNLP.core.dataset import DataSet from fastNLP.core.instance import Instance from fastNLP.core.losses import BCELoss @@ -119,3 +121,18 @@ class TestCallback(unittest.TestCase): use_tqdm=False, callbacks=[LRFinder(len(data_set) // 32)]) trainer.train() + + def test_TensorboardCallback(self): + data_set, model = prepare_env() + trainer = Trainer(data_set, model, + loss=BCELoss(pred="predict", target="y"), + n_epochs=5, + batch_size=32, + print_every=50, + optimizer=SGD(lr=0.1), + check_code_level=2, + use_tqdm=False, + dev_data=data_set, + metrics=AccuracyMetric(pred="predict", target="y"), + callbacks=[TensorboardCallback("loss", "metric")]) + trainer.train() From 47ec69ea96b484458448f4a0d0eda4de8e8b5562 Mon Sep 17 00:00:00 2001 From: yh Date: Mon, 21 Jan 2019 14:44:31 +0800 Subject: [PATCH 4/4] =?UTF-8?q?trainer=E6=A0=B9=E6=8D=AEsyf=E7=9A=84?= =?UTF-8?q?=E5=A4=9A=E8=BF=9B=E7=A8=8Bbatch=E8=BF=9B=E8=A1=8C=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/trainer.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index b7a8f72b..8ca3d22a 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -33,8 +33,8 @@ from fastNLP.core.utils import get_func_signature class Trainer(object): def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), - check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0, pin_memory=False, - timeout=0, use_tqdm=True, use_cuda=False, callbacks=None): + check_code_level=0, metric_key=None, sampler=RandomSampler(), prefetch=False, use_tqdm=True, + use_cuda=False, callbacks=None): """ :param DataSet train_data: the training data :param torch.nn.modules.module model: a PyTorch model @@ -58,12 +58,7 @@ class Trainer(object): metric_key="-PPL" # language model gets better as perplexity gets smaller :param BaseSampler sampler: method used to generate batch data. - :param num_workers: int, 使用多少个进程来准备数据。默认为0, 即使用主线程生成数据。 特性处于实验阶段,谨慎使用。 - 如果DataSet较大,且每个batch的准备时间很短,使用多进程可能并不能提速。 - :param pin_memory: bool, 默认为False. 当设置为True时,会使用锁页内存,可能导致内存占用变多。如果内存比较充足, - 可以考虑设置为True进行加速, 当pin_memory为True时,默认使用non_blocking=True的方式将数据从cpu移动到gpu。 - :param timeout: float, 大于0的数,只有在num_workers>0时才有用。超过该时间仍然没有获取到一个batch则报错,可以用于 - 检测是否出现了batch产生阻塞的情况。 + :param prefetch: bool, 是否使用额外的进程对产生batch数据。 :param bool use_tqdm: whether to use tqdm to show train progress. :param callbacks: List[Callback]. 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 通过callback机制实现。 @@ -125,9 +120,7 @@ class Trainer(object): self.best_dev_step = None self.best_dev_perf = None self.sampler = sampler - self.num_workers = num_workers - self.pin_memory = pin_memory - self.timeout = timeout + self.prefetch = prefetch self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) if isinstance(optimizer, torch.optim.Optimizer): @@ -236,8 +229,7 @@ class Trainer(object): with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: avg_loss = 0 data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, - num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout, - keep_process=True) + prefetch=self.prefetch, device=self._model_device) for epoch in range(1, self.n_epochs+1): pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) # early stopping @@ -246,8 +238,6 @@ class Trainer(object): indices = data_iterator.get_batch_indices() # negative sampling; replace unknown; re-weight batch_y self.callback_manager.before_batch(batch_x, batch_y, indices) - _move_dict_value_to_device(batch_x, batch_y, device=self._model_device, - non_blocking=self.pin_memory) # pin_memory, use non_blocking. prediction = self._data_forward(self.model, batch_x) # edit prediction