Browse Source

Merge Preprocessor into DataSet.

- DataSet's __init__ takes a function as argument, rather than class object
- Preprocessor is about to remove. Don't use anymore.
- Remove cross_validate in trainer, because it is rarely used and wired
- Loader.load is expected to be a static method
- Delete sth. in other_modules.py
- Add more tests
- Delete extra sample data
tags/v0.1.0^2
FengZiYjun 6 years ago
parent
commit
5be4cb7bb5
17 changed files with 337 additions and 17872 deletions
  1. +12
    -19
      fastNLP/core/dataset.py
  2. +1
    -8
      fastNLP/core/preprocess.py
  3. +0
    -26
      fastNLP/core/trainer.py
  4. +4
    -5
      fastNLP/fastnlp.py
  5. +4
    -2
      fastNLP/loader/base_loader.py
  6. +2
    -1
      fastNLP/loader/dataset_loader.py
  7. +0
    -27
      fastNLP/modules/other_modules.py
  8. +4
    -1
      fastNLP/saver/config_saver.py
  9. +243
    -0
      test/core/test_dataset.py
  10. +4
    -4
      test/core/test_predictor.py
  11. +0
    -8286
      test/data_for_tests/cws_test
  12. +0
    -9483
      test/data_for_tests/cws_train
  13. +1
    -1
      test/model/seq_labeling.py
  14. +2
    -2
      test/model/test_cws.py
  15. +2
    -2
      test/model/text_classify.py
  16. +23
    -1
      test/modules/test_other_modules.py
  17. +35
    -4
      test/saver/test_config_saver.py

+ 12
- 19
fastNLP/core/dataset.py View File

@@ -70,18 +70,18 @@ class DataSet(list):


""" """


def __init__(self, name="", instances=None, loader=None):
def __init__(self, name="", instances=None, load_func=None):
""" """


:param name: str, the name of the dataset. (default: "") :param name: str, the name of the dataset. (default: "")
:param instances: list of Instance objects. (default: None) :param instances: list of Instance objects. (default: None)
:param load_func: a function that takes the dataset path (string) as input and returns multi-level lists.
""" """
list.__init__([]) list.__init__([])
self.name = name self.name = name
if instances is not None: if instances is not None:
self.extend(instances) self.extend(instances)
self.dataset_loader = loader
self.data_set_load_func = load_func


def index_all(self, vocab): def index_all(self, vocab):
for ins in self: for ins in self:
@@ -117,15 +117,15 @@ class DataSet(list):
return lengths return lengths


def convert(self, data): def convert(self, data):
"""Convert lists of strings into Instances with Fields"""
"""Convert lists of strings into Instances with Fields, creating Vocabulary for labeled data. Used in Training."""
raise NotImplementedError raise NotImplementedError


def convert_with_vocabs(self, data, vocabs): def convert_with_vocabs(self, data, vocabs):
"""Convert lists of strings into Instances with Fields, using existing Vocabulary. Useful in predicting."""
"""Convert lists of strings into Instances with Fields, using existing Vocabulary, with labels. Used in Testing."""
raise NotImplementedError raise NotImplementedError


def convert_for_infer(self, data, vocabs): def convert_for_infer(self, data, vocabs):
"""Convert lists of strings into Instances with Fields."""
"""Convert lists of strings into Instances with Fields, using existing Vocabulary, without labels. Used in predicting."""


def load(self, data_path, vocabs=None, infer=False): def load(self, data_path, vocabs=None, infer=False):
"""Load data from the given files. """Load data from the given files.
@@ -135,7 +135,7 @@ class DataSet(list):
:param vocabs: dict of (name: Vocabulary object), used to index data. If not provided, a new vocabulary will be constructed. :param vocabs: dict of (name: Vocabulary object), used to index data. If not provided, a new vocabulary will be constructed.


""" """
raw_data = self.dataset_loader.load(data_path)
raw_data = self.data_set_load_func(data_path)
if infer is True: if infer is True:
self.convert_for_infer(raw_data, vocabs) self.convert_for_infer(raw_data, vocabs)
else: else:
@@ -145,7 +145,7 @@ class DataSet(list):
self.convert(raw_data) self.convert(raw_data)


def load_raw(self, raw_data, vocabs): def load_raw(self, raw_data, vocabs):
"""
"""Load raw data without loader. Used in FastNLP class.


:param raw_data: :param raw_data:
:param vocabs: :param vocabs:
@@ -174,8 +174,8 @@ class DataSet(list):




class SeqLabelDataSet(DataSet): class SeqLabelDataSet(DataSet):
def __init__(self, instances=None, loader=POSDataSetLoader()):
super(SeqLabelDataSet, self).__init__(name="", instances=instances, loader=loader)
def __init__(self, instances=None, load_func=POSDataSetLoader().load):
super(SeqLabelDataSet, self).__init__(name="", instances=instances, load_func=load_func)
self.word_vocab = Vocabulary() self.word_vocab = Vocabulary()
self.label_vocab = Vocabulary() self.label_vocab = Vocabulary()


@@ -231,8 +231,8 @@ class SeqLabelDataSet(DataSet):




class TextClassifyDataSet(DataSet): class TextClassifyDataSet(DataSet):
def __init__(self, instances=None, loader=ClassDataSetLoader()):
super(TextClassifyDataSet, self).__init__(name="", instances=instances, loader=loader)
def __init__(self, instances=None, load_func=ClassDataSetLoader().load):
super(TextClassifyDataSet, self).__init__(name="", instances=instances, load_func=load_func)
self.word_vocab = Vocabulary() self.word_vocab = Vocabulary()
self.label_vocab = Vocabulary(need_default=False) self.label_vocab = Vocabulary(need_default=False)


@@ -285,10 +285,3 @@ def change_field_is_target(data_set, field_name, new_target):
for inst in data_set: for inst in data_set:
inst.fields[field_name].is_target = new_target inst.fields[field_name].is_target = new_target



if __name__ == "__main__":
data_set = SeqLabelDataSet()
data_set.load("../../test/data_for_tests/people.txt")
a, b = data_set.split(0.3)
print(type(data_set), type(a), type(b))
print(len(data_set), len(a), len(b))

+ 1
- 8
fastNLP/core/preprocess.py View File

@@ -78,6 +78,7 @@ class Preprocessor(object):
is only available when label_is_seq is True. Default: False. is only available when label_is_seq is True. Default: False.
:param add_char_field: bool, whether to add character representations to all TextFields. Default: False. :param add_char_field: bool, whether to add character representations to all TextFields. Default: False.
""" """
print("Preprocessor is about to deprecate. Please use DataSet class.")
self.data_vocab = Vocabulary() self.data_vocab = Vocabulary()
if label_is_seq is True: if label_is_seq is True:
if share_vocab is True: if share_vocab is True:
@@ -307,11 +308,3 @@ class ClassPreprocess(Preprocessor):
print("[FastNLP warning] ClassPreprocess is about to deprecate. Please use Preprocess directly.") print("[FastNLP warning] ClassPreprocess is about to deprecate. Please use Preprocess directly.")
super(ClassPreprocess, self).__init__() super(ClassPreprocess, self).__init__()



if __name__ == "__main__":
p = Preprocessor()
train_dev_data = [[["I", "am", "a", "good", "student", "."], "0"],
[["You", "are", "pretty", "."], "1"]
]
training_set = p.run(train_dev_data)
print(training_set)

+ 0
- 26
fastNLP/core/trainer.py View File

@@ -1,4 +1,3 @@
import copy
import os import os
import time import time
from datetime import timedelta from datetime import timedelta
@@ -178,31 +177,6 @@ class Trainer(object):
logger.info(print_output) logger.info(print_output)
step += 1 step += 1


def cross_validate(self, network, train_data_cv, dev_data_cv):
"""Training with cross validation.

:param network: the model
:param train_data_cv: four-level list, of shape [num_folds, num_examples, 2, ?]
:param dev_data_cv: four-level list, of shape [num_folds, num_examples, 2, ?]

"""
if len(train_data_cv) != len(dev_data_cv):
logger.error("the number of folds in train and dev data unequals {}!={}".format(len(train_data_cv),
len(dev_data_cv)))
raise RuntimeError("the number of folds in train and dev data unequals")
if self.validate is False:
logger.warn("Cross validation requires self.validate to be True. Please turn it on. ")
print("[warning] Cross validation requires self.validate to be True. Please turn it on. ")
self.validate = True

n_fold = len(train_data_cv)
logger.info("perform {} folds cross validation.".format(n_fold))
for i in range(n_fold):
print("CV:", i)
logger.info("running the {} of {} folds cross validation".format(i + 1, n_fold))
network_copy = copy.deepcopy(network)
self.train(network_copy, train_data_cv[i], dev_data_cv[i])

def mode(self, model, is_test=False): def mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently. """Train mode or Test mode. This is for PyTorch currently.




+ 4
- 5
fastNLP/fastnlp.py View File

@@ -1,11 +1,10 @@
import os import os


from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet
from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
from fastNLP.core.preprocess import load_pickle from fastNLP.core.preprocess import load_pickle
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.model_loader import ModelLoader from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet



""" """
mapping from model name to [URL, file_name.class_name, model_pickle_name] mapping from model name to [URL, file_name.class_name, model_pickle_name]
@@ -73,7 +72,7 @@ class FastNLP(object):
:param model_dir: this directory should contain the following files: :param model_dir: this directory should contain the following files:
1. a trained model 1. a trained model
2. a config file, which is a fastNLP's configuration. 2. a config file, which is a fastNLP's configuration.
3. a Vocab file, which is a pickle object of a Vocab instance.
3. two Vocab files, which are pickle objects of Vocab instances, representing feature and label vocabs.
""" """
self.model_dir = model_dir self.model_dir = model_dir
self.model = None self.model = None
@@ -192,7 +191,7 @@ class FastNLP(object):




def _load(self, model_dir, model_name): def _load(self, model_dir, model_name):
# To do
return 0 return 0


def _download(self, model_name, url): def _download(self, model_name, url):
@@ -202,7 +201,7 @@ class FastNLP(object):
:param url: :param url:
""" """
print("Downloading {} from {}".format(model_name, url)) print("Downloading {} from {}".format(model_name, url))
# To do
# TODO: download model via url


def model_exist(self, model_dir): def model_exist(self, model_dir):
""" """


+ 4
- 2
fastNLP/loader/base_loader.py View File

@@ -3,12 +3,14 @@ class BaseLoader(object):
def __init__(self): def __init__(self):
super(BaseLoader, self).__init__() super(BaseLoader, self).__init__()


def load_lines(self, data_path):
@staticmethod
def load_lines(data_path):
with open(data_path, "r", encoding="utf=8") as f: with open(data_path, "r", encoding="utf=8") as f:
text = f.readlines() text = f.readlines()
return [line.strip() for line in text] return [line.strip() for line in text]


def load(self, data_path):
@staticmethod
def load(data_path):
with open(data_path, "r", encoding="utf-8") as f: with open(data_path, "r", encoding="utf-8") as f:
text = f.readlines() text = f.readlines()
return [[word for word in sent.strip()] for sent in text] return [[word for word in sent.strip()] for sent in text]


+ 2
- 1
fastNLP/loader/dataset_loader.py View File

@@ -84,7 +84,8 @@ class TokenizeDataSetLoader(DataSetLoader):
def __init__(self): def __init__(self):
super(TokenizeDataSetLoader, self).__init__() super(TokenizeDataSetLoader, self).__init__()


def load(self, data_path, max_seq_len=32):
@staticmethod
def load(data_path, max_seq_len=32):
""" """
load pku dataset for Chinese word segmentation load pku dataset for Chinese word segmentation
CWS (Chinese Word Segmentation) pku training dataset format: CWS (Chinese Word Segmentation) pku training dataset format:


+ 0
- 27
fastNLP/modules/other_modules.py View File

@@ -196,30 +196,3 @@ class BiAffine(nn.Module):
output = output * mask_d.unsqueeze(1).unsqueeze(3) * mask_e.unsqueeze(1).unsqueeze(2) output = output * mask_d.unsqueeze(1).unsqueeze(3) * mask_e.unsqueeze(1).unsqueeze(2)


return output return output


class Transpose(nn.Module):
def __init__(self, x, y):
super(Transpose, self).__init__()
self.x = x
self.y = y

def forward(self, x):
return x.transpose(self.x, self.y)


class WordDropout(nn.Module):
def __init__(self, dropout_rate, drop_to_token):
super(WordDropout, self).__init__()
self.dropout_rate = dropout_rate
self.drop_to_token = drop_to_token

def forward(self, word_idx):
if not self.training:
return word_idx
drop_mask = torch.rand(word_idx.shape) < self.dropout_rate
if word_idx.device.type == 'cuda':
drop_mask = drop_mask.cuda()
drop_mask = drop_mask.long()
output = drop_mask * self.drop_to_token + (1 - drop_mask) * word_idx
return output

+ 4
- 1
fastNLP/saver/config_saver.py View File

@@ -104,7 +104,8 @@ class ConfigSaver(object):
:return: :return:
""" """
section_file = self._get_section(section_name) section_file = self._get_section(section_name)
if len(section_file.__dict__.keys()) == 0:#the section not in file before
if len(section_file.__dict__.keys()) == 0: # the section not in the file before
# append this section to config file
with open(self.file_path, 'a') as f: with open(self.file_path, 'a') as f:
f.write('[' + section_name + ']\n') f.write('[' + section_name + ']\n')
for k in section.__dict__.keys(): for k in section.__dict__.keys():
@@ -114,9 +115,11 @@ class ConfigSaver(object):
else: else:
f.write(str(section[k]) + '\n\n') f.write(str(section[k]) + '\n\n')
else: else:
# the section exists
change_file = False change_file = False
for k in section.__dict__.keys(): for k in section.__dict__.keys():
if k not in section_file: if k not in section_file:
# find a new key in this section
change_file = True change_file = True
break break
if section_file[k] != section[k]: if section_file[k] != section[k]:


+ 243
- 0
test/core/test_dataset.py View File

@@ -0,0 +1,243 @@
import unittest

from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet
from fastNLP.core.dataset import create_dataset_from_lists


class TestDataSet(unittest.TestCase):
labeled_data_list = [
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
]
unlabeled_data_list = [
["a", "b", "e", "d"],
["a", "b", "e", "d"],
["a", "b", "e", "d"]
]
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3}
label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4}

def test_case_1(self):
data_set = create_dataset_from_lists(self.labeled_data_list, self.word_vocab, has_target=True,
label_vocab=self.label_vocab)
self.assertEqual(len(data_set), len(self.labeled_data_list))
self.assertTrue(len(data_set) > 0)
self.assertTrue(hasattr(data_set[0], "fields"))
self.assertTrue("word_seq" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
self.assertEqual(data_set[0].fields["word_seq"]._index,
[self.word_vocab[c] for c in self.labeled_data_list[0][0]])

self.assertTrue("label_seq" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["label_seq"], "text"))
self.assertTrue(hasattr(data_set[0].fields["label_seq"], "_index"))
self.assertEqual(data_set[0].fields["label_seq"].text, self.labeled_data_list[0][1])
self.assertEqual(data_set[0].fields["label_seq"]._index,
[self.label_vocab[c] for c in self.labeled_data_list[0][1]])

def test_case_2(self):
data_set = create_dataset_from_lists(self.unlabeled_data_list, self.word_vocab, has_target=False)

self.assertEqual(len(data_set), len(self.unlabeled_data_list))
self.assertTrue(len(data_set) > 0)
self.assertTrue(hasattr(data_set[0], "fields"))
self.assertTrue("word_seq" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
self.assertEqual(data_set[0].fields["word_seq"].text, self.unlabeled_data_list[0])
self.assertEqual(data_set[0].fields["word_seq"]._index,
[self.word_vocab[c] for c in self.unlabeled_data_list[0]])


class TestDataSetConvertion(unittest.TestCase):
labeled_data_list = [
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
]
unlabeled_data_list = [
["a", "b", "e", "d"],
["a", "b", "e", "d"],
["a", "b", "e", "d"]
]
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3}
label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4}

def test_case_1(self):
def loader(path):
labeled_data_list = [
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
]
return labeled_data_list

data_set = SeqLabelDataSet(load_func=loader)
data_set.load("any_path")

self.assertEqual(len(data_set), len(self.labeled_data_list))
self.assertTrue(len(data_set) > 0)
self.assertTrue(hasattr(data_set[0], "fields"))
self.assertTrue("word_seq" in data_set[0].fields)

self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])

self.assertTrue("truth" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["truth"], "text"))
self.assertTrue(hasattr(data_set[0].fields["truth"], "_index"))
self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1])

self.assertTrue("word_seq_origin_len" in data_set[0].fields)

def test_case_2(self):
def loader(path):
unlabeled_data_list = [
["a", "b", "e", "d"],
["a", "b", "e", "d"],
["a", "b", "e", "d"]
]
return unlabeled_data_list

data_set = SeqLabelDataSet(load_func=loader)
data_set.load("any_path", vocabs={"word_vocab": self.word_vocab}, infer=True)

self.assertEqual(len(data_set), len(self.labeled_data_list))
self.assertTrue(len(data_set) > 0)
self.assertTrue(hasattr(data_set[0], "fields"))
self.assertTrue("word_seq" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
self.assertEqual(data_set[0].fields["word_seq"]._index,
[self.word_vocab[c] for c in self.labeled_data_list[0][0]])

self.assertTrue("word_seq_origin_len" in data_set[0].fields)

def test_case_3(self):
def loader(path):
labeled_data_list = [
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
]
return labeled_data_list

data_set = SeqLabelDataSet(load_func=loader)
data_set.load("any_path", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab})

self.assertEqual(len(data_set), len(self.labeled_data_list))
self.assertTrue(len(data_set) > 0)
self.assertTrue(hasattr(data_set[0], "fields"))
self.assertTrue("word_seq" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
self.assertEqual(data_set[0].fields["word_seq"]._index,
[self.word_vocab[c] for c in self.labeled_data_list[0][0]])

self.assertTrue("truth" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["truth"], "text"))
self.assertTrue(hasattr(data_set[0].fields["truth"], "_index"))
self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1])
self.assertEqual(data_set[0].fields["truth"]._index,
[self.label_vocab[c] for c in self.labeled_data_list[0][1]])

self.assertTrue("word_seq_origin_len" in data_set[0].fields)


class TestDataSetConvertionHHH(unittest.TestCase):
labeled_data_list = [
[["a", "b", "e", "d"], "A"],
[["a", "b", "e", "d"], "C"],
[["a", "b", "e", "d"], "B"],
]
unlabeled_data_list = [
["a", "b", "e", "d"],
["a", "b", "e", "d"],
["a", "b", "e", "d"]
]
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3}
label_vocab = {"A": 1, "B": 2, "C": 3}

def test_case_1(self):
def loader(path):
labeled_data_list = [
[["a", "b", "e", "d"], "A"],
[["a", "b", "e", "d"], "C"],
[["a", "b", "e", "d"], "B"],
]
return labeled_data_list

data_set = TextClassifyDataSet(load_func=loader)
data_set.load("xxx")

self.assertEqual(len(data_set), len(self.labeled_data_list))
self.assertTrue(len(data_set) > 0)
self.assertTrue(hasattr(data_set[0], "fields"))
self.assertTrue("word_seq" in data_set[0].fields)

self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])

self.assertTrue("label" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["label"], "label"))
self.assertTrue(hasattr(data_set[0].fields["label"], "_index"))
self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1])

def test_case_2(self):
def loader(path):
labeled_data_list = [
[["a", "b", "e", "d"], "A"],
[["a", "b", "e", "d"], "C"],
[["a", "b", "e", "d"], "B"],
]
return labeled_data_list

data_set = TextClassifyDataSet(load_func=loader)
data_set.load("xxx", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab})

self.assertEqual(len(data_set), len(self.labeled_data_list))
self.assertTrue(len(data_set) > 0)
self.assertTrue(hasattr(data_set[0], "fields"))
self.assertTrue("word_seq" in data_set[0].fields)

self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
self.assertEqual(data_set[0].fields["word_seq"]._index,
[self.word_vocab[c] for c in self.labeled_data_list[0][0]])

self.assertTrue("label" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["label"], "label"))
self.assertTrue(hasattr(data_set[0].fields["label"], "_index"))
self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1])
self.assertEqual(data_set[0].fields["label"]._index, self.label_vocab[self.labeled_data_list[0][1]])

def test_case_3(self):
def loader(path):
unlabeled_data_list = [
["a", "b", "e", "d"],
["a", "b", "e", "d"],
["a", "b", "e", "d"]
]
return unlabeled_data_list

data_set = TextClassifyDataSet(load_func=loader)
data_set.load("xxx", vocabs={"word_vocab": self.word_vocab}, infer=True)

self.assertEqual(len(data_set), len(self.labeled_data_list))
self.assertTrue(len(data_set) > 0)
self.assertTrue(hasattr(data_set[0], "fields"))
self.assertTrue("word_seq" in data_set[0].fields)

self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
self.assertEqual(data_set[0].fields["word_seq"]._index,
[self.word_vocab[c] for c in self.labeled_data_list[0][0]])

+ 4
- 4
test/core/test_predictor.py View File

@@ -1,13 +1,13 @@
import os import os
import unittest import unittest


from fastNLP.core.predictor import Predictor
from fastNLP.core.dataset import TextClassifyDataSet, SeqLabelDataSet from fastNLP.core.dataset import TextClassifyDataSet, SeqLabelDataSet
from fastNLP.core.predictor import Predictor
from fastNLP.core.preprocess import save_pickle from fastNLP.core.preprocess import save_pickle
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from fastNLP.loader.base_loader import BaseLoader from fastNLP.loader.base_loader import BaseLoader
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.models.cnn_text_classification import CNNText from fastNLP.models.cnn_text_classification import CNNText
from fastNLP.models.sequence_modeling import SeqLabeling




class TestPredictor(unittest.TestCase): class TestPredictor(unittest.TestCase):
@@ -42,7 +42,7 @@ class TestPredictor(unittest.TestCase):
predictor = Predictor("./save/", pre.text_classify_post_processor) predictor = Predictor("./save/", pre.text_classify_post_processor)


# Load infer data # Load infer data
infer_data_set = TextClassifyDataSet(loader=BaseLoader())
infer_data_set = TextClassifyDataSet(load_func=BaseLoader.load)
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx})


results = predictor.predict(network=model, data=infer_data_set) results = predictor.predict(network=model, data=infer_data_set)
@@ -59,7 +59,7 @@ class TestPredictor(unittest.TestCase):
model = SeqLabeling(model_args) model = SeqLabeling(model_args)
predictor = Predictor("./save/", pre.seq_label_post_processor) predictor = Predictor("./save/", pre.seq_label_post_processor)


infer_data_set = SeqLabelDataSet(loader=BaseLoader())
infer_data_set = SeqLabelDataSet(load_func=BaseLoader.load)
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx})


results = predictor.predict(network=model, data=infer_data_set) results = predictor.predict(network=model, data=infer_data_set)


+ 0
- 8286
test/data_for_tests/cws_test
File diff suppressed because it is too large
View File


+ 0
- 9483
test/data_for_tests/cws_train
File diff suppressed because it is too large
View File


+ 1
- 1
test/model/seq_labeling.py View File

@@ -53,7 +53,7 @@ def infer():
print("model loaded!") print("model loaded!")


# Data Loader # Data Loader
infer_data = SeqLabelDataSet(loader=BaseLoader())
infer_data = SeqLabelDataSet(load_func=BaseLoader.load)
infer_data.load(data_infer_path, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}, infer=True) infer_data.load(data_infer_path, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}, infer=True)
print("data set prepared") print("data set prepared")




+ 2
- 2
test/model/test_cws.py View File

@@ -37,7 +37,7 @@ def infer():
print("model loaded!") print("model loaded!")


# Load infer data # Load infer data
infer_data = SeqLabelDataSet(loader=BaseLoader())
infer_data = SeqLabelDataSet(load_func=BaseLoader.load)
infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True) infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True)


# inference # inference
@@ -52,7 +52,7 @@ def train_test():
ConfigLoader().load_config(config_path, {"POS_infer": train_args}) ConfigLoader().load_config(config_path, {"POS_infer": train_args})


# define dataset # define dataset
data_train = SeqLabelDataSet(loader=TokenizeDataSetLoader())
data_train = SeqLabelDataSet(load_func=TokenizeDataSetLoader.load)
data_train.load(cws_data_path) data_train.load(cws_data_path)
train_args["vocab_size"] = len(data_train.word_vocab) train_args["vocab_size"] = len(data_train.word_vocab)
train_args["num_classes"] = len(data_train.label_vocab) train_args["num_classes"] = len(data_train.label_vocab)


+ 2
- 2
test/model/text_classify.py View File

@@ -40,7 +40,7 @@ def infer():
print("vocabulary size:", len(word_vocab)) print("vocabulary size:", len(word_vocab))
print("number of classes:", len(label_vocab)) print("number of classes:", len(label_vocab))


infer_data = TextClassifyDataSet(loader=ClassDataSetLoader())
infer_data = TextClassifyDataSet(load_func=ClassDataSetLoader.load)
infer_data.load(train_data_dir, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}) infer_data.load(train_data_dir, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab})


model_args = ConfigSection() model_args = ConfigSection()
@@ -67,7 +67,7 @@ def train():


# load dataset # load dataset
print("Loading data...") print("Loading data...")
data = TextClassifyDataSet(loader=ClassDataSetLoader())
data = TextClassifyDataSet(load_func=ClassDataSetLoader.load)
data.load(train_data_dir) data.load(train_data_dir)


print("vocabulary size:", len(data.word_vocab)) print("vocabulary size:", len(data.word_vocab))


+ 23
- 1
test/modules/test_other_modules.py View File

@@ -2,7 +2,7 @@ import unittest


import torch import torch


from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear
from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear, BiAffine




class TestGroupNorm(unittest.TestCase): class TestGroupNorm(unittest.TestCase):
@@ -27,3 +27,25 @@ class TestBiLinear(unittest.TestCase):
y = bl(x_left, x_right) y = bl(x_left, x_right)
print(bl) print(bl)
bl2 = BiLinear(n_left=15, n_right=15, n_out=10, bias=True) bl2 = BiLinear(n_left=15, n_right=15, n_out=10, bias=True)


class TestBiAffine(unittest.TestCase):
def test_case_1(self):
batch_size = 16
encoder_length = 21
decoder_length = 32
layer = BiAffine(10, 10, 25, biaffine=True)
decoder_input = torch.randn((batch_size, encoder_length, 10))
encoder_input = torch.randn((batch_size, decoder_length, 10))
y = layer(decoder_input, encoder_input)
self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, decoder_length))

def test_case_2(self):
batch_size = 16
encoder_length = 21
decoder_length = 32
layer = BiAffine(10, 10, 25, biaffine=False)
decoder_input = torch.randn((batch_size, encoder_length, 10))
encoder_input = torch.randn((batch_size, decoder_length, 10))
y = layer(decoder_input, encoder_input)
self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, 1))

+ 35
- 4
test/saver/test_config_saver.py View File

@@ -1,8 +1,5 @@
import os import os

import unittest import unittest
import configparser
import json


from fastNLP.loader.config_loader import ConfigSection, ConfigLoader from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
from fastNLP.saver.config_saver import ConfigSaver from fastNLP.saver.config_saver import ConfigSaver
@@ -10,7 +7,7 @@ from fastNLP.saver.config_saver import ConfigSaver


class TestConfigSaver(unittest.TestCase): class TestConfigSaver(unittest.TestCase):
def test_case_1(self): def test_case_1(self):
config_file_dir = "./test/loader/"
config_file_dir = "test/loader/"
config_file_name = "config" config_file_name = "config"
config_file_path = os.path.join(config_file_dir, config_file_name) config_file_path = os.path.join(config_file_dir, config_file_name)


@@ -80,3 +77,37 @@ class TestConfigSaver(unittest.TestCase):
tmp_config_saver = ConfigSaver("file-NOT-exist") tmp_config_saver = ConfigSaver("file-NOT-exist")
except Exception as e: except Exception as e:
pass pass

def test_case_2(self):
config = "[section_A]\n[section_B]\n"

with open("./test.cfg", "w", encoding="utf-8") as f:
f.write(config)
saver = ConfigSaver("./test.cfg")

section = ConfigSection()
section["doubles"] = 0.8
section["tt"] = [1, 2, 3]
section["test"] = 105
section["str"] = "this is a str"

saver.save_config_file("section_A", section)

os.system("rm ./test.cfg")

def test_case_3(self):
config = "[section_A]\ndoubles = 0.9\ntt = [1, 2, 3]\n[section_B]\n"

with open("./test.cfg", "w", encoding="utf-8") as f:
f.write(config)
saver = ConfigSaver("./test.cfg")

section = ConfigSection()
section["doubles"] = 0.8
section["tt"] = [1, 2, 3]
section["test"] = 105
section["str"] = "this is a str"

saver.save_config_file("section_A", section)

os.system("rm ./test.cfg")

Loading…
Cancel
Save