Browse Source

1. trainer中losser修改为loss

tags/v0.2.0^2
yh 5 years ago
parent
commit
aea931812b
4 changed files with 19 additions and 19 deletions
  1. +3
    -3
      fastNLP/core/trainer.py
  2. +0
    -1
      fastNLP/core/utils.py
  3. +6
    -6
      test/core/test_tester.py
  4. +10
    -9
      test/core/test_trainer.py

+ 3
- 3
fastNLP/core/trainer.py View File

@@ -28,7 +28,7 @@ class Trainer(object):
"""Main Training Loop

"""
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
validate_every=-1, dev_data=None, use_cuda=False, save_path=None,
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0,
metric_key=None, sampler=RandomSampler(), use_tqdm=True):
@@ -36,7 +36,7 @@ class Trainer(object):

:param DataSet train_data: the training data
:param torch.nn.modules.module model: a PyTorch model
:param LossBase losser: a loss object
:param LossBase loss: a loss object
:param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics
:param int n_epochs: the number of training epochs
:param int batch_size: batch size for training and validation
@@ -88,7 +88,7 @@ class Trainer(object):
self.metric_key = None

# prepare loss
losser = _prepare_losser(losser)
losser = _prepare_losser(loss)

# sampler check
if not isinstance(sampler, BaseSampler):


+ 0
- 1
fastNLP/core/utils.py View File

@@ -7,7 +7,6 @@ from collections import namedtuple

import numpy as np
import torch
from tqdm import tqdm

CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs'], verbose=False)


+ 6
- 6
test/core/test_tester.py View File

@@ -42,7 +42,6 @@ def prepare_fake_dataset2(*args, size=100):
class TestTester(unittest.TestCase):
def test_case_1(self):
# 检查报错提示能否正确提醒用户
# 这里传入多余参数,让其duplicate
dataset = prepare_fake_dataset2('x1', 'x_unused')
dataset.rename_field('x_unused', 'x2')
dataset.set_input('x1', 'x2')
@@ -60,8 +59,9 @@ class TestTester(unittest.TestCase):
return {'preds': x}

model = Model()
tester = Tester(
data=dataset,
model=model,
metrics=AccuracyMetric())
tester.test()
with self.assertRaises(NameError):
tester = Tester(
data=dataset,
model=model,
metrics=AccuracyMetric())
tester.test()

+ 10
- 9
test/core/test_trainer.py View File

@@ -48,7 +48,7 @@ class TrainerTestGround(unittest.TestCase):
model = NaiveClassifier(2, 1)

trainer = Trainer(train_set, model,
losser=BCELoss(pred="predict", target="y"),
loss=BCELoss(pred="predict", target="y"),
metrics=AccuracyMetric(pred="predict", target="y"),
n_epochs=10,
batch_size=32,
@@ -227,14 +227,15 @@ class TrainerTestGround(unittest.TestCase):
return {'preds': x}

model = Model()
trainer = Trainer(
train_data=dataset,
model=model,
dev_data=dataset,
losser=CrossEntropyLoss(),
metrics=AccuracyMetric(),
use_tqdm=False,
print_every=2)
with self.assertRaises(NameError):
trainer = Trainer(
train_data=dataset,
model=model,
dev_data=dataset,
loss=CrossEntropyLoss(),
metrics=AccuracyMetric(),
use_tqdm=False,
print_every=2)

def test_case2(self):
# check metrics Wrong


Loading…
Cancel
Save