From a4c9786ca4f69b037521fd63e9a7ed8c5fb75e4f Mon Sep 17 00:00:00 2001 From: yunfan Date: Wed, 17 Oct 2018 09:59:56 +0800 Subject: [PATCH] update dataset & loader --- fastNLP/core/dataset.py | 263 +++-------------------- fastNLP/core/field.py | 8 + fastNLP/core/instance.py | 30 ++- fastNLP/core/predictor.py | 5 +- fastNLP/fastnlp.py | 14 +- fastNLP/loader/dataset_loader.py | 102 ++++++++- reproduction/chinese_word_segment/run.py | 2 +- test/core/test_batch.py | 13 +- test/core/test_dataset.py | 201 +---------------- test/core/test_predictor.py | 12 +- test/core/test_tester.py | 4 +- test/core/test_trainer.py | 4 +- test/model/test_cws.py | 23 +- test/model/test_seq_label.py | 8 +- 14 files changed, 208 insertions(+), 481 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 13370969..a10a24d2 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -6,91 +6,33 @@ from copy import deepcopy from fastNLP.core.field import TextField, LabelField from fastNLP.core.instance import Instance from fastNLP.core.vocabulary import Vocabulary -from fastNLP.loader.dataset_loader import POSDataSetLoader, ClassDataSetLoader - - -def create_dataset_from_lists(str_lists: list, word_vocab: dict, has_target: bool = False, label_vocab: dict = None): - if has_target is True: - if label_vocab is None: - raise RuntimeError("Must provide label vocabulary to transform labels.") - return create_labeled_dataset_from_lists(str_lists, word_vocab, label_vocab) - else: - return create_unlabeled_dataset_from_lists(str_lists, word_vocab) - - -def create_labeled_dataset_from_lists(str_lists, word_vocab, label_vocab): - """Create an DataSet instance that contains labels. - - :param str_lists: list of list of strings, [num_examples, 2, *]. - :: - [ - [[word_11, word_12, ...], [label_11, label_12, ...]], - ... - ] - - :param word_vocab: dict of (str: int), which means (word: index). - :param label_vocab: dict of (str: int), which means (word: index). - :return data_set: a DataSet instance. - - """ - data_set = DataSet() - for example in str_lists: - word_seq, label_seq = example[0], example[1] - x = TextField(word_seq, is_target=False) - y = TextField(label_seq, is_target=True) - data_set.append(Instance(word_seq=x, label_seq=y)) - data_set.index_field("word_seq", word_vocab) - data_set.index_field("label_seq", label_vocab) - return data_set - - -def create_unlabeled_dataset_from_lists(str_lists, word_vocab): - """Create an DataSet instance that contains no labels. - - :param str_lists: list of list of strings, [num_examples, *]. - :: - [ - [word_11, word_12, ...], - ... - ] - - :param word_vocab: dict of (str: int), which means (word: index). - :return data_set: a DataSet instance. - - """ - data_set = DataSet() - for word_seq in str_lists: - x = TextField(word_seq, is_target=False) - data_set.append(Instance(word_seq=x)) - data_set.index_field("word_seq", word_vocab) - return data_set - class DataSet(list): """A DataSet object is a list of Instance objects. """ - def __init__(self, name="", instances=None, load_func=None): + def __init__(self, name="", instances=None): """ :param name: str, the name of the dataset. (default: "") :param instances: list of Instance objects. (default: None) - :param load_func: a function that takes the dataset path (string) as input and returns multi-level lists. """ list.__init__([]) self.name = name + self.origin_len = None if instances is not None: self.extend(instances) - self.data_set_load_func = load_func def index_all(self, vocab): for ins in self: ins.index_all(vocab) + return self def index_field(self, field_name, vocab): for ins in self: ins.index_field(field_name, vocab) + return self def to_tensor(self, idx: int, padding_length: dict): """Convert an instance in a dataset to tensor. @@ -102,7 +44,7 @@ class DataSet(list): """ ins = self[idx] - return ins.to_tensor(padding_length) + return ins.to_tensor(padding_length, self.origin_len) def get_length(self): """Fetch lengths of all fields in all instances in a dataset. @@ -117,42 +59,9 @@ class DataSet(list): lengths[field_name].append(field_length) return lengths - def convert(self, data): - """Convert lists of strings into Instances with Fields, creating Vocabulary for labeled data. Used in Training.""" - raise NotImplementedError - - def convert_with_vocabs(self, data, vocabs): - """Convert lists of strings into Instances with Fields, using existing Vocabulary, with labels. Used in Testing.""" - raise NotImplementedError - - def convert_for_infer(self, data, vocabs): - """Convert lists of strings into Instances with Fields, using existing Vocabulary, without labels. Used in predicting.""" - - def load(self, data_path, vocabs=None, infer=False): - """Load data from the given files. - - :param data_path: str, the path to the data - :param infer: bool. If True, there is no label information in the data. Default: False. - :param vocabs: dict of (name: Vocabulary object), used to index data. If not provided, a new vocabulary will be constructed. - - """ - raw_data = self.data_set_load_func(data_path) - if infer is True: - self.convert_for_infer(raw_data, vocabs) - else: - if vocabs is not None: - self.convert_with_vocabs(raw_data, vocabs) - else: - self.convert(raw_data) - - def load_raw(self, raw_data, vocabs): - """Load raw data without loader. Used in FastNLP class. - - :param raw_data: - :param vocabs: - :return: - """ - self.convert_for_infer(raw_data, vocabs) + def shuffle(self): + random.shuffle(self) + return self def split(self, ratio, shuffle=True): """Train/dev splitting @@ -165,7 +74,7 @@ class DataSet(list): """ assert 0 < ratio < 1 if shuffle: - random.shuffle(self) + self.shuffle() split_idx = int(len(self) * ratio) dev_set = deepcopy(self) train_set = deepcopy(self) @@ -173,134 +82,32 @@ class DataSet(list): del dev_set[split_idx:] return train_set, dev_set - -class SeqLabelDataSet(DataSet): - def __init__(self, instances=None, load_func=POSDataSetLoader().load): - super(SeqLabelDataSet, self).__init__(name="", instances=instances, load_func=load_func) - self.word_vocab = Vocabulary() - self.label_vocab = Vocabulary() - - def convert(self, data): - """Convert lists of strings into Instances with Fields. - - :param data: 3-level lists. Entries are strings. + def rename_field(self, old_name, new_name): + """rename a field """ - bar = ProgressBar(total=len(data)) - for example in data: - word_seq, label_seq = example[0], example[1] - # list, list - self.word_vocab.update(word_seq) - self.label_vocab.update(label_seq) - x = TextField(word_seq, is_target=False) - x_len = LabelField(len(word_seq), is_target=False) - y = TextField(label_seq, is_target=False) - instance = Instance() - instance.add_field("word_seq", x) - instance.add_field("truth", y) - instance.add_field("word_seq_origin_len", x_len) - self.append(instance) - bar.move() - self.index_field("word_seq", self.word_vocab) - self.index_field("truth", self.label_vocab) - # no need to index "word_seq_origin_len" - - def convert_with_vocabs(self, data, vocabs): - for example in data: - word_seq, label_seq = example[0], example[1] - # list, list - x = TextField(word_seq, is_target=False) - x_len = LabelField(len(word_seq), is_target=False) - y = TextField(label_seq, is_target=False) - instance = Instance() - instance.add_field("word_seq", x) - instance.add_field("truth", y) - instance.add_field("word_seq_origin_len", x_len) - self.append(instance) - self.index_field("word_seq", vocabs["word_vocab"]) - self.index_field("truth", vocabs["label_vocab"]) - # no need to index "word_seq_origin_len" - - def convert_for_infer(self, data, vocabs): - for word_seq in data: - # list - x = TextField(word_seq, is_target=False) - x_len = LabelField(len(word_seq), is_target=False) - instance = Instance() - instance.add_field("word_seq", x) - instance.add_field("word_seq_origin_len", x_len) - self.append(instance) - self.index_field("word_seq", vocabs["word_vocab"]) - # no need to index "word_seq_origin_len" - - -class TextClassifyDataSet(DataSet): - def __init__(self, instances=None, load_func=ClassDataSetLoader().load): - super(TextClassifyDataSet, self).__init__(name="", instances=instances, load_func=load_func) - self.word_vocab = Vocabulary() - self.label_vocab = Vocabulary(need_default=False) - - def convert(self, data): - for example in data: - word_seq, label = example[0], example[1] - # list, str - self.word_vocab.update(word_seq) - self.label_vocab.update(label) - x = TextField(word_seq, is_target=False) - y = LabelField(label, is_target=True) - instance = Instance() - instance.add_field("word_seq", x) - instance.add_field("label", y) - self.append(instance) - self.index_field("word_seq", self.word_vocab) - self.index_field("label", self.label_vocab) - - def convert_with_vocabs(self, data, vocabs): - for example in data: - word_seq, label = example[0], example[1] - # list, str - x = TextField(word_seq, is_target=False) - y = LabelField(label, is_target=True) - instance = Instance() - instance.add_field("word_seq", x) - instance.add_field("label", y) - self.append(instance) - self.index_field("word_seq", vocabs["word_vocab"]) - self.index_field("label", vocabs["label_vocab"]) - - def convert_for_infer(self, data, vocabs): - for word_seq in data: - # list - x = TextField(word_seq, is_target=False) - instance = Instance() - instance.add_field("word_seq", x) - self.append(instance) - self.index_field("word_seq", vocabs["word_vocab"]) - - -def change_field_is_target(data_set, field_name, new_target): - """Change the flag of is_target in a field. - - :param data_set: a DataSet object - :param field_name: str, the name of the field - :param new_target: one of (True, False, None), representing this field is batch_x / is batch_y / neither. - - """ - for inst in data_set: - inst.fields[field_name].is_target = new_target - - -class ProgressBar: + for ins in self: + ins.rename_field(old_name, new_name) + return self - def __init__(self, count=0, total=0, width=100): - self.count = count - self.total = total - self.width = width + def set_target(self, **fields): + """Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged. - def move(self): - self.count += 1 - progress = self.width * self.count // self.total - sys.stdout.write('{0:3}/{1:3}: '.format(self.count, self.total)) - sys.stdout.write('#' * progress + '-' * (self.width - progress) + '\r') - if progress == self.width: - sys.stdout.write('\n') - sys.stdout.flush() + :param key-value pairs for field-name and `is_target` value(True, False or None). + """ + for ins in self: + ins.set_target(**fields) + return self + + def update_vocab(self, **name_vocab): + for field_name, vocab in name_vocab.items(): + for ins in self: + vocab.update(ins[field_name].contents()) + return self + + def set_origin_len(self, origin_field, origin_len_name=None): + if origin_field is None: + self.origin_len = None + else: + self.origin_len = (origin_field + "_origin_len", origin_field) \ + if origin_len_name is None else (origin_len_name, origin_field) + return self diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index 770482ea..64aafdd3 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -18,6 +18,8 @@ class Field(object): def to_tensor(self, padding_length): raise NotImplementedError + def contents(self): + raise NotImplementedError class TextField(Field): def __init__(self, text, is_target): @@ -57,6 +59,8 @@ class TextField(Field): pads = [0] * (padding_length - self.get_length()) return torch.LongTensor(self._index + pads) + def contents(self): + return self.text.copy() class LabelField(Field): """The Field representing a single label. Can be a string or integer. @@ -92,6 +96,8 @@ class LabelField(Field): else: return torch.LongTensor([self._index]) + def contents(self): + return [self.label] class SeqLabelField(Field): def __init__(self, label_seq, is_target=True): @@ -122,6 +128,8 @@ class SeqLabelField(Field): else: return torch.LongTensor(self._index + pads) + def contents(self): + return self.label_seq.copy() if __name__ == "__main__": tf = TextField("test the code".split(), is_target=False) diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py index ebf01912..b01c336b 100644 --- a/fastNLP/core/instance.py +++ b/fastNLP/core/instance.py @@ -1,3 +1,5 @@ +import torch + class Instance(object): """An instance which consists of Fields is an example in the DataSet. @@ -10,6 +12,28 @@ class Instance(object): def add_field(self, field_name, field): self.fields[field_name] = field + return self + + def rename_field(self, old_name, new_name): + if old_name in self.fields: + self.fields[new_name] = self.fields.pop(old_name) + if old_name in self.indexes: + self.indexes[new_name] = self.indexes.pop(old_name) + else: + print("error, no such field: {}".format(old_name)) + return self + + def set_target(self, **fields): + for name, val in fields.items(): + if name in self.fields: + self.fields[name].is_target = val + return self + + def __getitem__(self, name): + if name in self.fields: + return self.fields[name] + else: + raise KeyError("{} not found".format(name)) def get_length(self): """Fetch the length of all fields in the instance. @@ -24,6 +48,7 @@ class Instance(object): """use `vocab` to index certain field """ self.indexes[field_name] = self.fields[field_name].index(vocab) + return self def index_all(self, vocab): """use `vocab` to index all fields @@ -35,7 +60,7 @@ class Instance(object): self.indexes = indexes return indexes - def to_tensor(self, padding_length: dict): + def to_tensor(self, padding_length: dict, origin_len=None): """Convert instance to tensor. :param padding_length: dict of (str: int), which means (field name: padding_length of this field) @@ -53,4 +78,7 @@ class Instance(object): else: # is_target is None continue + if origin_len is not None: + name, field_name = origin_len + tensor_x[name] = torch.LongTensor([self.fields[field_name].get_length()]) return tensor_x, tensor_y diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py index 14c4e8c1..c5d22df4 100644 --- a/fastNLP/core/predictor.py +++ b/fastNLP/core/predictor.py @@ -2,9 +2,9 @@ import numpy as np import torch from fastNLP.core.batch import Batch -from fastNLP.core.dataset import create_dataset_from_lists from fastNLP.core.preprocess import load_pickle from fastNLP.core.sampler import SequentialSampler +from fastNLP.loader.dataset_loader import convert_seq2seq_dataset, convert_seq2tag_dataset, convert_seq_dataset class Predictor(object): @@ -79,7 +79,8 @@ class Predictor(object): :return data_set: a DataSet instance. """ assert isinstance(data, list) - return create_dataset_from_lists(data, self.word_vocab, has_target=False) + data = convert_seq_dataset(data) + data.index_field("word_seq", self.word_vocab) class SeqLabelInfer(Predictor): diff --git a/fastNLP/fastnlp.py b/fastNLP/fastnlp.py index 0bd56d18..816db82d 100644 --- a/fastNLP/fastnlp.py +++ b/fastNLP/fastnlp.py @@ -1,6 +1,7 @@ import os -from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet +from fastNLP.core.dataset import DataSet +from fastNLP.loader.dataset_loader import convert_seq_dataset from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer from fastNLP.core.preprocess import load_pickle from fastNLP.loader.config_loader import ConfigLoader, ConfigSection @@ -178,13 +179,10 @@ class FastNLP(object): :param infer_input: 2-D lists of strings :return data_set: a DataSet object """ - if self.infer_type == "seq_label": - data_set = SeqLabelDataSet() - data_set.load_raw(infer_input, {"word_vocab": self.word_vocab}) - return data_set - elif self.infer_type == "text_class": - data_set = TextClassifyDataSet() - data_set.load_raw(infer_input, {"word_vocab": self.word_vocab}) + if self.infer_type in ["seq_label", "text_class"]: + data_set = convert_seq_dataset(infer_input) + data_set.index_field("word_seq", self.word_vocab) + data_set.set_origin_len("word_seq") return data_set else: raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type)) diff --git a/fastNLP/loader/dataset_loader.py b/fastNLP/loader/dataset_loader.py index a6a0fb77..4d3674e2 100644 --- a/fastNLP/loader/dataset_loader.py +++ b/fastNLP/loader/dataset_loader.py @@ -1,6 +1,71 @@ import os from fastNLP.loader.base_loader import BaseLoader +from fastNLP.core.dataset import DataSet +from fastNLP.core.instance import Instance +from fastNLP.core.field import * + +def convert_seq_dataset(data): + """Create an DataSet instance that contains no labels. + + :param data: list of list of strings, [num_examples, *]. + :: + [ + [word_11, word_12, ...], + ... + ] + + :return: a DataSet. + """ + dataset = DataSet() + for word_seq in data: + x = TextField(word_seq, is_target=False) + dataset.append(Instance(word_seq=x)) + return dataset + +def convert_seq2tag_dataset(data): + """Convert list of data into DataSet + + :param data: list of list of strings, [num_examples, *]. + :: + [ + [ [word_11, word_12, ...], label_1 ], + [ [word_21, word_22, ...], label_2 ], + ... + ] + + :return: a DataSet. + """ + dataset = DataSet() + for sample in data: + word_seq, label = sample[0], sample[1] + ins = Instance() + ins.add_field("word_seq", TextField(word_seq, is_target=False)) \ + .add_field("label", LabelField(label, is_target=True)) + dataset.append(ins) + return dataset + +def convert_seq2seq_dataset(data): + """Convert list of data into DataSet + + :param data: list of list of strings, [num_examples, *]. + :: + [ + [ [word_11, word_12, ...], [label_1, label_1, ...] ], + [ [word_21, word_22, ...], [label_2, label_1, ...] ], + ... + ] + + :return: a DataSet. + """ + dataset = DataSet() + for sample in data: + word_seq, label_seq = sample[0], sample[1] + ins = Instance() + ins.add_field("word_seq", TextField(word_seq, is_target=False)) \ + .add_field("label_seq", TextField(label_seq, is_target=True)) + dataset.append(ins) + return dataset class DataSetLoader(BaseLoader): @@ -48,7 +113,8 @@ class POSDataSetLoader(DataSetLoader): """ with open(data_path, "r", encoding="utf-8") as f: lines = f.readlines() - return self.parse(lines) + data = self.parse(lines) + return self.convert(data) @staticmethod def parse(lines): @@ -75,6 +141,10 @@ class POSDataSetLoader(DataSetLoader): data.append([words, labels]) return data + def convert(self, data): + """Convert lists of strings into Instances with Fields. + """ + return convert_seq2seq_dataset(data) class TokenizeDataSetLoader(DataSetLoader): """ @@ -84,8 +154,7 @@ class TokenizeDataSetLoader(DataSetLoader): def __init__(self): super(TokenizeDataSetLoader, self).__init__() - @staticmethod - def load(data_path, max_seq_len=32): + def load(self, data_path, max_seq_len=32): """ load pku dataset for Chinese word segmentation CWS (Chinese Word Segmentation) pku training dataset format: @@ -130,7 +199,10 @@ class TokenizeDataSetLoader(DataSetLoader): seq_words = words[start:end] seq_labels = labels[start:end] data.append([seq_words, seq_labels]) - return data + return self.convert(data) + + def convert(self, data): + return convert_seq2seq_dataset(data) class ClassDataSetLoader(DataSetLoader): @@ -143,7 +215,8 @@ class ClassDataSetLoader(DataSetLoader): assert os.path.exists(data_path) with open(data_path, "r", encoding="utf-8") as f: lines = f.readlines() - return self.parse(lines) + data = self.parse(lines) + return self.convert(data) @staticmethod def parse(lines): @@ -166,16 +239,18 @@ class ClassDataSetLoader(DataSetLoader): dataset.append(sentence) return dataset + def convert(self, data): + return convert_seq2tag_dataset(data) + class ConllLoader(DataSetLoader): """loader for conll format files""" - def __int__(self, data_path): + def __init__(self): """ :param str data_path: the path to the conll data set """ super(ConllLoader, self).__init__() - self.data_set = self.parse(self.load(data_path)) def load(self, data_path): """ @@ -183,7 +258,8 @@ class ConllLoader(DataSetLoader): """ with open(data_path, "r", encoding="utf-8") as f: lines = f.readlines() - return lines + data = self.parse(lines) + return self.convert(data) @staticmethod def parse(lines): @@ -204,6 +280,9 @@ class ConllLoader(DataSetLoader): tokens.append(line.split()) return sentences + def convert(self, data): + pass + class LMDataSetLoader(DataSetLoader): """Language Model Dataset Loader @@ -222,7 +301,8 @@ class LMDataSetLoader(DataSetLoader): with open(data_path, "r", encoding="utf=8") as f: text = " ".join(f.readlines()) tokens = text.strip().split() - return self.sentence_cut(tokens) + data = self.sentence_cut(tokens) + return self.convert(data) def sentence_cut(self, tokens, sentence_length=15): start_idx = 0 @@ -236,6 +316,8 @@ class LMDataSetLoader(DataSetLoader): data_set.append([x, y]) return data_set + def convert(self, data): + pass class PeopleDailyCorpusLoader(DataSetLoader): """ @@ -286,3 +368,5 @@ class PeopleDailyCorpusLoader(DataSetLoader): ner_examples.append([sent_words, sent_ner]) return pos_tag_examples, ner_examples + def convert(self, data): + pass diff --git a/reproduction/chinese_word_segment/run.py b/reproduction/chinese_word_segment/run.py index f940c5b8..df597942 100644 --- a/reproduction/chinese_word_segment/run.py +++ b/reproduction/chinese_word_segment/run.py @@ -12,7 +12,7 @@ from fastNLP.loader.model_loader import ModelLoader from fastNLP.core.tester import SeqLabelTester from fastNLP.models.sequence_modeling import AdvSeqLabel from fastNLP.core.predictor import SeqLabelInfer -from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target +from fastNLP.core.dataset import DataSet from fastNLP.core.preprocess import save_pickle from fastNLP.core.metrics import SeqLabelEvaluator diff --git a/test/core/test_batch.py b/test/core/test_batch.py index 5de91da8..826167ac 100644 --- a/test/core/test_batch.py +++ b/test/core/test_batch.py @@ -3,7 +3,7 @@ import unittest import torch from fastNLP.core.batch import Batch -from fastNLP.core.dataset import DataSet, create_dataset_from_lists +from fastNLP.core.dataset import DataSet from fastNLP.core.field import TextField, LabelField from fastNLP.core.instance import Instance @@ -51,14 +51,3 @@ class TestCase1(unittest.TestCase): self.assertTrue(isinstance(batch_x["text"], torch.LongTensor)) self.assertTrue(isinstance(batch_y, dict)) self.assertTrue(isinstance(batch_y["label"], torch.LongTensor)) - - -class TestCase2(unittest.TestCase): - def test(self): - data = DataSet() - for text in texts: - x = TextField(text, is_target=False) - ins = Instance(text=x) - data.append(ins) - data_set = create_dataset_from_lists(texts, vocab, has_target=False) - self.assertTrue(type(data) == type(data_set)) diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index 9b79c840..c30cd37f 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -1,7 +1,6 @@ import unittest -from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet -from fastNLP.core.dataset import create_dataset_from_lists +from fastNLP.loader.dataset_loader import convert_seq2seq_dataset, convert_seq_dataset class TestDataSet(unittest.TestCase): @@ -19,8 +18,9 @@ class TestDataSet(unittest.TestCase): label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4} def test_case_1(self): - data_set = create_dataset_from_lists(self.labeled_data_list, self.word_vocab, has_target=True, - label_vocab=self.label_vocab) + data_set = convert_seq2seq_dataset(self.labeled_data_list) + data_set.index_field("word_seq", self.word_vocab) + data_set.index_field("label_seq", self.label_vocab) self.assertEqual(len(data_set), len(self.labeled_data_list)) self.assertTrue(len(data_set) > 0) self.assertTrue(hasattr(data_set[0], "fields")) @@ -39,7 +39,8 @@ class TestDataSet(unittest.TestCase): [self.label_vocab[c] for c in self.labeled_data_list[0][1]]) def test_case_2(self): - data_set = create_dataset_from_lists(self.unlabeled_data_list, self.word_vocab, has_target=False) + data_set = convert_seq_dataset(self.unlabeled_data_list) + data_set.index_field("word_seq", self.word_vocab) self.assertEqual(len(data_set), len(self.unlabeled_data_list)) self.assertTrue(len(data_set) > 0) @@ -51,193 +52,3 @@ class TestDataSet(unittest.TestCase): self.assertEqual(data_set[0].fields["word_seq"]._index, [self.word_vocab[c] for c in self.unlabeled_data_list[0]]) - -class TestDataSetConvertion(unittest.TestCase): - labeled_data_list = [ - [["a", "b", "e", "d"], ["1", "2", "3", "4"]], - [["a", "b", "e", "d"], ["1", "2", "3", "4"]], - [["a", "b", "e", "d"], ["1", "2", "3", "4"]], - ] - unlabeled_data_list = [ - ["a", "b", "e", "d"], - ["a", "b", "e", "d"], - ["a", "b", "e", "d"] - ] - word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} - label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4} - - def test_case_1(self): - def loader(path): - labeled_data_list = [ - [["a", "b", "e", "d"], ["1", "2", "3", "4"]], - [["a", "b", "e", "d"], ["1", "2", "3", "4"]], - [["a", "b", "e", "d"], ["1", "2", "3", "4"]], - ] - return labeled_data_list - - data_set = SeqLabelDataSet(load_func=loader) - data_set.load("any_path") - - self.assertEqual(len(data_set), len(self.labeled_data_list)) - self.assertTrue(len(data_set) > 0) - self.assertTrue(hasattr(data_set[0], "fields")) - self.assertTrue("word_seq" in data_set[0].fields) - - self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) - self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) - self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) - - self.assertTrue("truth" in data_set[0].fields) - self.assertTrue(hasattr(data_set[0].fields["truth"], "text")) - self.assertTrue(hasattr(data_set[0].fields["truth"], "_index")) - self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1]) - - self.assertTrue("word_seq_origin_len" in data_set[0].fields) - - def test_case_2(self): - def loader(path): - unlabeled_data_list = [ - ["a", "b", "e", "d"], - ["a", "b", "e", "d"], - ["a", "b", "e", "d"] - ] - return unlabeled_data_list - - data_set = SeqLabelDataSet(load_func=loader) - data_set.load("any_path", vocabs={"word_vocab": self.word_vocab}, infer=True) - - self.assertEqual(len(data_set), len(self.labeled_data_list)) - self.assertTrue(len(data_set) > 0) - self.assertTrue(hasattr(data_set[0], "fields")) - self.assertTrue("word_seq" in data_set[0].fields) - self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) - self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) - self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) - self.assertEqual(data_set[0].fields["word_seq"]._index, - [self.word_vocab[c] for c in self.labeled_data_list[0][0]]) - - self.assertTrue("word_seq_origin_len" in data_set[0].fields) - - def test_case_3(self): - def loader(path): - labeled_data_list = [ - [["a", "b", "e", "d"], ["1", "2", "3", "4"]], - [["a", "b", "e", "d"], ["1", "2", "3", "4"]], - [["a", "b", "e", "d"], ["1", "2", "3", "4"]], - ] - return labeled_data_list - - data_set = SeqLabelDataSet(load_func=loader) - data_set.load("any_path", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab}) - - self.assertEqual(len(data_set), len(self.labeled_data_list)) - self.assertTrue(len(data_set) > 0) - self.assertTrue(hasattr(data_set[0], "fields")) - self.assertTrue("word_seq" in data_set[0].fields) - self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) - self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) - self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) - self.assertEqual(data_set[0].fields["word_seq"]._index, - [self.word_vocab[c] for c in self.labeled_data_list[0][0]]) - - self.assertTrue("truth" in data_set[0].fields) - self.assertTrue(hasattr(data_set[0].fields["truth"], "text")) - self.assertTrue(hasattr(data_set[0].fields["truth"], "_index")) - self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1]) - self.assertEqual(data_set[0].fields["truth"]._index, - [self.label_vocab[c] for c in self.labeled_data_list[0][1]]) - - self.assertTrue("word_seq_origin_len" in data_set[0].fields) - - -class TestDataSetConvertionHHH(unittest.TestCase): - labeled_data_list = [ - [["a", "b", "e", "d"], "A"], - [["a", "b", "e", "d"], "C"], - [["a", "b", "e", "d"], "B"], - ] - unlabeled_data_list = [ - ["a", "b", "e", "d"], - ["a", "b", "e", "d"], - ["a", "b", "e", "d"] - ] - word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} - label_vocab = {"A": 1, "B": 2, "C": 3} - - def test_case_1(self): - def loader(path): - labeled_data_list = [ - [["a", "b", "e", "d"], "A"], - [["a", "b", "e", "d"], "C"], - [["a", "b", "e", "d"], "B"], - ] - return labeled_data_list - - data_set = TextClassifyDataSet(load_func=loader) - data_set.load("xxx") - - self.assertEqual(len(data_set), len(self.labeled_data_list)) - self.assertTrue(len(data_set) > 0) - self.assertTrue(hasattr(data_set[0], "fields")) - self.assertTrue("word_seq" in data_set[0].fields) - - self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) - self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) - self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) - - self.assertTrue("label" in data_set[0].fields) - self.assertTrue(hasattr(data_set[0].fields["label"], "label")) - self.assertTrue(hasattr(data_set[0].fields["label"], "_index")) - self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1]) - - def test_case_2(self): - def loader(path): - labeled_data_list = [ - [["a", "b", "e", "d"], "A"], - [["a", "b", "e", "d"], "C"], - [["a", "b", "e", "d"], "B"], - ] - return labeled_data_list - - data_set = TextClassifyDataSet(load_func=loader) - data_set.load("xxx", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab}) - - self.assertEqual(len(data_set), len(self.labeled_data_list)) - self.assertTrue(len(data_set) > 0) - self.assertTrue(hasattr(data_set[0], "fields")) - self.assertTrue("word_seq" in data_set[0].fields) - - self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) - self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) - self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) - self.assertEqual(data_set[0].fields["word_seq"]._index, - [self.word_vocab[c] for c in self.labeled_data_list[0][0]]) - - self.assertTrue("label" in data_set[0].fields) - self.assertTrue(hasattr(data_set[0].fields["label"], "label")) - self.assertTrue(hasattr(data_set[0].fields["label"], "_index")) - self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1]) - self.assertEqual(data_set[0].fields["label"]._index, self.label_vocab[self.labeled_data_list[0][1]]) - - def test_case_3(self): - def loader(path): - unlabeled_data_list = [ - ["a", "b", "e", "d"], - ["a", "b", "e", "d"], - ["a", "b", "e", "d"] - ] - return unlabeled_data_list - - data_set = TextClassifyDataSet(load_func=loader) - data_set.load("xxx", vocabs={"word_vocab": self.word_vocab}, infer=True) - - self.assertEqual(len(data_set), len(self.labeled_data_list)) - self.assertTrue(len(data_set) > 0) - self.assertTrue(hasattr(data_set[0], "fields")) - self.assertTrue("word_seq" in data_set[0].fields) - - self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) - self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) - self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) - self.assertEqual(data_set[0].fields["word_seq"]._index, - [self.word_vocab[c] for c in self.labeled_data_list[0][0]]) diff --git a/test/core/test_predictor.py b/test/core/test_predictor.py index 8bd5a7ab..2fb2c090 100644 --- a/test/core/test_predictor.py +++ b/test/core/test_predictor.py @@ -1,11 +1,12 @@ import os import unittest -from fastNLP.core.dataset import TextClassifyDataSet, SeqLabelDataSet +from fastNLP.core.dataset import DataSet from fastNLP.core.predictor import Predictor from fastNLP.core.preprocess import save_pickle from fastNLP.core.vocabulary import Vocabulary from fastNLP.loader.base_loader import BaseLoader +from fastNLP.loader.dataset_loader import convert_seq_dataset from fastNLP.models.cnn_text_classification import CNNText from fastNLP.models.sequence_modeling import SeqLabeling @@ -42,8 +43,8 @@ class TestPredictor(unittest.TestCase): predictor = Predictor("./save/", pre.text_classify_post_processor) # Load infer data - infer_data_set = TextClassifyDataSet(load_func=BaseLoader.load) - infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) + infer_data_set = convert_seq_dataset(infer_data) + infer_data_set.index_field("word_seq", vocab) results = predictor.predict(network=model, data=infer_data_set) @@ -54,14 +55,11 @@ class TestPredictor(unittest.TestCase): self.assertTrue(isinstance(res, str)) self.assertTrue(res in class_vocab.word2idx) - del model, predictor, infer_data_set + del model, predictor model = SeqLabeling(model_args) predictor = Predictor("./save/", pre.seq_label_post_processor) - infer_data_set = SeqLabelDataSet(load_func=BaseLoader.load) - infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) - results = predictor.predict(network=model, data=infer_data_set) self.assertTrue(isinstance(results, list)) self.assertEqual(len(results), len(infer_data)) diff --git a/test/core/test_tester.py b/test/core/test_tester.py index 1118f284..5ae67e3f 100644 --- a/test/core/test_tester.py +++ b/test/core/test_tester.py @@ -1,7 +1,7 @@ import os import unittest -from fastNLP.core.dataset import SeqLabelDataSet +from fastNLP.core.dataset import DataSet from fastNLP.core.metrics import SeqLabelEvaluator from fastNLP.core.field import TextField, LabelField from fastNLP.core.instance import Instance @@ -35,7 +35,7 @@ class TestTester(unittest.TestCase): vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} - data_set = SeqLabelDataSet() + data_set = DataSet() for example in train_data: text, label = example[0], example[1] x = TextField(text, False) diff --git a/test/core/test_trainer.py b/test/core/test_trainer.py index b4a9178f..98ef879f 100644 --- a/test/core/test_trainer.py +++ b/test/core/test_trainer.py @@ -1,7 +1,7 @@ import os import unittest -from fastNLP.core.dataset import SeqLabelDataSet +from fastNLP.core.dataset import DataSet from fastNLP.core.metrics import SeqLabelEvaluator from fastNLP.core.field import TextField, LabelField from fastNLP.core.instance import Instance @@ -36,7 +36,7 @@ class TestTrainer(unittest.TestCase): vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} - data_set = SeqLabelDataSet() + data_set = DataSet() for example in train_data: text, label = example[0], example[1] x = TextField(text, False) diff --git a/test/model/test_cws.py b/test/model/test_cws.py index 0c43bbff..aaadce2d 100644 --- a/test/model/test_cws.py +++ b/test/model/test_cws.py @@ -1,6 +1,7 @@ import os -from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target +from fastNLP.core.dataset import DataSet +from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.metrics import SeqLabelEvaluator from fastNLP.core.predictor import SeqLabelInfer from fastNLP.core.preprocess import save_pickle, load_pickle @@ -37,8 +38,8 @@ def infer(): print("model loaded!") # Load infer data - infer_data = SeqLabelDataSet(load_func=BaseLoader.load) - infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True) + infer_data = TokenizeDataSetLoader().load(data_infer_path) + infer_data.index_field("word_seq", word2index) # inference infer = SeqLabelInfer(pickle_path) @@ -52,13 +53,15 @@ def train_test(): ConfigLoader().load_config(config_path, {"POS_infer": train_args}) # define dataset - data_train = SeqLabelDataSet(load_func=TokenizeDataSetLoader.load) - data_train.load(cws_data_path) - train_args["vocab_size"] = len(data_train.word_vocab) - train_args["num_classes"] = len(data_train.label_vocab) + data_train = TokenizeDataSetLoader().load(cws_data_path) + word_vocab = Vocabulary() + label_vocab = Vocabulary() + data_train.update_vocab(word_seq=word_vocab, label_seq=label_vocab) + train_args["vocab_size"] = len(word_vocab) + train_args["num_classes"] = len(label_vocab) - save_pickle(data_train.word_vocab, pickle_path, "word2id.pkl") - save_pickle(data_train.label_vocab, pickle_path, "label2id.pkl") + save_pickle(word_vocab, pickle_path, "word2id.pkl") + save_pickle(label_vocab, pickle_path, "label2id.pkl") # Trainer trainer = SeqLabelTrainer(**train_args.data) @@ -90,7 +93,7 @@ def train_test(): tester = SeqLabelTester(**test_args.data) # Start testing - change_field_is_target(data_train, "truth", True) + data_train.set_target(truth=True) tester.test(model, data_train) diff --git a/test/model/test_seq_label.py b/test/model/test_seq_label.py index ebb62f99..ba62b25b 100644 --- a/test/model/test_seq_label.py +++ b/test/model/test_seq_label.py @@ -1,6 +1,6 @@ import os -from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target +from fastNLP.core.dataset import DataSet from fastNLP.core.metrics import SeqLabelEvaluator from fastNLP.core.optimizer import Optimizer from fastNLP.core.preprocess import save_pickle @@ -25,8 +25,8 @@ def test_training(): ConfigLoader().load_config(config_dir, { "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) - data_set = SeqLabelDataSet() - data_set.load(data_path) + data_set = DataSet() + word_vocab = V data_train, data_dev = data_set.split(0.3, shuffle=True) model_args["vocab_size"] = len(data_set.word_vocab) model_args["num_classes"] = len(data_set.label_vocab) @@ -76,5 +76,5 @@ def test_training(): ) # Start testing with validation data - change_field_is_target(data_dev, "truth", True) + data_dev.set_target(truth=True) tester.test(model, data_dev)