- DataSet's __init__ takes a function as argument, rather than class object - Preprocessor is about to remove. Don't use anymore. - Remove cross_validate in trainer, because it is rarely used and wired - Loader.load is expected to be a static method - Delete sth. in other_modules.py - Add more tests - Delete extra sample datatags/v0.1.0^2
@@ -70,18 +70,18 @@ class DataSet(list): | |||
""" | |||
def __init__(self, name="", instances=None, loader=None): | |||
def __init__(self, name="", instances=None, load_func=None): | |||
""" | |||
:param name: str, the name of the dataset. (default: "") | |||
:param instances: list of Instance objects. (default: None) | |||
:param load_func: a function that takes the dataset path (string) as input and returns multi-level lists. | |||
""" | |||
list.__init__([]) | |||
self.name = name | |||
if instances is not None: | |||
self.extend(instances) | |||
self.dataset_loader = loader | |||
self.data_set_load_func = load_func | |||
def index_all(self, vocab): | |||
for ins in self: | |||
@@ -117,15 +117,15 @@ class DataSet(list): | |||
return lengths | |||
def convert(self, data): | |||
"""Convert lists of strings into Instances with Fields""" | |||
"""Convert lists of strings into Instances with Fields, creating Vocabulary for labeled data. Used in Training.""" | |||
raise NotImplementedError | |||
def convert_with_vocabs(self, data, vocabs): | |||
"""Convert lists of strings into Instances with Fields, using existing Vocabulary. Useful in predicting.""" | |||
"""Convert lists of strings into Instances with Fields, using existing Vocabulary, with labels. Used in Testing.""" | |||
raise NotImplementedError | |||
def convert_for_infer(self, data, vocabs): | |||
"""Convert lists of strings into Instances with Fields.""" | |||
"""Convert lists of strings into Instances with Fields, using existing Vocabulary, without labels. Used in predicting.""" | |||
def load(self, data_path, vocabs=None, infer=False): | |||
"""Load data from the given files. | |||
@@ -135,7 +135,7 @@ class DataSet(list): | |||
:param vocabs: dict of (name: Vocabulary object), used to index data. If not provided, a new vocabulary will be constructed. | |||
""" | |||
raw_data = self.dataset_loader.load(data_path) | |||
raw_data = self.data_set_load_func(data_path) | |||
if infer is True: | |||
self.convert_for_infer(raw_data, vocabs) | |||
else: | |||
@@ -145,7 +145,7 @@ class DataSet(list): | |||
self.convert(raw_data) | |||
def load_raw(self, raw_data, vocabs): | |||
""" | |||
"""Load raw data without loader. Used in FastNLP class. | |||
:param raw_data: | |||
:param vocabs: | |||
@@ -174,8 +174,8 @@ class DataSet(list): | |||
class SeqLabelDataSet(DataSet): | |||
def __init__(self, instances=None, loader=POSDataSetLoader()): | |||
super(SeqLabelDataSet, self).__init__(name="", instances=instances, loader=loader) | |||
def __init__(self, instances=None, load_func=POSDataSetLoader().load): | |||
super(SeqLabelDataSet, self).__init__(name="", instances=instances, load_func=load_func) | |||
self.word_vocab = Vocabulary() | |||
self.label_vocab = Vocabulary() | |||
@@ -231,8 +231,8 @@ class SeqLabelDataSet(DataSet): | |||
class TextClassifyDataSet(DataSet): | |||
def __init__(self, instances=None, loader=ClassDataSetLoader()): | |||
super(TextClassifyDataSet, self).__init__(name="", instances=instances, loader=loader) | |||
def __init__(self, instances=None, load_func=ClassDataSetLoader().load): | |||
super(TextClassifyDataSet, self).__init__(name="", instances=instances, load_func=load_func) | |||
self.word_vocab = Vocabulary() | |||
self.label_vocab = Vocabulary(need_default=False) | |||
@@ -285,10 +285,3 @@ def change_field_is_target(data_set, field_name, new_target): | |||
for inst in data_set: | |||
inst.fields[field_name].is_target = new_target | |||
if __name__ == "__main__": | |||
data_set = SeqLabelDataSet() | |||
data_set.load("../../test/data_for_tests/people.txt") | |||
a, b = data_set.split(0.3) | |||
print(type(data_set), type(a), type(b)) | |||
print(len(data_set), len(a), len(b)) |
@@ -78,6 +78,7 @@ class Preprocessor(object): | |||
is only available when label_is_seq is True. Default: False. | |||
:param add_char_field: bool, whether to add character representations to all TextFields. Default: False. | |||
""" | |||
print("Preprocessor is about to deprecate. Please use DataSet class.") | |||
self.data_vocab = Vocabulary() | |||
if label_is_seq is True: | |||
if share_vocab is True: | |||
@@ -307,11 +308,3 @@ class ClassPreprocess(Preprocessor): | |||
print("[FastNLP warning] ClassPreprocess is about to deprecate. Please use Preprocess directly.") | |||
super(ClassPreprocess, self).__init__() | |||
if __name__ == "__main__": | |||
p = Preprocessor() | |||
train_dev_data = [[["I", "am", "a", "good", "student", "."], "0"], | |||
[["You", "are", "pretty", "."], "1"] | |||
] | |||
training_set = p.run(train_dev_data) | |||
print(training_set) |
@@ -1,4 +1,3 @@ | |||
import copy | |||
import os | |||
import time | |||
from datetime import timedelta | |||
@@ -178,31 +177,6 @@ class Trainer(object): | |||
logger.info(print_output) | |||
step += 1 | |||
def cross_validate(self, network, train_data_cv, dev_data_cv): | |||
"""Training with cross validation. | |||
:param network: the model | |||
:param train_data_cv: four-level list, of shape [num_folds, num_examples, 2, ?] | |||
:param dev_data_cv: four-level list, of shape [num_folds, num_examples, 2, ?] | |||
""" | |||
if len(train_data_cv) != len(dev_data_cv): | |||
logger.error("the number of folds in train and dev data unequals {}!={}".format(len(train_data_cv), | |||
len(dev_data_cv))) | |||
raise RuntimeError("the number of folds in train and dev data unequals") | |||
if self.validate is False: | |||
logger.warn("Cross validation requires self.validate to be True. Please turn it on. ") | |||
print("[warning] Cross validation requires self.validate to be True. Please turn it on. ") | |||
self.validate = True | |||
n_fold = len(train_data_cv) | |||
logger.info("perform {} folds cross validation.".format(n_fold)) | |||
for i in range(n_fold): | |||
print("CV:", i) | |||
logger.info("running the {} of {} folds cross validation".format(i + 1, n_fold)) | |||
network_copy = copy.deepcopy(network) | |||
self.train(network_copy, train_data_cv[i], dev_data_cv[i]) | |||
def mode(self, model, is_test=False): | |||
"""Train mode or Test mode. This is for PyTorch currently. | |||
@@ -1,11 +1,10 @@ | |||
import os | |||
from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet | |||
from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer | |||
from fastNLP.core.preprocess import load_pickle | |||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
from fastNLP.loader.model_loader import ModelLoader | |||
from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet | |||
""" | |||
mapping from model name to [URL, file_name.class_name, model_pickle_name] | |||
@@ -73,7 +72,7 @@ class FastNLP(object): | |||
:param model_dir: this directory should contain the following files: | |||
1. a trained model | |||
2. a config file, which is a fastNLP's configuration. | |||
3. a Vocab file, which is a pickle object of a Vocab instance. | |||
3. two Vocab files, which are pickle objects of Vocab instances, representing feature and label vocabs. | |||
""" | |||
self.model_dir = model_dir | |||
self.model = None | |||
@@ -192,7 +191,7 @@ class FastNLP(object): | |||
def _load(self, model_dir, model_name): | |||
# To do | |||
return 0 | |||
def _download(self, model_name, url): | |||
@@ -202,7 +201,7 @@ class FastNLP(object): | |||
:param url: | |||
""" | |||
print("Downloading {} from {}".format(model_name, url)) | |||
# To do | |||
# TODO: download model via url | |||
def model_exist(self, model_dir): | |||
""" | |||
@@ -3,12 +3,14 @@ class BaseLoader(object): | |||
def __init__(self): | |||
super(BaseLoader, self).__init__() | |||
def load_lines(self, data_path): | |||
@staticmethod | |||
def load_lines(data_path): | |||
with open(data_path, "r", encoding="utf=8") as f: | |||
text = f.readlines() | |||
return [line.strip() for line in text] | |||
def load(self, data_path): | |||
@staticmethod | |||
def load(data_path): | |||
with open(data_path, "r", encoding="utf-8") as f: | |||
text = f.readlines() | |||
return [[word for word in sent.strip()] for sent in text] | |||
@@ -84,7 +84,8 @@ class TokenizeDataSetLoader(DataSetLoader): | |||
def __init__(self): | |||
super(TokenizeDataSetLoader, self).__init__() | |||
def load(self, data_path, max_seq_len=32): | |||
@staticmethod | |||
def load(data_path, max_seq_len=32): | |||
""" | |||
load pku dataset for Chinese word segmentation | |||
CWS (Chinese Word Segmentation) pku training dataset format: | |||
@@ -196,30 +196,3 @@ class BiAffine(nn.Module): | |||
output = output * mask_d.unsqueeze(1).unsqueeze(3) * mask_e.unsqueeze(1).unsqueeze(2) | |||
return output | |||
class Transpose(nn.Module): | |||
def __init__(self, x, y): | |||
super(Transpose, self).__init__() | |||
self.x = x | |||
self.y = y | |||
def forward(self, x): | |||
return x.transpose(self.x, self.y) | |||
class WordDropout(nn.Module): | |||
def __init__(self, dropout_rate, drop_to_token): | |||
super(WordDropout, self).__init__() | |||
self.dropout_rate = dropout_rate | |||
self.drop_to_token = drop_to_token | |||
def forward(self, word_idx): | |||
if not self.training: | |||
return word_idx | |||
drop_mask = torch.rand(word_idx.shape) < self.dropout_rate | |||
if word_idx.device.type == 'cuda': | |||
drop_mask = drop_mask.cuda() | |||
drop_mask = drop_mask.long() | |||
output = drop_mask * self.drop_to_token + (1 - drop_mask) * word_idx | |||
return output |
@@ -104,7 +104,8 @@ class ConfigSaver(object): | |||
:return: | |||
""" | |||
section_file = self._get_section(section_name) | |||
if len(section_file.__dict__.keys()) == 0:#the section not in file before | |||
if len(section_file.__dict__.keys()) == 0: # the section not in the file before | |||
# append this section to config file | |||
with open(self.file_path, 'a') as f: | |||
f.write('[' + section_name + ']\n') | |||
for k in section.__dict__.keys(): | |||
@@ -114,9 +115,11 @@ class ConfigSaver(object): | |||
else: | |||
f.write(str(section[k]) + '\n\n') | |||
else: | |||
# the section exists | |||
change_file = False | |||
for k in section.__dict__.keys(): | |||
if k not in section_file: | |||
# find a new key in this section | |||
change_file = True | |||
break | |||
if section_file[k] != section[k]: | |||
@@ -0,0 +1,243 @@ | |||
import unittest | |||
from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet | |||
from fastNLP.core.dataset import create_dataset_from_lists | |||
class TestDataSet(unittest.TestCase): | |||
labeled_data_list = [ | |||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
] | |||
unlabeled_data_list = [ | |||
["a", "b", "e", "d"], | |||
["a", "b", "e", "d"], | |||
["a", "b", "e", "d"] | |||
] | |||
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} | |||
label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4} | |||
def test_case_1(self): | |||
data_set = create_dataset_from_lists(self.labeled_data_list, self.word_vocab, has_target=True, | |||
label_vocab=self.label_vocab) | |||
self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||
self.assertTrue(len(data_set) > 0) | |||
self.assertTrue(hasattr(data_set[0], "fields")) | |||
self.assertTrue("word_seq" in data_set[0].fields) | |||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||
self.assertEqual(data_set[0].fields["word_seq"]._index, | |||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||
self.assertTrue("label_seq" in data_set[0].fields) | |||
self.assertTrue(hasattr(data_set[0].fields["label_seq"], "text")) | |||
self.assertTrue(hasattr(data_set[0].fields["label_seq"], "_index")) | |||
self.assertEqual(data_set[0].fields["label_seq"].text, self.labeled_data_list[0][1]) | |||
self.assertEqual(data_set[0].fields["label_seq"]._index, | |||
[self.label_vocab[c] for c in self.labeled_data_list[0][1]]) | |||
def test_case_2(self): | |||
data_set = create_dataset_from_lists(self.unlabeled_data_list, self.word_vocab, has_target=False) | |||
self.assertEqual(len(data_set), len(self.unlabeled_data_list)) | |||
self.assertTrue(len(data_set) > 0) | |||
self.assertTrue(hasattr(data_set[0], "fields")) | |||
self.assertTrue("word_seq" in data_set[0].fields) | |||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||
self.assertEqual(data_set[0].fields["word_seq"].text, self.unlabeled_data_list[0]) | |||
self.assertEqual(data_set[0].fields["word_seq"]._index, | |||
[self.word_vocab[c] for c in self.unlabeled_data_list[0]]) | |||
class TestDataSetConvertion(unittest.TestCase): | |||
labeled_data_list = [ | |||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
] | |||
unlabeled_data_list = [ | |||
["a", "b", "e", "d"], | |||
["a", "b", "e", "d"], | |||
["a", "b", "e", "d"] | |||
] | |||
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} | |||
label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4} | |||
def test_case_1(self): | |||
def loader(path): | |||
labeled_data_list = [ | |||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
] | |||
return labeled_data_list | |||
data_set = SeqLabelDataSet(load_func=loader) | |||
data_set.load("any_path") | |||
self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||
self.assertTrue(len(data_set) > 0) | |||
self.assertTrue(hasattr(data_set[0], "fields")) | |||
self.assertTrue("word_seq" in data_set[0].fields) | |||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||
self.assertTrue("truth" in data_set[0].fields) | |||
self.assertTrue(hasattr(data_set[0].fields["truth"], "text")) | |||
self.assertTrue(hasattr(data_set[0].fields["truth"], "_index")) | |||
self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1]) | |||
self.assertTrue("word_seq_origin_len" in data_set[0].fields) | |||
def test_case_2(self): | |||
def loader(path): | |||
unlabeled_data_list = [ | |||
["a", "b", "e", "d"], | |||
["a", "b", "e", "d"], | |||
["a", "b", "e", "d"] | |||
] | |||
return unlabeled_data_list | |||
data_set = SeqLabelDataSet(load_func=loader) | |||
data_set.load("any_path", vocabs={"word_vocab": self.word_vocab}, infer=True) | |||
self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||
self.assertTrue(len(data_set) > 0) | |||
self.assertTrue(hasattr(data_set[0], "fields")) | |||
self.assertTrue("word_seq" in data_set[0].fields) | |||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||
self.assertEqual(data_set[0].fields["word_seq"]._index, | |||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||
self.assertTrue("word_seq_origin_len" in data_set[0].fields) | |||
def test_case_3(self): | |||
def loader(path): | |||
labeled_data_list = [ | |||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
] | |||
return labeled_data_list | |||
data_set = SeqLabelDataSet(load_func=loader) | |||
data_set.load("any_path", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab}) | |||
self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||
self.assertTrue(len(data_set) > 0) | |||
self.assertTrue(hasattr(data_set[0], "fields")) | |||
self.assertTrue("word_seq" in data_set[0].fields) | |||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||
self.assertEqual(data_set[0].fields["word_seq"]._index, | |||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||
self.assertTrue("truth" in data_set[0].fields) | |||
self.assertTrue(hasattr(data_set[0].fields["truth"], "text")) | |||
self.assertTrue(hasattr(data_set[0].fields["truth"], "_index")) | |||
self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1]) | |||
self.assertEqual(data_set[0].fields["truth"]._index, | |||
[self.label_vocab[c] for c in self.labeled_data_list[0][1]]) | |||
self.assertTrue("word_seq_origin_len" in data_set[0].fields) | |||
class TestDataSetConvertionHHH(unittest.TestCase): | |||
labeled_data_list = [ | |||
[["a", "b", "e", "d"], "A"], | |||
[["a", "b", "e", "d"], "C"], | |||
[["a", "b", "e", "d"], "B"], | |||
] | |||
unlabeled_data_list = [ | |||
["a", "b", "e", "d"], | |||
["a", "b", "e", "d"], | |||
["a", "b", "e", "d"] | |||
] | |||
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} | |||
label_vocab = {"A": 1, "B": 2, "C": 3} | |||
def test_case_1(self): | |||
def loader(path): | |||
labeled_data_list = [ | |||
[["a", "b", "e", "d"], "A"], | |||
[["a", "b", "e", "d"], "C"], | |||
[["a", "b", "e", "d"], "B"], | |||
] | |||
return labeled_data_list | |||
data_set = TextClassifyDataSet(load_func=loader) | |||
data_set.load("xxx") | |||
self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||
self.assertTrue(len(data_set) > 0) | |||
self.assertTrue(hasattr(data_set[0], "fields")) | |||
self.assertTrue("word_seq" in data_set[0].fields) | |||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||
self.assertTrue("label" in data_set[0].fields) | |||
self.assertTrue(hasattr(data_set[0].fields["label"], "label")) | |||
self.assertTrue(hasattr(data_set[0].fields["label"], "_index")) | |||
self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1]) | |||
def test_case_2(self): | |||
def loader(path): | |||
labeled_data_list = [ | |||
[["a", "b", "e", "d"], "A"], | |||
[["a", "b", "e", "d"], "C"], | |||
[["a", "b", "e", "d"], "B"], | |||
] | |||
return labeled_data_list | |||
data_set = TextClassifyDataSet(load_func=loader) | |||
data_set.load("xxx", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab}) | |||
self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||
self.assertTrue(len(data_set) > 0) | |||
self.assertTrue(hasattr(data_set[0], "fields")) | |||
self.assertTrue("word_seq" in data_set[0].fields) | |||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||
self.assertEqual(data_set[0].fields["word_seq"]._index, | |||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||
self.assertTrue("label" in data_set[0].fields) | |||
self.assertTrue(hasattr(data_set[0].fields["label"], "label")) | |||
self.assertTrue(hasattr(data_set[0].fields["label"], "_index")) | |||
self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1]) | |||
self.assertEqual(data_set[0].fields["label"]._index, self.label_vocab[self.labeled_data_list[0][1]]) | |||
def test_case_3(self): | |||
def loader(path): | |||
unlabeled_data_list = [ | |||
["a", "b", "e", "d"], | |||
["a", "b", "e", "d"], | |||
["a", "b", "e", "d"] | |||
] | |||
return unlabeled_data_list | |||
data_set = TextClassifyDataSet(load_func=loader) | |||
data_set.load("xxx", vocabs={"word_vocab": self.word_vocab}, infer=True) | |||
self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||
self.assertTrue(len(data_set) > 0) | |||
self.assertTrue(hasattr(data_set[0], "fields")) | |||
self.assertTrue("word_seq" in data_set[0].fields) | |||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||
self.assertEqual(data_set[0].fields["word_seq"]._index, | |||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]]) |
@@ -1,13 +1,13 @@ | |||
import os | |||
import unittest | |||
from fastNLP.core.predictor import Predictor | |||
from fastNLP.core.dataset import TextClassifyDataSet, SeqLabelDataSet | |||
from fastNLP.core.predictor import Predictor | |||
from fastNLP.core.preprocess import save_pickle | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.loader.base_loader import BaseLoader | |||
from fastNLP.models.sequence_modeling import SeqLabeling | |||
from fastNLP.models.cnn_text_classification import CNNText | |||
from fastNLP.models.sequence_modeling import SeqLabeling | |||
class TestPredictor(unittest.TestCase): | |||
@@ -42,7 +42,7 @@ class TestPredictor(unittest.TestCase): | |||
predictor = Predictor("./save/", pre.text_classify_post_processor) | |||
# Load infer data | |||
infer_data_set = TextClassifyDataSet(loader=BaseLoader()) | |||
infer_data_set = TextClassifyDataSet(load_func=BaseLoader.load) | |||
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) | |||
results = predictor.predict(network=model, data=infer_data_set) | |||
@@ -59,7 +59,7 @@ class TestPredictor(unittest.TestCase): | |||
model = SeqLabeling(model_args) | |||
predictor = Predictor("./save/", pre.seq_label_post_processor) | |||
infer_data_set = SeqLabelDataSet(loader=BaseLoader()) | |||
infer_data_set = SeqLabelDataSet(load_func=BaseLoader.load) | |||
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) | |||
results = predictor.predict(network=model, data=infer_data_set) | |||
@@ -53,7 +53,7 @@ def infer(): | |||
print("model loaded!") | |||
# Data Loader | |||
infer_data = SeqLabelDataSet(loader=BaseLoader()) | |||
infer_data = SeqLabelDataSet(load_func=BaseLoader.load) | |||
infer_data.load(data_infer_path, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}, infer=True) | |||
print("data set prepared") | |||
@@ -37,7 +37,7 @@ def infer(): | |||
print("model loaded!") | |||
# Load infer data | |||
infer_data = SeqLabelDataSet(loader=BaseLoader()) | |||
infer_data = SeqLabelDataSet(load_func=BaseLoader.load) | |||
infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True) | |||
# inference | |||
@@ -52,7 +52,7 @@ def train_test(): | |||
ConfigLoader().load_config(config_path, {"POS_infer": train_args}) | |||
# define dataset | |||
data_train = SeqLabelDataSet(loader=TokenizeDataSetLoader()) | |||
data_train = SeqLabelDataSet(load_func=TokenizeDataSetLoader.load) | |||
data_train.load(cws_data_path) | |||
train_args["vocab_size"] = len(data_train.word_vocab) | |||
train_args["num_classes"] = len(data_train.label_vocab) | |||
@@ -40,7 +40,7 @@ def infer(): | |||
print("vocabulary size:", len(word_vocab)) | |||
print("number of classes:", len(label_vocab)) | |||
infer_data = TextClassifyDataSet(loader=ClassDataSetLoader()) | |||
infer_data = TextClassifyDataSet(load_func=ClassDataSetLoader.load) | |||
infer_data.load(train_data_dir, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}) | |||
model_args = ConfigSection() | |||
@@ -67,7 +67,7 @@ def train(): | |||
# load dataset | |||
print("Loading data...") | |||
data = TextClassifyDataSet(loader=ClassDataSetLoader()) | |||
data = TextClassifyDataSet(load_func=ClassDataSetLoader.load) | |||
data.load(train_data_dir) | |||
print("vocabulary size:", len(data.word_vocab)) | |||
@@ -2,7 +2,7 @@ import unittest | |||
import torch | |||
from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear | |||
from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear, BiAffine | |||
class TestGroupNorm(unittest.TestCase): | |||
@@ -27,3 +27,25 @@ class TestBiLinear(unittest.TestCase): | |||
y = bl(x_left, x_right) | |||
print(bl) | |||
bl2 = BiLinear(n_left=15, n_right=15, n_out=10, bias=True) | |||
class TestBiAffine(unittest.TestCase): | |||
def test_case_1(self): | |||
batch_size = 16 | |||
encoder_length = 21 | |||
decoder_length = 32 | |||
layer = BiAffine(10, 10, 25, biaffine=True) | |||
decoder_input = torch.randn((batch_size, encoder_length, 10)) | |||
encoder_input = torch.randn((batch_size, decoder_length, 10)) | |||
y = layer(decoder_input, encoder_input) | |||
self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, decoder_length)) | |||
def test_case_2(self): | |||
batch_size = 16 | |||
encoder_length = 21 | |||
decoder_length = 32 | |||
layer = BiAffine(10, 10, 25, biaffine=False) | |||
decoder_input = torch.randn((batch_size, encoder_length, 10)) | |||
encoder_input = torch.randn((batch_size, decoder_length, 10)) | |||
y = layer(decoder_input, encoder_input) | |||
self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, 1)) |
@@ -1,8 +1,5 @@ | |||
import os | |||
import unittest | |||
import configparser | |||
import json | |||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | |||
from fastNLP.saver.config_saver import ConfigSaver | |||
@@ -10,7 +7,7 @@ from fastNLP.saver.config_saver import ConfigSaver | |||
class TestConfigSaver(unittest.TestCase): | |||
def test_case_1(self): | |||
config_file_dir = "./test/loader/" | |||
config_file_dir = "test/loader/" | |||
config_file_name = "config" | |||
config_file_path = os.path.join(config_file_dir, config_file_name) | |||
@@ -80,3 +77,37 @@ class TestConfigSaver(unittest.TestCase): | |||
tmp_config_saver = ConfigSaver("file-NOT-exist") | |||
except Exception as e: | |||
pass | |||
def test_case_2(self): | |||
config = "[section_A]\n[section_B]\n" | |||
with open("./test.cfg", "w", encoding="utf-8") as f: | |||
f.write(config) | |||
saver = ConfigSaver("./test.cfg") | |||
section = ConfigSection() | |||
section["doubles"] = 0.8 | |||
section["tt"] = [1, 2, 3] | |||
section["test"] = 105 | |||
section["str"] = "this is a str" | |||
saver.save_config_file("section_A", section) | |||
os.system("rm ./test.cfg") | |||
def test_case_3(self): | |||
config = "[section_A]\ndoubles = 0.9\ntt = [1, 2, 3]\n[section_B]\n" | |||
with open("./test.cfg", "w", encoding="utf-8") as f: | |||
f.write(config) | |||
saver = ConfigSaver("./test.cfg") | |||
section = ConfigSection() | |||
section["doubles"] = 0.8 | |||
section["tt"] = [1, 2, 3] | |||
section["test"] = 105 | |||
section["str"] = "this is a str" | |||
saver.save_config_file("section_A", section) | |||
os.system("rm ./test.cfg") |