Browse Source

* add tqdm in requirements.txt

* fix FieldArray type check bugs
tags/v0.2.0^2
FengZiYjun 5 years ago
parent
commit
4b099bb0dd
3 changed files with 21 additions and 8 deletions
  1. +2
    -2
      fastNLP/core/fieldarray.py
  2. +1
    -0
      requirements.txt
  3. +18
    -6
      test/core/test_trainer.py

+ 2
- 2
fastNLP/core/fieldarray.py View File

@@ -83,12 +83,12 @@ class FieldArray(object):
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.")
if len(val) == 0:
raise RuntimeError("Cannot append an empty list.")
val_list_type = [type(_) for _ in val] # type check
val_list_type = set([type(_) for _ in val]) # type check
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type:
# up-cast int to float
val_type = float
elif len(val_list_type) == 1:
val_type = val_list_type[0]
val_type = val_list_type.pop()
else:
raise RuntimeError("Cannot append a list of {}".format(val_list_type))
else:


+ 1
- 0
requirements.txt View File

@@ -1,3 +1,4 @@
numpy>=1.14.2
torch>=0.4.0
tensorboardX
tqdm

+ 18
- 6
test/core/test_trainer.py View File

@@ -1,8 +1,8 @@
import unittest

import numpy as np
from torch import nn
import torch.nn.functional as F
from torch import nn

from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
@@ -26,6 +26,7 @@ def prepare_fake_dataset():
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
return data_set


def prepare_fake_dataset2(*args, size=100):
ys = np.random.randint(4, size=100)
data = {'y': ys}
@@ -33,6 +34,7 @@ def prepare_fake_dataset2(*args, size=100):
data[arg] = np.random.randn(size, 5)
return DataSet(data=data)


class TrainerTestGround(unittest.TestCase):
def test_case(self):
data_set = prepare_fake_dataset()
@@ -55,15 +57,20 @@ class TrainerTestGround(unittest.TestCase):
check_code_level=2,
use_tqdm=True)
trainer.train()
"""
# 应该正确运行
"""

def test_trainer_suggestion1(self):
# 检查报错提示能否正确提醒用户。
# 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。
dataset = prepare_fake_dataset2('x')

class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)

def forward(self, x1, x2, y):
x1 = self.fc(x1)
x2 = self.fc(x2)
@@ -72,10 +79,12 @@ class TrainerTestGround(unittest.TestCase):
return {'loss': loss}

model = Model()
trainer = Trainer(
train_data=dataset,
model=model
)

with self.assertRaises(NameError):
trainer = Trainer(
train_data=dataset,
model=model
)
"""
# 应该获取到的报错提示
NameError:
@@ -91,10 +100,12 @@ class TrainerTestGround(unittest.TestCase):
# 这里传入forward需要的数据,看是否可以运行
dataset = prepare_fake_dataset2('x1', 'x2')
dataset.set_input('x1', 'x2', 'y', flag=True)

class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)

def forward(self, x1, x2, y):
x1 = self.fc(x1)
x2 = self.fc(x2)
@@ -119,10 +130,12 @@ class TrainerTestGround(unittest.TestCase):
# 这里传入forward需要的数据,但是forward没有返回loss这个key
dataset = prepare_fake_dataset2('x1', 'x2')
dataset.set_input('x1', 'x2', 'y', flag=True)

class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)

def forward(self, x1, x2, y):
x1 = self.fc(x1)
x2 = self.fc(x2)
@@ -142,7 +155,6 @@ class TrainerTestGround(unittest.TestCase):
# 应该正确运行
"""


def test_case2(self):
# check metrics Wrong
data_set = prepare_fake_dataset2('x1', 'x2')

Loading…
Cancel
Save