Browse Source

Merge pull request #51 from xuyige/test_code

Test code
tags/v0.1.0
Coet GitHub 6 years ago
parent
commit
b76d3d0827
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 282 additions and 1 deletions
  1. +2
    -0
      fastNLP/models/char_language_model.py
  2. +1
    -1
      fastNLP/modules/encoder/masked_rnn.py
  3. +18
    -0
      test/core/test_action.py
  4. +43
    -0
      test/core/test_preprocess.py
  5. +33
    -0
      test/core/test_trainer.py
  6. +7
    -0
      test/loader/config
  7. +75
    -0
      test/loader/test_loader.py
  8. +27
    -0
      test/modules/test_masked_rnn.py
  9. +30
    -0
      test/modules/test_other_modules.py
  10. +18
    -0
      test/modules/test_utils.py
  11. +28
    -0
      test/modules/test_variational_rnn.py

+ 2
- 0
fastNLP/models/char_language_model.py View File

@@ -142,6 +142,8 @@ class CharLM(BaseModel):
"char_dict": char_dict,
"reverse_word_dict": reverse_word_dict,
}
if not os.path.exists("cache"):
os.mkdir("cache")
torch.save(objects, "cache/prep.pt")
print("Preprocess done.")



+ 1
- 1
fastNLP/modules/encoder/masked_rnn.py View File

@@ -273,7 +273,7 @@ class MaskedRNNBase(nn.Module):
hx = (hx, hx)

func = AutogradMaskedStep(num_layers=self.num_layers,
dropout=self.dropout,
dropout=self.step_dropout,
train=self.training,
lstm=lstm)



+ 18
- 0
test/core/test_action.py View File

@@ -0,0 +1,18 @@
import os

import unittest

from fastNLP.core.action import Action, Batchifier, SequentialSampler

class TestAction(unittest.TestCase):
def test_case_1(self):
x = [1, 2, 3, 4, 5, 6, 7, 8]
y = [1, 1, 1, 1, 2, 2, 2, 2]
data = []
for i in range(len(x)):
data.append([[x[i]], [y[i]]])
data = Batchifier(SequentialSampler(data), batch_size=2, drop_last=False)
action = Action()
for batch_x in action.make_batch(data, use_cuda=False, output_length=True, max_len=None):
print(batch_x)


+ 43
- 0
test/core/test_preprocess.py View File

@@ -0,0 +1,43 @@
import os
import unittest

from fastNLP.core.preprocess import SeqLabelPreprocess


class TestSeqLabelPreprocess(unittest.TestCase):
def test_case_1(self):
data = [
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
]

if os.path.exists("./save"):
for root, dirs, files in os.walk("./save", topdown=False):
for name in files:
os.remove(os.path.join(root, name))
for name in dirs:
os.rmdir(os.path.join(root, name))
result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4,
pickle_path="./save")
result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4,
pickle_path="./save")
if os.path.exists("./save"):
for root, dirs, files in os.walk("./save", topdown=False):
for name in files:
os.remove(os.path.join(root, name))
for name in dirs:
os.rmdir(os.path.join(root, name))
result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data,
pickle_path="./save", train_dev_split=0.4,
cross_val=True)
result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data,
pickle_path="./save", train_dev_split=0.4,
cross_val=True)

+ 33
- 0
test/core/test_trainer.py View File

@@ -0,0 +1,33 @@
import os

import torch.nn as nn
import unittest

from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.core.loss import Loss
from fastNLP.core.optimizer import Optimizer
from fastNLP.models.sequence_modeling import SeqLabeling

class TestTrainer(unittest.TestCase):
def test_case_1(self):
args = {"epochs": 3, "batch_size": 8, "validate": True, "use_cuda": True, "pickle_path": "./save/",
"save_best_dev": True, "model_name": "default_model_name.pkl",
"loss": Loss(None),
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
"vocab_size": 20,
"word_emb_dim": 100,
"rnn_hidden_units": 100,
"num_classes": 3
}
trainer = SeqLabelTrainer()
train_data = [
[[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]],
[[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]],
[[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]],
[[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]],
[[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]],
[[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]],
]
dev_data = train_data
model = SeqLabeling(args)
trainer.train(network=model, train_data=train_data, dev_data=dev_data)

+ 7
- 0
test/loader/config View File

@@ -0,0 +1,7 @@
[test]
x = 1
y = 2
z = 3
input = [1,2,3]
text = "this is text"
doubles = 0.5

+ 75
- 0
test/loader/test_loader.py View File

@@ -0,0 +1,75 @@
import os
import configparser

import json
import unittest


from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, POSDatasetLoader, LMDatasetLoader

class TestConfigLoader(unittest.TestCase):
def test_case_ConfigLoader(self):

def read_section_from_config(config_path, section_name):
dict = {}
if not os.path.exists(config_path):
raise FileNotFoundError("config file {} NOT found.".format(config_path))
cfg = configparser.ConfigParser()
cfg.read(config_path)
if section_name not in cfg:
raise AttributeError("config file {} do NOT have section {}".format(
config_path, section_name
))
gen_sec = cfg[section_name]
for s in gen_sec.keys():
try:
val = json.loads(gen_sec[s])
dict[s] = val
except Exception as e:
raise AttributeError("json can NOT load {} in section {}, config file {}".format(
s, section_name, config_path
))
return dict

test_arg = ConfigSection()
ConfigLoader("config", "").load_config(os.path.join("./test/loader", "config"), {"test": test_arg})
#ConfigLoader("config", "").load_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config",
# {"test": test_arg})

#dict = read_section_from_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", "test")
dict = read_section_from_config(os.path.join("./test/loader", "config"), "test")

for sec in dict:
if (sec not in test_arg) or (dict[sec] != test_arg[sec]):
raise AttributeError("ERROR")

for sec in test_arg.__dict__.keys():
if (sec not in dict) or (dict[sec] != test_arg[sec]):
raise AttributeError("ERROR")

try:
not_exist = test_arg["NOT EXIST"]
except Exception as e:
pass

print("pass config test!")


class TestDatasetLoader(unittest.TestCase):
def test_case_TokenizeDatasetLoader(self):
loader = TokenizeDatasetLoader("cws_pku_utf_8", "./test/data_for_tests/cws_pku_utf_8")
data = loader.load_pku(max_seq_len=32)
print("pass TokenizeDatasetLoader test!")

def test_case_POSDatasetLoader(self):
loader = POSDatasetLoader("people", "./test/data_for_tests/people.txt")
data = loader.load()
datas = loader.load_lines()
print("pass POSDatasetLoader test!")

def test_case_LMDatasetLoader(self):
loader = LMDatasetLoader("cws_pku_utf_8", "./test/data_for_tests/cws_pku_utf_8")
data = loader.load()
datas = loader.load_lines()
print("pass TokenizeDatasetLoader test!")

+ 27
- 0
test/modules/test_masked_rnn.py View File

@@ -0,0 +1,27 @@

import torch
import unittest

from fastNLP.modules.encoder.masked_rnn import MaskedRNN

class TestMaskedRnn(unittest.TestCase):
def test_case_1(self):
masked_rnn = MaskedRNN(input_size=1, hidden_size=1, bidirectional=True, batch_first=True)
x = torch.tensor([[[1.0], [2.0]]])
print(x.size())
y = masked_rnn(x)
mask = torch.tensor([[[1], [1]]])
y = masked_rnn(x, mask=mask)
mask = torch.tensor([[[1], [0]]])
y = masked_rnn(x, mask=mask)

def test_case_2(self):
masked_rnn = MaskedRNN(input_size=1, hidden_size=1, bidirectional=False, batch_first=True)
x = torch.tensor([[[1.0], [2.0]]])
print(x.size())
y = masked_rnn(x)
mask = torch.tensor([[[1], [1]]])
y = masked_rnn(x, mask=mask)
xx = torch.tensor([[[1.0]]])
y = masked_rnn.step(xx)
y = masked_rnn.step(xx, mask=mask)

+ 30
- 0
test/modules/test_other_modules.py View File

@@ -0,0 +1,30 @@


import torch
import unittest

from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear


class TestGroupNorm(unittest.TestCase):
def test_case_1(self):
gn = GroupNorm(num_features=1, num_groups=10, eps=1.5e-5)
x = torch.randn((20, 50, 10))
y = gn(x)


class TestLayerNormalization(unittest.TestCase):
def test_case_1(self):
ln = LayerNormalization(d_hid=5, eps=2e-3)
x = torch.randn((20, 50, 5))
y = ln(x)


class TestBiLinear(unittest.TestCase):
def test_case_1(self):
bl = BiLinear(n_left=5, n_right=5, n_out=10, bias=True)
x_left = torch.randn((7, 10, 20, 5))
x_right = torch.randn((7, 10, 20, 5))
y = bl(x_left, x_right)
print(bl)
bl2 = BiLinear(n_left=15, n_right=15, n_out=10, bias=True)

+ 18
- 0
test/modules/test_utils.py View File

@@ -0,0 +1,18 @@

import torch
import numpy as np
import unittest

import fastNLP.modules.utils as utils

class TestUtils(unittest.TestCase):
def test_case_1(self):
a = torch.tensor([
[1, 2, 3, 4, 5], [2, 3, 4, 5, 6]
])
utils.orthogonal(a)

def test_case_2(self):
a = np.random.rand(100, 100)
utils.mst(a)


+ 28
- 0
test/modules/test_variational_rnn.py View File

@@ -0,0 +1,28 @@

import torch
import unittest

from fastNLP.modules.encoder.variational_rnn import VarMaskedFastLSTM

class TestMaskedRnn(unittest.TestCase):
def test_case_1(self):
masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True)
x = torch.tensor([[[1.0], [2.0]]])
print(x.size())
y = masked_rnn(x)
mask = torch.tensor([[[1], [1]]])
y = masked_rnn(x, mask=mask)
mask = torch.tensor([[[1], [0]]])
y = masked_rnn(x, mask=mask)

def test_case_2(self):
masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=False, batch_first=True)
x = torch.tensor([[[1.0], [2.0]]])
print(x.size())
y = masked_rnn(x)
mask = torch.tensor([[[1], [1]]])
y = masked_rnn(x, mask=mask)
xx = torch.tensor([[[1.0]]])
#y, hidden = masked_rnn.step(xx)
#step() still has a bug
#y, hidden = masked_rnn.step(xx, mask=mask)

Loading…
Cancel
Save