From aea931812b75aa56106996906f647a1ac341aa30 Mon Sep 17 00:00:00 2001 From: yh Date: Wed, 5 Dec 2018 20:23:40 +0800 Subject: [PATCH] =?UTF-8?q?1.=20trainer=E4=B8=ADlosser=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E4=B8=BAloss?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/trainer.py | 6 +++--- fastNLP/core/utils.py | 1 - test/core/test_tester.py | 12 ++++++------ test/core/test_trainer.py | 19 ++++++++++--------- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 8f676279..45055be5 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -28,7 +28,7 @@ class Trainer(object): """Main Training Loop """ - def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, + def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, validate_every=-1, dev_data=None, use_cuda=False, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True): @@ -36,7 +36,7 @@ class Trainer(object): :param DataSet train_data: the training data :param torch.nn.modules.module model: a PyTorch model - :param LossBase losser: a loss object + :param LossBase loss: a loss object :param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics :param int n_epochs: the number of training epochs :param int batch_size: batch size for training and validation @@ -88,7 +88,7 @@ class Trainer(object): self.metric_key = None # prepare loss - losser = _prepare_losser(losser) + losser = _prepare_losser(loss) # sampler check if not isinstance(sampler, BaseSampler): diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 0e2bba07..508d5587 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -7,7 +7,6 @@ from collections import namedtuple import numpy as np import torch -from tqdm import tqdm CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', 'varargs'], verbose=False) diff --git a/test/core/test_tester.py b/test/core/test_tester.py index 99a8000e..d606c0b8 100644 --- a/test/core/test_tester.py +++ b/test/core/test_tester.py @@ -42,7 +42,6 @@ def prepare_fake_dataset2(*args, size=100): class TestTester(unittest.TestCase): def test_case_1(self): # 检查报错提示能否正确提醒用户 - # 这里传入多余参数,让其duplicate dataset = prepare_fake_dataset2('x1', 'x_unused') dataset.rename_field('x_unused', 'x2') dataset.set_input('x1', 'x2') @@ -60,8 +59,9 @@ class TestTester(unittest.TestCase): return {'preds': x} model = Model() - tester = Tester( - data=dataset, - model=model, - metrics=AccuracyMetric()) - tester.test() + with self.assertRaises(NameError): + tester = Tester( + data=dataset, + model=model, + metrics=AccuracyMetric()) + tester.test() diff --git a/test/core/test_trainer.py b/test/core/test_trainer.py index a69438ae..6f6fbbf3 100644 --- a/test/core/test_trainer.py +++ b/test/core/test_trainer.py @@ -48,7 +48,7 @@ class TrainerTestGround(unittest.TestCase): model = NaiveClassifier(2, 1) trainer = Trainer(train_set, model, - losser=BCELoss(pred="predict", target="y"), + loss=BCELoss(pred="predict", target="y"), metrics=AccuracyMetric(pred="predict", target="y"), n_epochs=10, batch_size=32, @@ -227,14 +227,15 @@ class TrainerTestGround(unittest.TestCase): return {'preds': x} model = Model() - trainer = Trainer( - train_data=dataset, - model=model, - dev_data=dataset, - losser=CrossEntropyLoss(), - metrics=AccuracyMetric(), - use_tqdm=False, - print_every=2) + with self.assertRaises(NameError): + trainer = Trainer( + train_data=dataset, + model=model, + dev_data=dataset, + loss=CrossEntropyLoss(), + metrics=AccuracyMetric(), + use_tqdm=False, + print_every=2) def test_case2(self): # check metrics Wrong