2. 删除Trainer的prefetch参数; 在注释中增加num_workers参数 3. Trainer中默认sampler为RandomSamplertags/v0.4.10
| @@ -351,6 +351,8 @@ class Trainer(object): | |||||
| :param int batch_size: 训练和验证的时候的batch大小。 | :param int batch_size: 训练和验证的时候的batch大小。 | ||||
| :param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` | :param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` | ||||
| :param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` | :param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` | ||||
| :param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch | |||||
| :param num_workers: int, 有多少个线程来进行数据pad处理。 | |||||
| :param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 | :param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 | ||||
| 会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 | 会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 | ||||
| :param int n_epochs: 需要优化迭代多少次。 | :param int n_epochs: 需要优化迭代多少次。 | ||||
| @@ -367,7 +369,6 @@ class Trainer(object): | |||||
| :param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 | :param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 | ||||
| :param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模型。 | :param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模型。 | ||||
| 保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | 保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | ||||
| :param prefetch: bool, 是否使用额外的进程对产生batch数据。理论上会使得Batch迭代更快。 | |||||
| :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 | :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 | ||||
| :param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | :param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | ||||
| 的计算位置进行管理。支持以下的输入: | 的计算位置进行管理。支持以下的输入: | ||||
| @@ -394,16 +395,12 @@ class Trainer(object): | |||||
| """ | """ | ||||
| def __init__(self, train_data, model, optimizer=None, loss=None, | def __init__(self, train_data, model, optimizer=None, loss=None, | ||||
| batch_size=32, sampler=None, update_every=1, num_workers=0, | |||||
| n_epochs=10, print_every=5, | |||||
| batch_size=32, sampler=None, drop_last=False,update_every=1, | |||||
| num_workers=0, n_epochs=10, print_every=5, | |||||
| dev_data=None, metrics=None, metric_key=None, | dev_data=None, metrics=None, metric_key=None, | ||||
| validate_every=-1, save_path=None, | |||||
| prefetch=False, use_tqdm=True, device=None, | |||||
| callbacks=None, | |||||
| check_code_level=0): | |||||
| validate_every=-1, save_path=None, use_tqdm=True, device=None, | |||||
| callbacks=None, check_code_level=0): | |||||
| super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
| if not isinstance(train_data, DataSet): | |||||
| raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") | |||||
| if not isinstance(model, nn.Module): | if not isinstance(model, nn.Module): | ||||
| raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | ||||
| @@ -440,9 +437,12 @@ class Trainer(object): | |||||
| if sampler is not None and not isinstance(sampler, Sampler): | if sampler is not None and not isinstance(sampler, Sampler): | ||||
| raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | ||||
| if sampler is None: | |||||
| sampler = RandomSampler() | |||||
| if isinstance(train_data, DataSet): | if isinstance(train_data, DataSet): | ||||
| self.data_iterator = DataSetIter( | self.data_iterator = DataSetIter( | ||||
| dataset=train_data, batch_size=batch_size, num_workers=num_workers) | |||||
| dataset=train_data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, drop_last=drop_last) | |||||
| elif isinstance(train_data, BatchIter): | elif isinstance(train_data, BatchIter): | ||||
| self.data_iterator = train_data | self.data_iterator = train_data | ||||
| else: | else: | ||||
| @@ -470,8 +470,6 @@ class Trainer(object): | |||||
| self.best_dev_epoch = None | self.best_dev_epoch = None | ||||
| self.best_dev_step = None | self.best_dev_step = None | ||||
| self.best_dev_perf = None | self.best_dev_perf = None | ||||
| self.sampler = sampler if sampler is not None else RandomSampler() | |||||
| self.prefetch = prefetch | |||||
| self.n_steps = (len(self.train_data) // self.batch_size + int( | self.n_steps = (len(self.train_data) // self.batch_size + int( | ||||
| len(self.train_data) % self.batch_size != 0)) * self.n_epochs | len(self.train_data) % self.batch_size != 0)) * self.n_epochs | ||||
| @@ -184,11 +184,8 @@ def train(path): | |||||
| m.weight.requires_grad = True | m.weight.requires_grad = True | ||||
| # Trainer | # Trainer | ||||
| trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data, | |||||
| loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | |||||
| **train_args.data, | |||||
| optimizer=fastNLP.Adam(**optim_args.data), | |||||
| save_path=path, | |||||
| trainer = Trainer(train_data=train_data, model=model, optimizer=fastNLP.Adam(**optim_args.data), loss=ParserLoss(), | |||||
| dev_data=dev_data, metrics=ParserMetric(), metric_key='UAS', save_path=path, | |||||
| callbacks=[MyCallback()]) | callbacks=[MyCallback()]) | ||||
| # Start training | # Start training | ||||
| @@ -89,11 +89,11 @@ def train(train_data_path, dev_data_path, checkpoint=None, save=None): | |||||
| model = torch.load(checkpoint) | model = torch.load(checkpoint) | ||||
| # call trainer to train | # call trainer to train | ||||
| trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||||
| target="truth", | |||||
| seq_lens="word_seq_origin_len"), | |||||
| dev_data=dev_data, metric_key="f", | |||||
| use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path=save) | |||||
| trainer = Trainer(dataset, model, loss=None, n_epochs=20, print_every=10, dev_data=dev_data, | |||||
| metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||||
| target="truth", | |||||
| seq_lens="word_seq_origin_len"), metric_key="f", save_path=save, | |||||
| use_tqdm=True) | |||||
| trainer.train(load_best_model=True) | trainer.train(load_best_model=True) | ||||
| # save model & pipeline | # save model & pipeline | ||||
| @@ -149,14 +149,10 @@ def train(): | |||||
| ) if x.requires_grad and x.size(0) != len(word_v)] | ) if x.requires_grad and x.size(0) != len(word_v)] | ||||
| optim_cfg = [{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1}, | optim_cfg = [{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1}, | ||||
| {'params': ex_param, 'lr': g_args.lr, 'weight_decay': g_args.w_decay}, ] | {'params': ex_param, 'lr': g_args.lr, 'weight_decay': g_args.w_decay}, ] | ||||
| trainer = FN.Trainer(model=model, train_data=train_data, dev_data=dev_data, | |||||
| loss=loss, metrics=metric, metric_key=metric_key, | |||||
| optimizer=torch.optim.Adam(optim_cfg), | |||||
| n_epochs=g_args.ep, batch_size=g_args.bsz, print_every=10, validate_every=3000, | |||||
| device=device, | |||||
| use_tqdm=False, prefetch=False, | |||||
| save_path=g_args.log, | |||||
| callbacks=[MyCallback()]) | |||||
| trainer = FN.Trainer(train_data=train_data, model=model, optimizer=torch.optim.Adam(optim_cfg), loss=loss, | |||||
| batch_size=g_args.bsz, n_epochs=g_args.ep, print_every=10, dev_data=dev_data, metrics=metric, | |||||
| metric_key=metric_key, validate_every=3000, save_path=g_args.log, use_tqdm=False, | |||||
| device=device, callbacks=[MyCallback()]) | |||||
| trainer.train() | trainer.train() | ||||
| tester = FN.Tester(data=test_data, model=model, metrics=metric, | tester = FN.Tester(data=test_data, model=model, metrics=metric, | ||||
| @@ -70,19 +70,10 @@ test_data = preprocess_data(test_data, bert_dirs) | |||||
| model = BertForNLI(bert_dir=bert_dirs) | model = BertForNLI(bert_dir=bert_dirs) | ||||
| trainer = Trainer( | |||||
| train_data=train_data, | |||||
| model=model, | |||||
| optimizer=Adam(lr=2e-5, model_params=model.parameters()), | |||||
| batch_size=torch.cuda.device_count() * 12, | |||||
| n_epochs=4, | |||||
| print_every=-1, | |||||
| dev_data=dev_data, | |||||
| metrics=AccuracyMetric(), | |||||
| metric_key='acc', | |||||
| device=[i for i in range(torch.cuda.device_count())], | |||||
| check_code_level=-1 | |||||
| ) | |||||
| trainer = Trainer(train_data=train_data, model=model, optimizer=Adam(lr=2e-5, model_params=model.parameters()), | |||||
| batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1, dev_data=dev_data, | |||||
| metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], | |||||
| check_code_level=-1) | |||||
| trainer.train(load_best_model=True) | trainer.train(load_best_model=True) | ||||
| tester = Tester( | tester = Tester( | ||||
| @@ -57,12 +57,8 @@ callbacks = [clipper] | |||||
| # if pretrain: | # if pretrain: | ||||
| # fixer = FixEmbedding([model.char_embedding, model.bigram_embedding], fix_until=fix_until) | # fixer = FixEmbedding([model.char_embedding, model.bigram_embedding], fix_until=fix_until) | ||||
| # callbacks.append(fixer) | # callbacks.append(fixer) | ||||
| trainer = Trainer(data.datasets['train'], model, optimizer=optimizer, loss=None, | |||||
| batch_size=32, sampler=sampler, update_every=5, | |||||
| n_epochs=3, print_every=5, | |||||
| dev_data=data.datasets['dev'], metrics=RelayMetric(), metric_key='f', | |||||
| validate_every=-1, save_path=None, | |||||
| prefetch=True, use_tqdm=True, device=device, | |||||
| callbacks=callbacks, | |||||
| trainer = Trainer(data.datasets['train'], model, optimizer=optimizer, loss=None, batch_size=32, sampler=sampler, | |||||
| update_every=5, n_epochs=3, print_every=5, dev_data=data.datasets['dev'], metrics=RelayMetric(), | |||||
| metric_key='f', validate_every=-1, save_path=None, use_tqdm=True, device=device, callbacks=callbacks, | |||||
| check_code_level=0) | check_code_level=0) | ||||
| trainer.train() | trainer.train() | ||||
| @@ -40,89 +40,50 @@ class TestCallback(unittest.TestCase): | |||||
| def test_gradient_clip(self): | def test_gradient_clip(self): | ||||
| data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
| trainer = Trainer(data_set, model, | |||||
| loss=BCELoss(pred="predict", target="y"), | |||||
| n_epochs=20, | |||||
| batch_size=32, | |||||
| print_every=50, | |||||
| optimizer=SGD(lr=0.1), | |||||
| check_code_level=2, | |||||
| use_tqdm=False, | |||||
| dev_data=data_set, | |||||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||||
| callbacks=[GradientClipCallback(model.parameters(), clip_value=2)]) | |||||
| trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
| batch_size=32, n_epochs=20, print_every=50, dev_data=data_set, | |||||
| metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||||
| callbacks=[GradientClipCallback(model.parameters(), clip_value=2)], check_code_level=2) | |||||
| trainer.train() | trainer.train() | ||||
| def test_early_stop(self): | def test_early_stop(self): | ||||
| data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
| trainer = Trainer(data_set, model, | |||||
| loss=BCELoss(pred="predict", target="y"), | |||||
| n_epochs=20, | |||||
| batch_size=32, | |||||
| print_every=50, | |||||
| optimizer=SGD(lr=0.01), | |||||
| check_code_level=2, | |||||
| use_tqdm=False, | |||||
| dev_data=data_set, | |||||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||||
| callbacks=[EarlyStopCallback(5)]) | |||||
| trainer = Trainer(data_set, model, optimizer=SGD(lr=0.01), loss=BCELoss(pred="predict", target="y"), | |||||
| batch_size=32, n_epochs=20, print_every=50, dev_data=data_set, | |||||
| metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||||
| callbacks=[EarlyStopCallback(5)], check_code_level=2) | |||||
| trainer.train() | trainer.train() | ||||
| def test_lr_scheduler(self): | def test_lr_scheduler(self): | ||||
| data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
| optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | ||||
| trainer = Trainer(data_set, model, | |||||
| loss=BCELoss(pred="predict", target="y"), | |||||
| n_epochs=5, | |||||
| batch_size=32, | |||||
| print_every=50, | |||||
| optimizer=optimizer, | |||||
| check_code_level=2, | |||||
| use_tqdm=False, | |||||
| dev_data=data_set, | |||||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||||
| callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))]) | |||||
| trainer = Trainer(data_set, model, optimizer=optimizer, loss=BCELoss(pred="predict", target="y"), batch_size=32, | |||||
| n_epochs=5, print_every=50, dev_data=data_set, | |||||
| metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||||
| callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))], | |||||
| check_code_level=2) | |||||
| trainer.train() | trainer.train() | ||||
| def test_KeyBoardInterrupt(self): | def test_KeyBoardInterrupt(self): | ||||
| data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
| trainer = Trainer(data_set, model, | |||||
| loss=BCELoss(pred="predict", target="y"), | |||||
| n_epochs=5, | |||||
| batch_size=32, | |||||
| print_every=50, | |||||
| optimizer=SGD(lr=0.1), | |||||
| check_code_level=2, | |||||
| use_tqdm=False, | |||||
| callbacks=[ControlC(False)]) | |||||
| trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
| batch_size=32, n_epochs=5, print_every=50, use_tqdm=False, callbacks=[ControlC(False)], | |||||
| check_code_level=2) | |||||
| trainer.train() | trainer.train() | ||||
| def test_LRFinder(self): | def test_LRFinder(self): | ||||
| data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
| trainer = Trainer(data_set, model, | |||||
| loss=BCELoss(pred="predict", target="y"), | |||||
| n_epochs=5, | |||||
| batch_size=32, | |||||
| print_every=50, | |||||
| optimizer=SGD(lr=0.1), | |||||
| check_code_level=2, | |||||
| use_tqdm=False, | |||||
| callbacks=[LRFinder(len(data_set) // 32)]) | |||||
| trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
| batch_size=32, n_epochs=5, print_every=50, use_tqdm=False, | |||||
| callbacks=[LRFinder(len(data_set) // 32)], check_code_level=2) | |||||
| trainer.train() | trainer.train() | ||||
| def test_TensorboardCallback(self): | def test_TensorboardCallback(self): | ||||
| data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
| trainer = Trainer(data_set, model, | |||||
| loss=BCELoss(pred="predict", target="y"), | |||||
| n_epochs=5, | |||||
| batch_size=32, | |||||
| print_every=50, | |||||
| optimizer=SGD(lr=0.1), | |||||
| check_code_level=2, | |||||
| use_tqdm=False, | |||||
| dev_data=data_set, | |||||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||||
| callbacks=[TensorboardCallback("loss", "metric")]) | |||||
| trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
| batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, | |||||
| metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||||
| callbacks=[TensorboardCallback("loss", "metric")], check_code_level=2) | |||||
| trainer.train() | trainer.train() | ||||
| def test_readonly_property(self): | def test_readonly_property(self): | ||||
| @@ -141,16 +102,9 @@ class TestCallback(unittest.TestCase): | |||||
| print(self.optimizer) | print(self.optimizer) | ||||
| data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
| trainer = Trainer(data_set, model, | |||||
| loss=BCELoss(pred="predict", target="y"), | |||||
| n_epochs=total_epochs, | |||||
| batch_size=32, | |||||
| print_every=50, | |||||
| optimizer=SGD(lr=0.1), | |||||
| check_code_level=2, | |||||
| use_tqdm=False, | |||||
| dev_data=data_set, | |||||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||||
| callbacks=[MyCallback()]) | |||||
| trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
| batch_size=32, n_epochs=total_epochs, print_every=50, dev_data=data_set, | |||||
| metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, callbacks=[MyCallback()], | |||||
| check_code_level=2) | |||||
| trainer.train() | trainer.train() | ||||
| assert passed_epochs == list(range(1, total_epochs + 1)) | assert passed_epochs == list(range(1, total_epochs + 1)) | ||||
| @@ -46,18 +46,10 @@ class TrainerTestGround(unittest.TestCase): | |||||
| model = NaiveClassifier(2, 1) | model = NaiveClassifier(2, 1) | ||||
| trainer = Trainer(train_set, model, | |||||
| loss=BCELoss(pred="predict", target="y"), | |||||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||||
| n_epochs=10, | |||||
| batch_size=32, | |||||
| print_every=50, | |||||
| validate_every=-1, | |||||
| dev_data=dev_set, | |||||
| optimizer=SGD(lr=0.1), | |||||
| check_code_level=2, | |||||
| use_tqdm=True, | |||||
| save_path=None) | |||||
| trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
| batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, | |||||
| metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||||
| use_tqdm=True, check_code_level=2) | |||||
| trainer.train() | trainer.train() | ||||
| """ | """ | ||||
| # 应该正确运行 | # 应该正确运行 | ||||
| @@ -83,10 +75,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
| model = Model() | model = Model() | ||||
| with self.assertRaises(RuntimeError): | with self.assertRaises(RuntimeError): | ||||
| trainer = Trainer( | |||||
| train_data=dataset, | |||||
| model=model | |||||
| ) | |||||
| trainer = Trainer(train_data=dataset, model=model) | |||||
| """ | """ | ||||
| # 应该获取到的报错提示 | # 应该获取到的报错提示 | ||||
| NameError: | NameError: | ||||
| @@ -116,12 +105,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
| return {'loss': loss} | return {'loss': loss} | ||||
| model = Model() | model = Model() | ||||
| trainer = Trainer( | |||||
| train_data=dataset, | |||||
| model=model, | |||||
| use_tqdm=False, | |||||
| print_every=2 | |||||
| ) | |||||
| trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||||
| trainer.train() | trainer.train() | ||||
| """ | """ | ||||
| # 应该正确运行 | # 应该正确运行 | ||||
| @@ -147,12 +131,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
| model = Model() | model = Model() | ||||
| with self.assertRaises(NameError): | with self.assertRaises(NameError): | ||||
| trainer = Trainer( | |||||
| train_data=dataset, | |||||
| model=model, | |||||
| use_tqdm=False, | |||||
| print_every=2 | |||||
| ) | |||||
| trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||||
| trainer.train() | trainer.train() | ||||
| def test_trainer_suggestion4(self): | def test_trainer_suggestion4(self): | ||||
| @@ -175,12 +154,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
| model = Model() | model = Model() | ||||
| with self.assertRaises(NameError): | with self.assertRaises(NameError): | ||||
| trainer = Trainer( | |||||
| train_data=dataset, | |||||
| model=model, | |||||
| use_tqdm=False, | |||||
| print_every=2 | |||||
| ) | |||||
| trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||||
| def test_trainer_suggestion5(self): | def test_trainer_suggestion5(self): | ||||
| # 检查报错提示能否正确提醒用户 | # 检查报错提示能否正确提醒用户 | ||||
| @@ -203,12 +177,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
| return {'loss': loss} | return {'loss': loss} | ||||
| model = Model() | model = Model() | ||||
| trainer = Trainer( | |||||
| train_data=dataset, | |||||
| model=model, | |||||
| use_tqdm=False, | |||||
| print_every=2 | |||||
| ) | |||||
| trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||||
| def test_trainer_suggestion6(self): | def test_trainer_suggestion6(self): | ||||
| # 检查报错提示能否正确提醒用户 | # 检查报错提示能否正确提醒用户 | ||||
| @@ -233,14 +202,8 @@ class TrainerTestGround(unittest.TestCase): | |||||
| model = Model() | model = Model() | ||||
| with self.assertRaises(NameError): | with self.assertRaises(NameError): | ||||
| trainer = Trainer( | |||||
| train_data=dataset, | |||||
| model=model, | |||||
| dev_data=dataset, | |||||
| loss=CrossEntropyLoss(), | |||||
| metrics=AccuracyMetric(), | |||||
| use_tqdm=False, | |||||
| print_every=2) | |||||
| trainer = Trainer(train_data=dataset, model=model, loss=CrossEntropyLoss(), print_every=2, dev_data=dataset, | |||||
| metrics=AccuracyMetric(), use_tqdm=False) | |||||
| """ | """ | ||||
| def test_trainer_multiprocess(self): | def test_trainer_multiprocess(self): | ||||
| @@ -130,11 +130,8 @@ class ModelRunner(): | |||||
| tester = Tester(data=data, model=model, metrics=metrics, | tester = Tester(data=data, model=model, metrics=metrics, | ||||
| batch_size=BATCH_SIZE, verbose=0) | batch_size=BATCH_SIZE, verbose=0) | ||||
| before_train = tester.test() | before_train = tester.test() | ||||
| trainer = Trainer(model=model, train_data=data, dev_data=None, | |||||
| n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, | |||||
| loss=loss, | |||||
| save_path=None, | |||||
| use_tqdm=False) | |||||
| trainer = Trainer(train_data=data, model=model, loss=loss, batch_size=BATCH_SIZE, n_epochs=N_EPOCHS, | |||||
| dev_data=None, save_path=None, use_tqdm=False) | |||||
| trainer.train(load_best_model=False) | trainer.train(load_best_model=False) | ||||
| after_train = tester.test() | after_train = tester.test() | ||||
| for metric_name, v1 in before_train.items(): | for metric_name, v1 in before_train.items(): | ||||
| @@ -60,10 +60,10 @@ class TestTutorial(unittest.TestCase): | |||||
| print(test_data[0]) | print(test_data[0]) | ||||
| # 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具 | # 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具 | ||||
| from fastNLP.core.batch import Batch | |||||
| from fastNLP.core.batch import DataSetIter | |||||
| from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
| batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler()) | |||||
| batch_iterator = DataSetIter(dataset=train_data, batch_size=2, sampler=RandomSampler()) | |||||
| for batch_x, batch_y in batch_iterator: | for batch_x, batch_y in batch_iterator: | ||||
| print("batch_x has: ", batch_x) | print("batch_x has: ", batch_x) | ||||
| print("batch_y has: ", batch_y) | print("batch_y has: ", batch_y) | ||||
| @@ -85,21 +85,14 @@ class TestTutorial(unittest.TestCase): | |||||
| # 实例化Trainer,传入模型和数据,进行训练 | # 实例化Trainer,传入模型和数据,进行训练 | ||||
| # 先在test_data拟合(确保模型的实现是正确的) | # 先在test_data拟合(确保模型的实现是正确的) | ||||
| copy_model = deepcopy(model) | copy_model = deepcopy(model) | ||||
| overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data, | |||||
| loss=loss, | |||||
| metrics=metric, | |||||
| save_path=None, | |||||
| batch_size=32, | |||||
| n_epochs=5) | |||||
| overfit_trainer = Trainer(train_data=test_data, model=copy_model, loss=loss, batch_size=32, n_epochs=5, | |||||
| dev_data=test_data, metrics=metric, save_path=None) | |||||
| overfit_trainer.train() | overfit_trainer.train() | ||||
| # 用train_data训练,在test_data验证 | # 用train_data训练,在test_data验证 | ||||
| trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, | |||||
| loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||||
| metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||||
| save_path=None, | |||||
| batch_size=32, | |||||
| n_epochs=5) | |||||
| trainer = Trainer(train_data=train_data, model=model, loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||||
| batch_size=32, n_epochs=5, dev_data=test_data, | |||||
| metrics=AccuracyMetric(pred="predict", target="label_seq"), save_path=None) | |||||
| trainer.train() | trainer.train() | ||||
| print('Train finished!') | print('Train finished!') | ||||
| @@ -147,13 +140,8 @@ class TestTutorial(unittest.TestCase): | |||||
| from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam | from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam | ||||
| trainer = Trainer(model=model, | |||||
| train_data=train_data, | |||||
| dev_data=dev_data, | |||||
| loss=CrossEntropyLoss(), | |||||
| optimizer= Adam(), | |||||
| metrics=AccuracyMetric(target='target') | |||||
| ) | |||||
| trainer = Trainer(train_data=train_data, model=model, optimizer=Adam(), loss=CrossEntropyLoss(), | |||||
| dev_data=dev_data, metrics=AccuracyMetric(target='target')) | |||||
| trainer.train() | trainer.train() | ||||
| print('Train finished!') | print('Train finished!') | ||||