diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index a1ece0aa..0a94b26c 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -83,12 +83,12 @@ class FieldArray(object): raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.") if len(val) == 0: raise RuntimeError("Cannot append an empty list.") - val_list_type = [type(_) for _ in val] # type check + val_list_type = set([type(_) for _ in val]) # type check if len(val_list_type) == 2 and int in val_list_type and float in val_list_type: # up-cast int to float val_type = float elif len(val_list_type) == 1: - val_type = val_list_type[0] + val_type = val_list_type.pop() else: raise RuntimeError("Cannot append a list of {}".format(val_list_type)) else: diff --git a/requirements.txt b/requirements.txt index 91a3f040..60ab7849 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ numpy>=1.14.2 torch>=0.4.0 tensorboardX +tqdm \ No newline at end of file diff --git a/test/core/test_trainer.py b/test/core/test_trainer.py index ed4cc38d..2b14aa11 100644 --- a/test/core/test_trainer.py +++ b/test/core/test_trainer.py @@ -1,8 +1,8 @@ import unittest import numpy as np -from torch import nn import torch.nn.functional as F +from torch import nn from fastNLP.core.dataset import DataSet from fastNLP.core.instance import Instance @@ -26,6 +26,7 @@ def prepare_fake_dataset(): [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) return data_set + def prepare_fake_dataset2(*args, size=100): ys = np.random.randint(4, size=100) data = {'y': ys} @@ -33,6 +34,7 @@ def prepare_fake_dataset2(*args, size=100): data[arg] = np.random.randn(size, 5) return DataSet(data=data) + class TrainerTestGround(unittest.TestCase): def test_case(self): data_set = prepare_fake_dataset() @@ -55,15 +57,20 @@ class TrainerTestGround(unittest.TestCase): check_code_level=2, use_tqdm=True) trainer.train() + """ + # 应该正确运行 + """ def test_trainer_suggestion1(self): # 检查报错提示能否正确提醒用户。 # 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。 dataset = prepare_fake_dataset2('x') + class Model(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(5, 4) + def forward(self, x1, x2, y): x1 = self.fc(x1) x2 = self.fc(x2) @@ -72,10 +79,12 @@ class TrainerTestGround(unittest.TestCase): return {'loss': loss} model = Model() - trainer = Trainer( - train_data=dataset, - model=model - ) + + with self.assertRaises(NameError): + trainer = Trainer( + train_data=dataset, + model=model + ) """ # 应该获取到的报错提示 NameError: @@ -91,10 +100,12 @@ class TrainerTestGround(unittest.TestCase): # 这里传入forward需要的数据,看是否可以运行 dataset = prepare_fake_dataset2('x1', 'x2') dataset.set_input('x1', 'x2', 'y', flag=True) + class Model(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(5, 4) + def forward(self, x1, x2, y): x1 = self.fc(x1) x2 = self.fc(x2) @@ -119,10 +130,12 @@ class TrainerTestGround(unittest.TestCase): # 这里传入forward需要的数据,但是forward没有返回loss这个key dataset = prepare_fake_dataset2('x1', 'x2') dataset.set_input('x1', 'x2', 'y', flag=True) + class Model(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(5, 4) + def forward(self, x1, x2, y): x1 = self.fc(x1) x2 = self.fc(x2) @@ -142,7 +155,6 @@ class TrainerTestGround(unittest.TestCase): # 应该正确运行 """ - def test_case2(self): # check metrics Wrong data_set = prepare_fake_dataset2('x1', 'x2')