|
|
@@ -1,8 +1,8 @@ |
|
|
|
import unittest |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
from torch import nn |
|
|
|
import torch.nn.functional as F |
|
|
|
from torch import nn |
|
|
|
|
|
|
|
from fastNLP.core.dataset import DataSet |
|
|
|
from fastNLP.core.instance import Instance |
|
|
@@ -26,6 +26,7 @@ def prepare_fake_dataset(): |
|
|
|
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) |
|
|
|
return data_set |
|
|
|
|
|
|
|
|
|
|
|
def prepare_fake_dataset2(*args, size=100): |
|
|
|
ys = np.random.randint(4, size=100) |
|
|
|
data = {'y': ys} |
|
|
@@ -33,6 +34,7 @@ def prepare_fake_dataset2(*args, size=100): |
|
|
|
data[arg] = np.random.randn(size, 5) |
|
|
|
return DataSet(data=data) |
|
|
|
|
|
|
|
|
|
|
|
class TrainerTestGround(unittest.TestCase): |
|
|
|
def test_case(self): |
|
|
|
data_set = prepare_fake_dataset() |
|
|
@@ -55,15 +57,20 @@ class TrainerTestGround(unittest.TestCase): |
|
|
|
check_code_level=2, |
|
|
|
use_tqdm=True) |
|
|
|
trainer.train() |
|
|
|
""" |
|
|
|
# 应该正确运行 |
|
|
|
""" |
|
|
|
|
|
|
|
def test_trainer_suggestion1(self): |
|
|
|
# 检查报错提示能否正确提醒用户。 |
|
|
|
# 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。 |
|
|
|
dataset = prepare_fake_dataset2('x') |
|
|
|
|
|
|
|
class Model(nn.Module): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.fc = nn.Linear(5, 4) |
|
|
|
|
|
|
|
def forward(self, x1, x2, y): |
|
|
|
x1 = self.fc(x1) |
|
|
|
x2 = self.fc(x2) |
|
|
@@ -72,10 +79,12 @@ class TrainerTestGround(unittest.TestCase): |
|
|
|
return {'loss': loss} |
|
|
|
|
|
|
|
model = Model() |
|
|
|
trainer = Trainer( |
|
|
|
train_data=dataset, |
|
|
|
model=model |
|
|
|
) |
|
|
|
|
|
|
|
with self.assertRaises(NameError): |
|
|
|
trainer = Trainer( |
|
|
|
train_data=dataset, |
|
|
|
model=model |
|
|
|
) |
|
|
|
""" |
|
|
|
# 应该获取到的报错提示 |
|
|
|
NameError: |
|
|
@@ -91,10 +100,12 @@ class TrainerTestGround(unittest.TestCase): |
|
|
|
# 这里传入forward需要的数据,看是否可以运行 |
|
|
|
dataset = prepare_fake_dataset2('x1', 'x2') |
|
|
|
dataset.set_input('x1', 'x2', 'y', flag=True) |
|
|
|
|
|
|
|
class Model(nn.Module): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.fc = nn.Linear(5, 4) |
|
|
|
|
|
|
|
def forward(self, x1, x2, y): |
|
|
|
x1 = self.fc(x1) |
|
|
|
x2 = self.fc(x2) |
|
|
@@ -119,10 +130,12 @@ class TrainerTestGround(unittest.TestCase): |
|
|
|
# 这里传入forward需要的数据,但是forward没有返回loss这个key |
|
|
|
dataset = prepare_fake_dataset2('x1', 'x2') |
|
|
|
dataset.set_input('x1', 'x2', 'y', flag=True) |
|
|
|
|
|
|
|
class Model(nn.Module): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.fc = nn.Linear(5, 4) |
|
|
|
|
|
|
|
def forward(self, x1, x2, y): |
|
|
|
x1 = self.fc(x1) |
|
|
|
x2 = self.fc(x2) |
|
|
@@ -142,7 +155,6 @@ class TrainerTestGround(unittest.TestCase): |
|
|
|
# 应该正确运行 |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def test_case2(self): |
|
|
|
# check metrics Wrong |
|
|
|
data_set = prepare_fake_dataset2('x1', 'x2') |