- DataSet's __init__ takes a function as argument, rather than class object - Preprocessor is about to remove. Don't use anymore. - Remove cross_validate in trainer, because it is rarely used and wired - Loader.load is expected to be a static method - Delete sth. in other_modules.py - Add more tests - Delete extra sample datatags/v0.1.0^2
@@ -70,18 +70,18 @@ class DataSet(list): | |||||
""" | """ | ||||
def __init__(self, name="", instances=None, loader=None): | |||||
def __init__(self, name="", instances=None, load_func=None): | |||||
""" | """ | ||||
:param name: str, the name of the dataset. (default: "") | :param name: str, the name of the dataset. (default: "") | ||||
:param instances: list of Instance objects. (default: None) | :param instances: list of Instance objects. (default: None) | ||||
:param load_func: a function that takes the dataset path (string) as input and returns multi-level lists. | |||||
""" | """ | ||||
list.__init__([]) | list.__init__([]) | ||||
self.name = name | self.name = name | ||||
if instances is not None: | if instances is not None: | ||||
self.extend(instances) | self.extend(instances) | ||||
self.dataset_loader = loader | |||||
self.data_set_load_func = load_func | |||||
def index_all(self, vocab): | def index_all(self, vocab): | ||||
for ins in self: | for ins in self: | ||||
@@ -117,15 +117,15 @@ class DataSet(list): | |||||
return lengths | return lengths | ||||
def convert(self, data): | def convert(self, data): | ||||
"""Convert lists of strings into Instances with Fields""" | |||||
"""Convert lists of strings into Instances with Fields, creating Vocabulary for labeled data. Used in Training.""" | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def convert_with_vocabs(self, data, vocabs): | def convert_with_vocabs(self, data, vocabs): | ||||
"""Convert lists of strings into Instances with Fields, using existing Vocabulary. Useful in predicting.""" | |||||
"""Convert lists of strings into Instances with Fields, using existing Vocabulary, with labels. Used in Testing.""" | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def convert_for_infer(self, data, vocabs): | def convert_for_infer(self, data, vocabs): | ||||
"""Convert lists of strings into Instances with Fields.""" | |||||
"""Convert lists of strings into Instances with Fields, using existing Vocabulary, without labels. Used in predicting.""" | |||||
def load(self, data_path, vocabs=None, infer=False): | def load(self, data_path, vocabs=None, infer=False): | ||||
"""Load data from the given files. | """Load data from the given files. | ||||
@@ -135,7 +135,7 @@ class DataSet(list): | |||||
:param vocabs: dict of (name: Vocabulary object), used to index data. If not provided, a new vocabulary will be constructed. | :param vocabs: dict of (name: Vocabulary object), used to index data. If not provided, a new vocabulary will be constructed. | ||||
""" | """ | ||||
raw_data = self.dataset_loader.load(data_path) | |||||
raw_data = self.data_set_load_func(data_path) | |||||
if infer is True: | if infer is True: | ||||
self.convert_for_infer(raw_data, vocabs) | self.convert_for_infer(raw_data, vocabs) | ||||
else: | else: | ||||
@@ -145,7 +145,7 @@ class DataSet(list): | |||||
self.convert(raw_data) | self.convert(raw_data) | ||||
def load_raw(self, raw_data, vocabs): | def load_raw(self, raw_data, vocabs): | ||||
""" | |||||
"""Load raw data without loader. Used in FastNLP class. | |||||
:param raw_data: | :param raw_data: | ||||
:param vocabs: | :param vocabs: | ||||
@@ -174,8 +174,8 @@ class DataSet(list): | |||||
class SeqLabelDataSet(DataSet): | class SeqLabelDataSet(DataSet): | ||||
def __init__(self, instances=None, loader=POSDataSetLoader()): | |||||
super(SeqLabelDataSet, self).__init__(name="", instances=instances, loader=loader) | |||||
def __init__(self, instances=None, load_func=POSDataSetLoader().load): | |||||
super(SeqLabelDataSet, self).__init__(name="", instances=instances, load_func=load_func) | |||||
self.word_vocab = Vocabulary() | self.word_vocab = Vocabulary() | ||||
self.label_vocab = Vocabulary() | self.label_vocab = Vocabulary() | ||||
@@ -231,8 +231,8 @@ class SeqLabelDataSet(DataSet): | |||||
class TextClassifyDataSet(DataSet): | class TextClassifyDataSet(DataSet): | ||||
def __init__(self, instances=None, loader=ClassDataSetLoader()): | |||||
super(TextClassifyDataSet, self).__init__(name="", instances=instances, loader=loader) | |||||
def __init__(self, instances=None, load_func=ClassDataSetLoader().load): | |||||
super(TextClassifyDataSet, self).__init__(name="", instances=instances, load_func=load_func) | |||||
self.word_vocab = Vocabulary() | self.word_vocab = Vocabulary() | ||||
self.label_vocab = Vocabulary(need_default=False) | self.label_vocab = Vocabulary(need_default=False) | ||||
@@ -285,10 +285,3 @@ def change_field_is_target(data_set, field_name, new_target): | |||||
for inst in data_set: | for inst in data_set: | ||||
inst.fields[field_name].is_target = new_target | inst.fields[field_name].is_target = new_target | ||||
if __name__ == "__main__": | |||||
data_set = SeqLabelDataSet() | |||||
data_set.load("../../test/data_for_tests/people.txt") | |||||
a, b = data_set.split(0.3) | |||||
print(type(data_set), type(a), type(b)) | |||||
print(len(data_set), len(a), len(b)) |
@@ -78,6 +78,7 @@ class Preprocessor(object): | |||||
is only available when label_is_seq is True. Default: False. | is only available when label_is_seq is True. Default: False. | ||||
:param add_char_field: bool, whether to add character representations to all TextFields. Default: False. | :param add_char_field: bool, whether to add character representations to all TextFields. Default: False. | ||||
""" | """ | ||||
print("Preprocessor is about to deprecate. Please use DataSet class.") | |||||
self.data_vocab = Vocabulary() | self.data_vocab = Vocabulary() | ||||
if label_is_seq is True: | if label_is_seq is True: | ||||
if share_vocab is True: | if share_vocab is True: | ||||
@@ -307,11 +308,3 @@ class ClassPreprocess(Preprocessor): | |||||
print("[FastNLP warning] ClassPreprocess is about to deprecate. Please use Preprocess directly.") | print("[FastNLP warning] ClassPreprocess is about to deprecate. Please use Preprocess directly.") | ||||
super(ClassPreprocess, self).__init__() | super(ClassPreprocess, self).__init__() | ||||
if __name__ == "__main__": | |||||
p = Preprocessor() | |||||
train_dev_data = [[["I", "am", "a", "good", "student", "."], "0"], | |||||
[["You", "are", "pretty", "."], "1"] | |||||
] | |||||
training_set = p.run(train_dev_data) | |||||
print(training_set) |
@@ -1,4 +1,3 @@ | |||||
import copy | |||||
import os | import os | ||||
import time | import time | ||||
from datetime import timedelta | from datetime import timedelta | ||||
@@ -178,31 +177,6 @@ class Trainer(object): | |||||
logger.info(print_output) | logger.info(print_output) | ||||
step += 1 | step += 1 | ||||
def cross_validate(self, network, train_data_cv, dev_data_cv): | |||||
"""Training with cross validation. | |||||
:param network: the model | |||||
:param train_data_cv: four-level list, of shape [num_folds, num_examples, 2, ?] | |||||
:param dev_data_cv: four-level list, of shape [num_folds, num_examples, 2, ?] | |||||
""" | |||||
if len(train_data_cv) != len(dev_data_cv): | |||||
logger.error("the number of folds in train and dev data unequals {}!={}".format(len(train_data_cv), | |||||
len(dev_data_cv))) | |||||
raise RuntimeError("the number of folds in train and dev data unequals") | |||||
if self.validate is False: | |||||
logger.warn("Cross validation requires self.validate to be True. Please turn it on. ") | |||||
print("[warning] Cross validation requires self.validate to be True. Please turn it on. ") | |||||
self.validate = True | |||||
n_fold = len(train_data_cv) | |||||
logger.info("perform {} folds cross validation.".format(n_fold)) | |||||
for i in range(n_fold): | |||||
print("CV:", i) | |||||
logger.info("running the {} of {} folds cross validation".format(i + 1, n_fold)) | |||||
network_copy = copy.deepcopy(network) | |||||
self.train(network_copy, train_data_cv[i], dev_data_cv[i]) | |||||
def mode(self, model, is_test=False): | def mode(self, model, is_test=False): | ||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
@@ -1,11 +1,10 @@ | |||||
import os | import os | ||||
from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet | |||||
from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer | from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer | ||||
from fastNLP.core.preprocess import load_pickle | from fastNLP.core.preprocess import load_pickle | ||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet | |||||
""" | """ | ||||
mapping from model name to [URL, file_name.class_name, model_pickle_name] | mapping from model name to [URL, file_name.class_name, model_pickle_name] | ||||
@@ -73,7 +72,7 @@ class FastNLP(object): | |||||
:param model_dir: this directory should contain the following files: | :param model_dir: this directory should contain the following files: | ||||
1. a trained model | 1. a trained model | ||||
2. a config file, which is a fastNLP's configuration. | 2. a config file, which is a fastNLP's configuration. | ||||
3. a Vocab file, which is a pickle object of a Vocab instance. | |||||
3. two Vocab files, which are pickle objects of Vocab instances, representing feature and label vocabs. | |||||
""" | """ | ||||
self.model_dir = model_dir | self.model_dir = model_dir | ||||
self.model = None | self.model = None | ||||
@@ -192,7 +191,7 @@ class FastNLP(object): | |||||
def _load(self, model_dir, model_name): | def _load(self, model_dir, model_name): | ||||
# To do | |||||
return 0 | return 0 | ||||
def _download(self, model_name, url): | def _download(self, model_name, url): | ||||
@@ -202,7 +201,7 @@ class FastNLP(object): | |||||
:param url: | :param url: | ||||
""" | """ | ||||
print("Downloading {} from {}".format(model_name, url)) | print("Downloading {} from {}".format(model_name, url)) | ||||
# To do | |||||
# TODO: download model via url | |||||
def model_exist(self, model_dir): | def model_exist(self, model_dir): | ||||
""" | """ | ||||
@@ -3,12 +3,14 @@ class BaseLoader(object): | |||||
def __init__(self): | def __init__(self): | ||||
super(BaseLoader, self).__init__() | super(BaseLoader, self).__init__() | ||||
def load_lines(self, data_path): | |||||
@staticmethod | |||||
def load_lines(data_path): | |||||
with open(data_path, "r", encoding="utf=8") as f: | with open(data_path, "r", encoding="utf=8") as f: | ||||
text = f.readlines() | text = f.readlines() | ||||
return [line.strip() for line in text] | return [line.strip() for line in text] | ||||
def load(self, data_path): | |||||
@staticmethod | |||||
def load(data_path): | |||||
with open(data_path, "r", encoding="utf-8") as f: | with open(data_path, "r", encoding="utf-8") as f: | ||||
text = f.readlines() | text = f.readlines() | ||||
return [[word for word in sent.strip()] for sent in text] | return [[word for word in sent.strip()] for sent in text] | ||||
@@ -84,7 +84,8 @@ class TokenizeDataSetLoader(DataSetLoader): | |||||
def __init__(self): | def __init__(self): | ||||
super(TokenizeDataSetLoader, self).__init__() | super(TokenizeDataSetLoader, self).__init__() | ||||
def load(self, data_path, max_seq_len=32): | |||||
@staticmethod | |||||
def load(data_path, max_seq_len=32): | |||||
""" | """ | ||||
load pku dataset for Chinese word segmentation | load pku dataset for Chinese word segmentation | ||||
CWS (Chinese Word Segmentation) pku training dataset format: | CWS (Chinese Word Segmentation) pku training dataset format: | ||||
@@ -196,30 +196,3 @@ class BiAffine(nn.Module): | |||||
output = output * mask_d.unsqueeze(1).unsqueeze(3) * mask_e.unsqueeze(1).unsqueeze(2) | output = output * mask_d.unsqueeze(1).unsqueeze(3) * mask_e.unsqueeze(1).unsqueeze(2) | ||||
return output | return output | ||||
class Transpose(nn.Module): | |||||
def __init__(self, x, y): | |||||
super(Transpose, self).__init__() | |||||
self.x = x | |||||
self.y = y | |||||
def forward(self, x): | |||||
return x.transpose(self.x, self.y) | |||||
class WordDropout(nn.Module): | |||||
def __init__(self, dropout_rate, drop_to_token): | |||||
super(WordDropout, self).__init__() | |||||
self.dropout_rate = dropout_rate | |||||
self.drop_to_token = drop_to_token | |||||
def forward(self, word_idx): | |||||
if not self.training: | |||||
return word_idx | |||||
drop_mask = torch.rand(word_idx.shape) < self.dropout_rate | |||||
if word_idx.device.type == 'cuda': | |||||
drop_mask = drop_mask.cuda() | |||||
drop_mask = drop_mask.long() | |||||
output = drop_mask * self.drop_to_token + (1 - drop_mask) * word_idx | |||||
return output |
@@ -104,7 +104,8 @@ class ConfigSaver(object): | |||||
:return: | :return: | ||||
""" | """ | ||||
section_file = self._get_section(section_name) | section_file = self._get_section(section_name) | ||||
if len(section_file.__dict__.keys()) == 0:#the section not in file before | |||||
if len(section_file.__dict__.keys()) == 0: # the section not in the file before | |||||
# append this section to config file | |||||
with open(self.file_path, 'a') as f: | with open(self.file_path, 'a') as f: | ||||
f.write('[' + section_name + ']\n') | f.write('[' + section_name + ']\n') | ||||
for k in section.__dict__.keys(): | for k in section.__dict__.keys(): | ||||
@@ -114,9 +115,11 @@ class ConfigSaver(object): | |||||
else: | else: | ||||
f.write(str(section[k]) + '\n\n') | f.write(str(section[k]) + '\n\n') | ||||
else: | else: | ||||
# the section exists | |||||
change_file = False | change_file = False | ||||
for k in section.__dict__.keys(): | for k in section.__dict__.keys(): | ||||
if k not in section_file: | if k not in section_file: | ||||
# find a new key in this section | |||||
change_file = True | change_file = True | ||||
break | break | ||||
if section_file[k] != section[k]: | if section_file[k] != section[k]: | ||||
@@ -0,0 +1,243 @@ | |||||
import unittest | |||||
from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet | |||||
from fastNLP.core.dataset import create_dataset_from_lists | |||||
class TestDataSet(unittest.TestCase): | |||||
labeled_data_list = [ | |||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
] | |||||
unlabeled_data_list = [ | |||||
["a", "b", "e", "d"], | |||||
["a", "b", "e", "d"], | |||||
["a", "b", "e", "d"] | |||||
] | |||||
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} | |||||
label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4} | |||||
def test_case_1(self): | |||||
data_set = create_dataset_from_lists(self.labeled_data_list, self.word_vocab, has_target=True, | |||||
label_vocab=self.label_vocab) | |||||
self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
self.assertTrue(len(data_set) > 0) | |||||
self.assertTrue(hasattr(data_set[0], "fields")) | |||||
self.assertTrue("word_seq" in data_set[0].fields) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||||
self.assertTrue("label_seq" in data_set[0].fields) | |||||
self.assertTrue(hasattr(data_set[0].fields["label_seq"], "text")) | |||||
self.assertTrue(hasattr(data_set[0].fields["label_seq"], "_index")) | |||||
self.assertEqual(data_set[0].fields["label_seq"].text, self.labeled_data_list[0][1]) | |||||
self.assertEqual(data_set[0].fields["label_seq"]._index, | |||||
[self.label_vocab[c] for c in self.labeled_data_list[0][1]]) | |||||
def test_case_2(self): | |||||
data_set = create_dataset_from_lists(self.unlabeled_data_list, self.word_vocab, has_target=False) | |||||
self.assertEqual(len(data_set), len(self.unlabeled_data_list)) | |||||
self.assertTrue(len(data_set) > 0) | |||||
self.assertTrue(hasattr(data_set[0], "fields")) | |||||
self.assertTrue("word_seq" in data_set[0].fields) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.unlabeled_data_list[0]) | |||||
self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
[self.word_vocab[c] for c in self.unlabeled_data_list[0]]) | |||||
class TestDataSetConvertion(unittest.TestCase): | |||||
labeled_data_list = [ | |||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
] | |||||
unlabeled_data_list = [ | |||||
["a", "b", "e", "d"], | |||||
["a", "b", "e", "d"], | |||||
["a", "b", "e", "d"] | |||||
] | |||||
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} | |||||
label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4} | |||||
def test_case_1(self): | |||||
def loader(path): | |||||
labeled_data_list = [ | |||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
] | |||||
return labeled_data_list | |||||
data_set = SeqLabelDataSet(load_func=loader) | |||||
data_set.load("any_path") | |||||
self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
self.assertTrue(len(data_set) > 0) | |||||
self.assertTrue(hasattr(data_set[0], "fields")) | |||||
self.assertTrue("word_seq" in data_set[0].fields) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
self.assertTrue("truth" in data_set[0].fields) | |||||
self.assertTrue(hasattr(data_set[0].fields["truth"], "text")) | |||||
self.assertTrue(hasattr(data_set[0].fields["truth"], "_index")) | |||||
self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1]) | |||||
self.assertTrue("word_seq_origin_len" in data_set[0].fields) | |||||
def test_case_2(self): | |||||
def loader(path): | |||||
unlabeled_data_list = [ | |||||
["a", "b", "e", "d"], | |||||
["a", "b", "e", "d"], | |||||
["a", "b", "e", "d"] | |||||
] | |||||
return unlabeled_data_list | |||||
data_set = SeqLabelDataSet(load_func=loader) | |||||
data_set.load("any_path", vocabs={"word_vocab": self.word_vocab}, infer=True) | |||||
self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
self.assertTrue(len(data_set) > 0) | |||||
self.assertTrue(hasattr(data_set[0], "fields")) | |||||
self.assertTrue("word_seq" in data_set[0].fields) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||||
self.assertTrue("word_seq_origin_len" in data_set[0].fields) | |||||
def test_case_3(self): | |||||
def loader(path): | |||||
labeled_data_list = [ | |||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
] | |||||
return labeled_data_list | |||||
data_set = SeqLabelDataSet(load_func=loader) | |||||
data_set.load("any_path", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab}) | |||||
self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
self.assertTrue(len(data_set) > 0) | |||||
self.assertTrue(hasattr(data_set[0], "fields")) | |||||
self.assertTrue("word_seq" in data_set[0].fields) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||||
self.assertTrue("truth" in data_set[0].fields) | |||||
self.assertTrue(hasattr(data_set[0].fields["truth"], "text")) | |||||
self.assertTrue(hasattr(data_set[0].fields["truth"], "_index")) | |||||
self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1]) | |||||
self.assertEqual(data_set[0].fields["truth"]._index, | |||||
[self.label_vocab[c] for c in self.labeled_data_list[0][1]]) | |||||
self.assertTrue("word_seq_origin_len" in data_set[0].fields) | |||||
class TestDataSetConvertionHHH(unittest.TestCase): | |||||
labeled_data_list = [ | |||||
[["a", "b", "e", "d"], "A"], | |||||
[["a", "b", "e", "d"], "C"], | |||||
[["a", "b", "e", "d"], "B"], | |||||
] | |||||
unlabeled_data_list = [ | |||||
["a", "b", "e", "d"], | |||||
["a", "b", "e", "d"], | |||||
["a", "b", "e", "d"] | |||||
] | |||||
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} | |||||
label_vocab = {"A": 1, "B": 2, "C": 3} | |||||
def test_case_1(self): | |||||
def loader(path): | |||||
labeled_data_list = [ | |||||
[["a", "b", "e", "d"], "A"], | |||||
[["a", "b", "e", "d"], "C"], | |||||
[["a", "b", "e", "d"], "B"], | |||||
] | |||||
return labeled_data_list | |||||
data_set = TextClassifyDataSet(load_func=loader) | |||||
data_set.load("xxx") | |||||
self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
self.assertTrue(len(data_set) > 0) | |||||
self.assertTrue(hasattr(data_set[0], "fields")) | |||||
self.assertTrue("word_seq" in data_set[0].fields) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
self.assertTrue("label" in data_set[0].fields) | |||||
self.assertTrue(hasattr(data_set[0].fields["label"], "label")) | |||||
self.assertTrue(hasattr(data_set[0].fields["label"], "_index")) | |||||
self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1]) | |||||
def test_case_2(self): | |||||
def loader(path): | |||||
labeled_data_list = [ | |||||
[["a", "b", "e", "d"], "A"], | |||||
[["a", "b", "e", "d"], "C"], | |||||
[["a", "b", "e", "d"], "B"], | |||||
] | |||||
return labeled_data_list | |||||
data_set = TextClassifyDataSet(load_func=loader) | |||||
data_set.load("xxx", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab}) | |||||
self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
self.assertTrue(len(data_set) > 0) | |||||
self.assertTrue(hasattr(data_set[0], "fields")) | |||||
self.assertTrue("word_seq" in data_set[0].fields) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||||
self.assertTrue("label" in data_set[0].fields) | |||||
self.assertTrue(hasattr(data_set[0].fields["label"], "label")) | |||||
self.assertTrue(hasattr(data_set[0].fields["label"], "_index")) | |||||
self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1]) | |||||
self.assertEqual(data_set[0].fields["label"]._index, self.label_vocab[self.labeled_data_list[0][1]]) | |||||
def test_case_3(self): | |||||
def loader(path): | |||||
unlabeled_data_list = [ | |||||
["a", "b", "e", "d"], | |||||
["a", "b", "e", "d"], | |||||
["a", "b", "e", "d"] | |||||
] | |||||
return unlabeled_data_list | |||||
data_set = TextClassifyDataSet(load_func=loader) | |||||
data_set.load("xxx", vocabs={"word_vocab": self.word_vocab}, infer=True) | |||||
self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
self.assertTrue(len(data_set) > 0) | |||||
self.assertTrue(hasattr(data_set[0], "fields")) | |||||
self.assertTrue("word_seq" in data_set[0].fields) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]]) |
@@ -1,13 +1,13 @@ | |||||
import os | import os | ||||
import unittest | import unittest | ||||
from fastNLP.core.predictor import Predictor | |||||
from fastNLP.core.dataset import TextClassifyDataSet, SeqLabelDataSet | from fastNLP.core.dataset import TextClassifyDataSet, SeqLabelDataSet | ||||
from fastNLP.core.predictor import Predictor | |||||
from fastNLP.core.preprocess import save_pickle | from fastNLP.core.preprocess import save_pickle | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.loader.base_loader import BaseLoader | from fastNLP.loader.base_loader import BaseLoader | ||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
from fastNLP.models.cnn_text_classification import CNNText | from fastNLP.models.cnn_text_classification import CNNText | ||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
class TestPredictor(unittest.TestCase): | class TestPredictor(unittest.TestCase): | ||||
@@ -42,7 +42,7 @@ class TestPredictor(unittest.TestCase): | |||||
predictor = Predictor("./save/", pre.text_classify_post_processor) | predictor = Predictor("./save/", pre.text_classify_post_processor) | ||||
# Load infer data | # Load infer data | ||||
infer_data_set = TextClassifyDataSet(loader=BaseLoader()) | |||||
infer_data_set = TextClassifyDataSet(load_func=BaseLoader.load) | |||||
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) | infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) | ||||
results = predictor.predict(network=model, data=infer_data_set) | results = predictor.predict(network=model, data=infer_data_set) | ||||
@@ -59,7 +59,7 @@ class TestPredictor(unittest.TestCase): | |||||
model = SeqLabeling(model_args) | model = SeqLabeling(model_args) | ||||
predictor = Predictor("./save/", pre.seq_label_post_processor) | predictor = Predictor("./save/", pre.seq_label_post_processor) | ||||
infer_data_set = SeqLabelDataSet(loader=BaseLoader()) | |||||
infer_data_set = SeqLabelDataSet(load_func=BaseLoader.load) | |||||
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) | infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) | ||||
results = predictor.predict(network=model, data=infer_data_set) | results = predictor.predict(network=model, data=infer_data_set) | ||||
@@ -53,7 +53,7 @@ def infer(): | |||||
print("model loaded!") | print("model loaded!") | ||||
# Data Loader | # Data Loader | ||||
infer_data = SeqLabelDataSet(loader=BaseLoader()) | |||||
infer_data = SeqLabelDataSet(load_func=BaseLoader.load) | |||||
infer_data.load(data_infer_path, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}, infer=True) | infer_data.load(data_infer_path, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}, infer=True) | ||||
print("data set prepared") | print("data set prepared") | ||||
@@ -37,7 +37,7 @@ def infer(): | |||||
print("model loaded!") | print("model loaded!") | ||||
# Load infer data | # Load infer data | ||||
infer_data = SeqLabelDataSet(loader=BaseLoader()) | |||||
infer_data = SeqLabelDataSet(load_func=BaseLoader.load) | |||||
infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True) | infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True) | ||||
# inference | # inference | ||||
@@ -52,7 +52,7 @@ def train_test(): | |||||
ConfigLoader().load_config(config_path, {"POS_infer": train_args}) | ConfigLoader().load_config(config_path, {"POS_infer": train_args}) | ||||
# define dataset | # define dataset | ||||
data_train = SeqLabelDataSet(loader=TokenizeDataSetLoader()) | |||||
data_train = SeqLabelDataSet(load_func=TokenizeDataSetLoader.load) | |||||
data_train.load(cws_data_path) | data_train.load(cws_data_path) | ||||
train_args["vocab_size"] = len(data_train.word_vocab) | train_args["vocab_size"] = len(data_train.word_vocab) | ||||
train_args["num_classes"] = len(data_train.label_vocab) | train_args["num_classes"] = len(data_train.label_vocab) | ||||
@@ -40,7 +40,7 @@ def infer(): | |||||
print("vocabulary size:", len(word_vocab)) | print("vocabulary size:", len(word_vocab)) | ||||
print("number of classes:", len(label_vocab)) | print("number of classes:", len(label_vocab)) | ||||
infer_data = TextClassifyDataSet(loader=ClassDataSetLoader()) | |||||
infer_data = TextClassifyDataSet(load_func=ClassDataSetLoader.load) | |||||
infer_data.load(train_data_dir, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}) | infer_data.load(train_data_dir, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}) | ||||
model_args = ConfigSection() | model_args = ConfigSection() | ||||
@@ -67,7 +67,7 @@ def train(): | |||||
# load dataset | # load dataset | ||||
print("Loading data...") | print("Loading data...") | ||||
data = TextClassifyDataSet(loader=ClassDataSetLoader()) | |||||
data = TextClassifyDataSet(load_func=ClassDataSetLoader.load) | |||||
data.load(train_data_dir) | data.load(train_data_dir) | ||||
print("vocabulary size:", len(data.word_vocab)) | print("vocabulary size:", len(data.word_vocab)) | ||||
@@ -2,7 +2,7 @@ import unittest | |||||
import torch | import torch | ||||
from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear | |||||
from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear, BiAffine | |||||
class TestGroupNorm(unittest.TestCase): | class TestGroupNorm(unittest.TestCase): | ||||
@@ -27,3 +27,25 @@ class TestBiLinear(unittest.TestCase): | |||||
y = bl(x_left, x_right) | y = bl(x_left, x_right) | ||||
print(bl) | print(bl) | ||||
bl2 = BiLinear(n_left=15, n_right=15, n_out=10, bias=True) | bl2 = BiLinear(n_left=15, n_right=15, n_out=10, bias=True) | ||||
class TestBiAffine(unittest.TestCase): | |||||
def test_case_1(self): | |||||
batch_size = 16 | |||||
encoder_length = 21 | |||||
decoder_length = 32 | |||||
layer = BiAffine(10, 10, 25, biaffine=True) | |||||
decoder_input = torch.randn((batch_size, encoder_length, 10)) | |||||
encoder_input = torch.randn((batch_size, decoder_length, 10)) | |||||
y = layer(decoder_input, encoder_input) | |||||
self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, decoder_length)) | |||||
def test_case_2(self): | |||||
batch_size = 16 | |||||
encoder_length = 21 | |||||
decoder_length = 32 | |||||
layer = BiAffine(10, 10, 25, biaffine=False) | |||||
decoder_input = torch.randn((batch_size, encoder_length, 10)) | |||||
encoder_input = torch.randn((batch_size, decoder_length, 10)) | |||||
y = layer(decoder_input, encoder_input) | |||||
self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, 1)) |
@@ -1,8 +1,5 @@ | |||||
import os | import os | ||||
import unittest | import unittest | ||||
import configparser | |||||
import json | |||||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | ||||
from fastNLP.saver.config_saver import ConfigSaver | from fastNLP.saver.config_saver import ConfigSaver | ||||
@@ -10,7 +7,7 @@ from fastNLP.saver.config_saver import ConfigSaver | |||||
class TestConfigSaver(unittest.TestCase): | class TestConfigSaver(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
config_file_dir = "./test/loader/" | |||||
config_file_dir = "test/loader/" | |||||
config_file_name = "config" | config_file_name = "config" | ||||
config_file_path = os.path.join(config_file_dir, config_file_name) | config_file_path = os.path.join(config_file_dir, config_file_name) | ||||
@@ -80,3 +77,37 @@ class TestConfigSaver(unittest.TestCase): | |||||
tmp_config_saver = ConfigSaver("file-NOT-exist") | tmp_config_saver = ConfigSaver("file-NOT-exist") | ||||
except Exception as e: | except Exception as e: | ||||
pass | pass | ||||
def test_case_2(self): | |||||
config = "[section_A]\n[section_B]\n" | |||||
with open("./test.cfg", "w", encoding="utf-8") as f: | |||||
f.write(config) | |||||
saver = ConfigSaver("./test.cfg") | |||||
section = ConfigSection() | |||||
section["doubles"] = 0.8 | |||||
section["tt"] = [1, 2, 3] | |||||
section["test"] = 105 | |||||
section["str"] = "this is a str" | |||||
saver.save_config_file("section_A", section) | |||||
os.system("rm ./test.cfg") | |||||
def test_case_3(self): | |||||
config = "[section_A]\ndoubles = 0.9\ntt = [1, 2, 3]\n[section_B]\n" | |||||
with open("./test.cfg", "w", encoding="utf-8") as f: | |||||
f.write(config) | |||||
saver = ConfigSaver("./test.cfg") | |||||
section = ConfigSection() | |||||
section["doubles"] = 0.8 | |||||
section["tt"] = [1, 2, 3] | |||||
section["test"] = 105 | |||||
section["str"] = "this is a str" | |||||
saver.save_config_file("section_A", section) | |||||
os.system("rm ./test.cfg") |