* refine code style * set up unit tests for Batch, DataSet, FieldArray * remove a lot of out-of-date unit tests, to get testing passedtags/v0.2.0
| @@ -64,6 +64,7 @@ class DataSet(object): | |||
| """ | |||
| :param data: a dict or a list. If it is a dict, the key is the name of a field and the value is the field. | |||
| All values must be of the same length. | |||
| If it is a list, it must be a list of Instance objects. | |||
| """ | |||
| self.field_arrays = {} | |||
| @@ -23,8 +23,7 @@ class FieldArray(object): | |||
| self.dtype = None | |||
| def __repr__(self): | |||
| # TODO | |||
| return '{}: {}'.format(self.name, self.content.__repr__()) | |||
| return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | |||
| def append(self, val): | |||
| self.content.append(val) | |||
| @@ -11,7 +11,7 @@ class Instance(object): | |||
| def __init__(self, **fields): | |||
| """ | |||
| :param fields: a dict of (field name: field) | |||
| :param fields: a dict of (str: list). | |||
| """ | |||
| self.fields = fields | |||
| @@ -1,5 +1,6 @@ | |||
| import os | |||
| import _pickle as pickle | |||
| import os | |||
| class BaseLoader(object): | |||
| @@ -1,7 +1,6 @@ | |||
| import os | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.field import * | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.io.base_loader import BaseLoader | |||
| @@ -87,6 +86,7 @@ class DataSetLoader(BaseLoader): | |||
| """ | |||
| raise NotImplementedError | |||
| @DataSet.set_reader('read_raw') | |||
| class RawDataSetLoader(DataSetLoader): | |||
| def __init__(self): | |||
| @@ -102,6 +102,7 @@ class RawDataSetLoader(DataSetLoader): | |||
| def convert(self, data): | |||
| return convert_seq_dataset(data) | |||
| @DataSet.set_reader('read_pos') | |||
| class POSDataSetLoader(DataSetLoader): | |||
| """Dataset Loader for POS Tag datasets. | |||
| @@ -171,6 +172,7 @@ class POSDataSetLoader(DataSetLoader): | |||
| """ | |||
| return convert_seq2seq_dataset(data) | |||
| @DataSet.set_reader('read_tokenize') | |||
| class TokenizeDataSetLoader(DataSetLoader): | |||
| """ | |||
| @@ -230,6 +232,7 @@ class TokenizeDataSetLoader(DataSetLoader): | |||
| def convert(self, data): | |||
| return convert_seq2seq_dataset(data) | |||
| @DataSet.set_reader('read_class') | |||
| class ClassDataSetLoader(DataSetLoader): | |||
| """Loader for classification data sets""" | |||
| @@ -268,6 +271,7 @@ class ClassDataSetLoader(DataSetLoader): | |||
| def convert(self, data): | |||
| return convert_seq2tag_dataset(data) | |||
| @DataSet.set_reader('read_conll') | |||
| class ConllLoader(DataSetLoader): | |||
| """loader for conll format files""" | |||
| @@ -309,6 +313,7 @@ class ConllLoader(DataSetLoader): | |||
| def convert(self, data): | |||
| pass | |||
| @DataSet.set_reader('read_lm') | |||
| class LMDataSetLoader(DataSetLoader): | |||
| """Language Model Dataset Loader | |||
| @@ -345,6 +350,7 @@ class LMDataSetLoader(DataSetLoader): | |||
| def convert(self, data): | |||
| pass | |||
| @DataSet.set_reader('read_people_daily') | |||
| class PeopleDailyCorpusLoader(DataSetLoader): | |||
| """ | |||
| @@ -1,6 +1,9 @@ | |||
| import unittest | |||
| import numpy as np | |||
| from fastNLP.core.batch import Batch | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.dataset import construct_dataset | |||
| from fastNLP.core.sampler import SequentialSampler | |||
| @@ -10,9 +13,21 @@ class TestCase1(unittest.TestCase): | |||
| dataset = construct_dataset( | |||
| [["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | |||
| dataset.set_target() | |||
| batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), use_cuda=False) | |||
| batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
| cnt = 0 | |||
| for _, _ in batch: | |||
| cnt += 1 | |||
| self.assertEqual(cnt, 10) | |||
| def test_dataset_batching(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
| ds.set_input(x=True) | |||
| ds.set_target(y=True) | |||
| iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
| for x, y in iter: | |||
| self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray)) | |||
| self.assertEqual(len(x["x"]), 4) | |||
| self.assertEqual(len(y["y"]), 4) | |||
| self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4]) | |||
| self.assertListEqual(list(y["y"][-1]), [5, 6]) | |||
| @@ -1,20 +1,75 @@ | |||
| import unittest | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.instance import Instance | |||
| class TestDataSet(unittest.TestCase): | |||
| def test_case_1(self): | |||
| ds = DataSet() | |||
| ds.add_field(name="xx", fields=["a", "b", "e", "d"]) | |||
| def test_init_v1(self): | |||
| ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40) | |||
| self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | |||
| self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | |||
| self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) | |||
| self.assertTrue("xx" in ds.field_arrays) | |||
| self.assertEqual(len(ds.field_arrays["xx"]), 4) | |||
| self.assertEqual(ds.get_length(), 4) | |||
| self.assertEqual(ds.get_fields(), ds.field_arrays) | |||
| def test_init_v2(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
| self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | |||
| self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | |||
| self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) | |||
| try: | |||
| ds.add_field(name="yy", fields=["x", "y", "z", "w", "f"]) | |||
| except BaseException as e: | |||
| self.assertTrue(isinstance(e, AssertionError)) | |||
| def test_init_assert(self): | |||
| with self.assertRaises(AssertionError): | |||
| _ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100}) | |||
| with self.assertRaises(AssertionError): | |||
| _ = DataSet([[1, 2, 3, 4]] * 10) | |||
| with self.assertRaises(ValueError): | |||
| _ = DataSet(0.00001) | |||
| def test_append(self): | |||
| dd = DataSet() | |||
| for _ in range(3): | |||
| dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6])) | |||
| self.assertEqual(len(dd), 3) | |||
| self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3) | |||
| self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3) | |||
| def test_add_append(self): | |||
| dd = DataSet() | |||
| dd.add_field("x", [[1, 2, 3]] * 10) | |||
| dd.add_field("y", [[1, 2, 3, 4]] * 10) | |||
| dd.add_field("z", [[5, 6]] * 10) | |||
| self.assertEqual(len(dd), 10) | |||
| self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3]] * 10) | |||
| self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) | |||
| self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) | |||
| def test_delete_field(self): | |||
| dd = DataSet() | |||
| dd.add_field("x", [[1, 2, 3]] * 10) | |||
| dd.add_field("y", [[1, 2, 3, 4]] * 10) | |||
| dd.delete_field("x") | |||
| self.assertFalse("x" in dd.field_arrays) | |||
| self.assertTrue("y" in dd.field_arrays) | |||
| def test_getitem(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
| ins_1, ins_0 = ds[0], ds[1] | |||
| self.assertTrue(isinstance(ins_1, DataSet.Instance) and isinstance(ins_0, DataSet.Instance)) | |||
| self.assertEqual(ins_1["x"], [1, 2, 3, 4]) | |||
| self.assertEqual(ins_1["y"], [5, 6]) | |||
| self.assertEqual(ins_0["x"], [1, 2, 3, 4]) | |||
| self.assertEqual(ins_0["y"], [5, 6]) | |||
| sub_ds = ds[:10] | |||
| self.assertTrue(isinstance(sub_ds, DataSet)) | |||
| self.assertEqual(len(sub_ds), 10) | |||
| field = ds["x"] | |||
| self.assertEqual(field, ds.field_arrays["x"]) | |||
| def test_apply(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
| ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx") | |||
| self.assertTrue("rx" in ds.field_arrays) | |||
| self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) | |||
| @@ -1,6 +1,22 @@ | |||
| import unittest | |||
| import numpy as np | |||
| from fastNLP.core.fieldarray import FieldArray | |||
| class TestFieldArray(unittest.TestCase): | |||
| def test(self): | |||
| pass | |||
| fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) | |||
| self.assertEqual(len(fa), 5) | |||
| fa.append(6) | |||
| self.assertEqual(len(fa), 6) | |||
| self.assertEqual(fa[-1], 6) | |||
| self.assertEqual(fa[0], 1) | |||
| fa[-1] = 60 | |||
| self.assertEqual(fa[-1], 60) | |||
| self.assertEqual(fa.get(0), 1) | |||
| self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray)) | |||
| self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3]) | |||
| @@ -1,100 +0,0 @@ | |||
| import os | |||
| import sys | |||
| sys.path = [os.path.join(os.path.dirname(__file__), '..')] + sys.path | |||
| from fastNLP.core import metrics | |||
| # from sklearn import metrics as skmetrics | |||
| import unittest | |||
| from numpy import random | |||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||
| import torch | |||
| def generate_fake_label(low, high, size): | |||
| return random.randint(low, high, size), random.randint(low, high, size) | |||
| class TestEvaluator(unittest.TestCase): | |||
| def test_a(self): | |||
| evaluator = SeqLabelEvaluator() | |||
| pred = [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]] | |||
| truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4])}] | |||
| ans = evaluator(pred, truth) | |||
| print(ans) | |||
| def test_b(self): | |||
| evaluator = SeqLabelEvaluator() | |||
| pred = [[1, 2, 3, 4, 5, 0, 0], [1, 2, 3, 4, 5, 0, 0]] | |||
| truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3, 0, 0])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4, 0, 0])}] | |||
| ans = evaluator(pred, truth) | |||
| print(ans) | |||
| class TestMetrics(unittest.TestCase): | |||
| delta = 1e-5 | |||
| # test for binary, multiclass, multilabel | |||
| data_types = [((1000,), 2), ((1000,), 10), ((1000, 10), 2)] | |||
| fake_data = [generate_fake_label(0, high, shape) for shape, high in data_types] | |||
| def test_accuracy_score(self): | |||
| for y_true, y_pred in self.fake_data: | |||
| for normalize in [True, False]: | |||
| for sample_weight in [None, random.rand(y_true.shape[0])]: | |||
| test = metrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) | |||
| # ans = skmetrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) | |||
| # self.assertAlmostEqual(test, ans, delta=self.delta) | |||
| def test_recall_score(self): | |||
| for y_true, y_pred in self.fake_data: | |||
| # print(y_true.shape) | |||
| labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None | |||
| test = metrics.recall_score(y_true, y_pred, labels=labels, average=None) | |||
| if not isinstance(test, list): | |||
| test = list(test) | |||
| # ans = skmetrics.recall_score(y_true, y_pred,labels=labels, average=None) | |||
| # ans = list(ans) | |||
| # for a, b in zip(test, ans): | |||
| # # print('{}, {}'.format(a, b)) | |||
| # self.assertAlmostEqual(a, b, delta=self.delta) | |||
| # test binary | |||
| y_true, y_pred = generate_fake_label(0, 2, 1000) | |||
| test = metrics.recall_score(y_true, y_pred) | |||
| # ans = skmetrics.recall_score(y_true, y_pred) | |||
| # self.assertAlmostEqual(ans, test, delta=self.delta) | |||
| def test_precision_score(self): | |||
| for y_true, y_pred in self.fake_data: | |||
| # print(y_true.shape) | |||
| labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None | |||
| test = metrics.precision_score(y_true, y_pred, labels=labels, average=None) | |||
| # ans = skmetrics.precision_score(y_true, y_pred,labels=labels, average=None) | |||
| # ans, test = list(ans), list(test) | |||
| # for a, b in zip(test, ans): | |||
| # # print('{}, {}'.format(a, b)) | |||
| # self.assertAlmostEqual(a, b, delta=self.delta) | |||
| # test binary | |||
| y_true, y_pred = generate_fake_label(0, 2, 1000) | |||
| test = metrics.precision_score(y_true, y_pred) | |||
| # ans = skmetrics.precision_score(y_true, y_pred) | |||
| # self.assertAlmostEqual(ans, test, delta=self.delta) | |||
| def test_f1_score(self): | |||
| for y_true, y_pred in self.fake_data: | |||
| # print(y_true.shape) | |||
| labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None | |||
| test = metrics.f1_score(y_true, y_pred, labels=labels, average=None) | |||
| # ans = skmetrics.f1_score(y_true, y_pred,labels=labels, average=None) | |||
| # ans, test = list(ans), list(test) | |||
| # for a, b in zip(test, ans): | |||
| # # print('{}, {}'.format(a, b)) | |||
| # self.assertAlmostEqual(a, b, delta=self.delta) | |||
| # test binary | |||
| y_true, y_pred = generate_fake_label(0, 2, 1000) | |||
| test = metrics.f1_score(y_true, y_pred) | |||
| # ans = skmetrics.f1_score(y_true, y_pred) | |||
| # self.assertAlmostEqual(ans, test, delta=self.delta) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -1,77 +1,6 @@ | |||
| import os | |||
| import unittest | |||
| from fastNLP.core.predictor import Predictor | |||
| from fastNLP.core.utils import save_pickle | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.io.dataset_loader import convert_seq_dataset | |||
| from fastNLP.models.cnn_text_classification import CNNText | |||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||
| class TestPredictor(unittest.TestCase): | |||
| def test_seq_label(self): | |||
| model_args = { | |||
| "vocab_size": 10, | |||
| "word_emb_dim": 100, | |||
| "rnn_hidden_units": 100, | |||
| "num_classes": 5 | |||
| } | |||
| infer_data = [ | |||
| ['a', 'b', 'c', 'd', 'e'], | |||
| ['a', '@', 'c', 'd', 'e'], | |||
| ['a', 'b', '#', 'd', 'e'], | |||
| ['a', 'b', 'c', '?', 'e'], | |||
| ['a', 'b', 'c', 'd', '$'], | |||
| ['!', 'b', 'c', 'd', 'e'] | |||
| ] | |||
| vocab = Vocabulary() | |||
| vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||
| class_vocab = Vocabulary() | |||
| class_vocab.word2idx = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4} | |||
| os.system("mkdir save") | |||
| save_pickle(class_vocab, "./save/", "label2id.pkl") | |||
| save_pickle(vocab, "./save/", "word2id.pkl") | |||
| model = CNNText(model_args) | |||
| import fastNLP.core.predictor as pre | |||
| predictor = Predictor("./save/", pre.text_classify_post_processor) | |||
| # Load infer data | |||
| infer_data_set = convert_seq_dataset(infer_data) | |||
| infer_data_set.index_field("word_seq", vocab) | |||
| results = predictor.predict(network=model, data=infer_data_set) | |||
| self.assertTrue(isinstance(results, list)) | |||
| self.assertGreater(len(results), 0) | |||
| self.assertEqual(len(results), len(infer_data)) | |||
| for res in results: | |||
| self.assertTrue(isinstance(res, str)) | |||
| self.assertTrue(res in class_vocab.word2idx) | |||
| del model, predictor | |||
| infer_data_set.set_origin_len("word_seq") | |||
| model = SeqLabeling(model_args) | |||
| predictor = Predictor("./save/", pre.seq_label_post_processor) | |||
| results = predictor.predict(network=model, data=infer_data_set) | |||
| self.assertTrue(isinstance(results, list)) | |||
| self.assertEqual(len(results), len(infer_data)) | |||
| for i in range(len(infer_data)): | |||
| res = results[i] | |||
| self.assertTrue(isinstance(res, list)) | |||
| self.assertEqual(len(res), len(infer_data[i])) | |||
| os.system("rm -rf save") | |||
| print("pickle path deleted") | |||
| class TestPredictor2(unittest.TestCase): | |||
| def test_text_classify(self): | |||
| # TODO | |||
| def test(self): | |||
| pass | |||
| @@ -1,57 +1,9 @@ | |||
| import os | |||
| import unittest | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.field import TextField, LabelField | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||
| from fastNLP.core.tester import Tester | |||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||
| data_name = "pku_training.utf8" | |||
| pickle_path = "data_for_tests" | |||
| class TestTester(unittest.TestCase): | |||
| def test_case_1(self): | |||
| model_args = { | |||
| "vocab_size": 10, | |||
| "word_emb_dim": 100, | |||
| "rnn_hidden_units": 100, | |||
| "num_classes": 5 | |||
| } | |||
| valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||
| "save_loss": True, "batch_size": 2, "pickle_path": "./save/", | |||
| "use_cuda": False, "print_every_step": 1, "evaluator": SeqLabelEvaluator()} | |||
| train_data = [ | |||
| [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| [['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| [['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| [['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| [['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']], | |||
| [['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| ] | |||
| vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||
| label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | |||
| data_set = DataSet() | |||
| for example in train_data: | |||
| text, label = example[0], example[1] | |||
| x = TextField(text, False) | |||
| x_len = LabelField(len(text), is_target=False) | |||
| y = TextField(label, is_target=True) | |||
| ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len) | |||
| data_set.append(ins) | |||
| data_set.index_field("word_seq", vocab) | |||
| data_set.index_field("truth", label_vocab) | |||
| model = SeqLabeling(model_args) | |||
| tester = Tester(**valid_args) | |||
| tester.test(network=model, dev_data=data_set) | |||
| # If this can run, everything is OK. | |||
| os.system("rm -rf save") | |||
| print("pickle path deleted") | |||
| pass | |||
| @@ -1,57 +1,6 @@ | |||
| import os | |||
| import unittest | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.field import TextField, LabelField | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.loss import Loss | |||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||
| from fastNLP.core.optimizer import Optimizer | |||
| from fastNLP.core.trainer import Trainer | |||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||
| class TestTrainer(unittest.TestCase): | |||
| def test_case_1(self): | |||
| args = {"epochs": 3, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | |||
| "save_best_dev": True, "model_name": "default_model_name.pkl", | |||
| "loss": Loss("cross_entropy"), | |||
| "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | |||
| "vocab_size": 10, | |||
| "word_emb_dim": 100, | |||
| "rnn_hidden_units": 100, | |||
| "num_classes": 5, | |||
| "evaluator": SeqLabelEvaluator() | |||
| } | |||
| trainer = Trainer(**args) | |||
| train_data = [ | |||
| [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| [['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| [['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| [['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| [['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']], | |||
| [['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| ] | |||
| vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||
| label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | |||
| data_set = DataSet() | |||
| for example in train_data: | |||
| text, label = example[0], example[1] | |||
| x = TextField(text, False) | |||
| x_len = LabelField(len(text), is_target=False) | |||
| y = TextField(label, is_target=False) | |||
| ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len) | |||
| data_set.append(ins) | |||
| data_set.index_field("word_seq", vocab) | |||
| data_set.index_field("truth", label_vocab) | |||
| model = SeqLabeling(args) | |||
| trainer.train(network=model, train_data=data_set, dev_data=data_set) | |||
| # If this can run, everything is OK. | |||
| os.system("rm -rf save") | |||
| print("pickle path deleted") | |||
| pass | |||
| @@ -1,53 +0,0 @@ | |||
| import configparser | |||
| import json | |||
| import os | |||
| import unittest | |||
| from fastNLP.io.config_loader import ConfigSection, ConfigLoader | |||
| class TestConfigLoader(unittest.TestCase): | |||
| def test_case_ConfigLoader(self): | |||
| def read_section_from_config(config_path, section_name): | |||
| dict = {} | |||
| if not os.path.exists(config_path): | |||
| raise FileNotFoundError("config file {} NOT found.".format(config_path)) | |||
| cfg = configparser.ConfigParser() | |||
| cfg.read(config_path) | |||
| if section_name not in cfg: | |||
| raise AttributeError("config file {} do NOT have section {}".format( | |||
| config_path, section_name | |||
| )) | |||
| gen_sec = cfg[section_name] | |||
| for s in gen_sec.keys(): | |||
| try: | |||
| val = json.loads(gen_sec[s]) | |||
| dict[s] = val | |||
| except Exception as e: | |||
| raise AttributeError("json can NOT load {} in section {}, config file {}".format( | |||
| s, section_name, config_path | |||
| )) | |||
| return dict | |||
| test_arg = ConfigSection() | |||
| ConfigLoader().load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) | |||
| section = read_section_from_config(os.path.join("./test/loader", "config"), "test") | |||
| for sec in section: | |||
| if (sec not in test_arg) or (section[sec] != test_arg[sec]): | |||
| raise AttributeError("ERROR") | |||
| for sec in test_arg.__dict__.keys(): | |||
| if (sec not in section) or (section[sec] != test_arg[sec]): | |||
| raise AttributeError("ERROR") | |||
| try: | |||
| not_exist = test_arg["NOT EXIST"] | |||
| except Exception as e: | |||
| pass | |||
| print("pass config test!") | |||
| @@ -7,7 +7,7 @@ from fastNLP.io.config_saver import ConfigSaver | |||
| class TestConfigSaver(unittest.TestCase): | |||
| def test_case_1(self): | |||
| config_file_dir = "test/loader/" | |||
| config_file_dir = "test/io/" | |||
| config_file_name = "config" | |||
| config_file_path = os.path.join(config_file_dir, config_file_name) | |||
| @@ -1,53 +0,0 @@ | |||
| import unittest | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.io.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \ | |||
| PeopleDailyCorpusLoader, ConllLoader | |||
| class TestDatasetLoader(unittest.TestCase): | |||
| def test_case_1(self): | |||
| data = """Tom\tT\nand\tF\nJerry\tT\n.\tF\n\nHello\tT\nworld\tF\n!\tF""" | |||
| lines = data.split("\n") | |||
| answer = POSDataSetLoader.parse(lines) | |||
| truth = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], [["Hello", "world", "!"], ["T", "F", "F"]]] | |||
| self.assertListEqual(answer, truth, "POS Dataset Loader") | |||
| def test_case_TokenizeDatasetLoader(self): | |||
| loader = TokenizeDataSetLoader() | |||
| filepath = "./test/data_for_tests/cws_pku_utf_8" | |||
| data = loader.load(filepath, max_seq_len=32) | |||
| assert len(data) > 0 | |||
| data1 = DataSet() | |||
| data1.read_tokenize(filepath, max_seq_len=32) | |||
| assert len(data1) > 0 | |||
| print("pass TokenizeDataSetLoader test!") | |||
| def test_case_POSDatasetLoader(self): | |||
| loader = POSDataSetLoader() | |||
| filepath = "./test/data_for_tests/people.txt" | |||
| data = loader.load("./test/data_for_tests/people.txt") | |||
| datas = loader.load_lines("./test/data_for_tests/people.txt") | |||
| data1 = DataSet().read_pos(filepath) | |||
| assert len(data1) > 0 | |||
| print("pass POSDataSetLoader test!") | |||
| def test_case_LMDatasetLoader(self): | |||
| loader = LMDataSetLoader() | |||
| data = loader.load("./test/data_for_tests/charlm.txt") | |||
| datas = loader.load_lines("./test/data_for_tests/charlm.txt") | |||
| print("pass TokenizeDataSetLoader test!") | |||
| def test_PeopleDailyCorpusLoader(self): | |||
| loader = PeopleDailyCorpusLoader() | |||
| _, _ = loader.load("./test/data_for_tests/people_daily_raw.txt") | |||
| def test_ConllLoader(self): | |||
| loader = ConllLoader() | |||
| _ = loader.load("./test/data_for_tests/conll_example.txt") | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -1,31 +0,0 @@ | |||
| import os | |||
| import unittest | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.io.embed_loader import EmbedLoader | |||
| class TestEmbedLoader(unittest.TestCase): | |||
| glove_path = './test/data_for_tests/glove.6B.50d_test.txt' | |||
| pkl_path = './save' | |||
| raw_texts = ["i am a cat", | |||
| "this is a test of new batch", | |||
| "ha ha", | |||
| "I am a good boy .", | |||
| "This is the most beautiful girl ." | |||
| ] | |||
| texts = [text.strip().split() for text in raw_texts] | |||
| vocab = Vocabulary() | |||
| vocab.update(texts) | |||
| def test1(self): | |||
| emb, _ = EmbedLoader.load_embedding(50, self.glove_path, 'glove', self.vocab, self.pkl_path) | |||
| self.assertTrue(emb.shape[0] == (len(self.vocab))) | |||
| self.assertTrue(emb.shape[1] == 50) | |||
| os.remove(self.pkl_path) | |||
| def test2(self): | |||
| try: | |||
| _ = EmbedLoader.load_embedding(100, self.glove_path, 'glove', self.vocab, self.pkl_path) | |||
| self.fail(msg="load dismatch embedding") | |||
| except ValueError: | |||
| pass | |||
| @@ -1,150 +0,0 @@ | |||
| import os | |||
| import sys | |||
| sys.path.append("..") | |||
| import argparse | |||
| from fastNLP.io.config_loader import ConfigLoader, ConfigSection | |||
| from fastNLP.io.dataset_loader import BaseLoader | |||
| from fastNLP.io.model_saver import ModelSaver | |||
| from fastNLP.io.model_loader import ModelLoader | |||
| from fastNLP.core.tester import SeqLabelTester | |||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||
| from fastNLP.core.predictor import SeqLabelInfer | |||
| from fastNLP.core.optimizer import Optimizer | |||
| from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target | |||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||
| from fastNLP.core.utils import save_pickle, load_pickle | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files") | |||
| parser.add_argument("-t", "--train", type=str, default="../data_for_tests/people.txt", | |||
| help="path to the training data") | |||
| parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file") | |||
| parser.add_argument("-m", "--model_name", type=str, default="seq_label_model.pkl", help="the name of the model") | |||
| parser.add_argument("-i", "--infer", type=str, default="../data_for_tests/people_infer.txt", | |||
| help="data used for inference") | |||
| args = parser.parse_args() | |||
| pickle_path = args.save | |||
| model_name = args.model_name | |||
| config_dir = args.config | |||
| data_path = args.train | |||
| data_infer_path = args.infer | |||
| def infer(): | |||
| # Load infer configuration, the same as test | |||
| test_args = ConfigSection() | |||
| ConfigLoader().load_config(config_dir, {"POS_infer": test_args}) | |||
| # fetch dictionary size and number of labels from pickle files | |||
| word_vocab = load_pickle(pickle_path, "word2id.pkl") | |||
| label_vocab = load_pickle(pickle_path, "label2id.pkl") | |||
| test_args["vocab_size"] = len(word_vocab) | |||
| test_args["num_classes"] = len(label_vocab) | |||
| print("vocabularies loaded") | |||
| # Define the same model | |||
| model = SeqLabeling(test_args) | |||
| print("model defined") | |||
| # Dump trained parameters into the model | |||
| ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | |||
| print("model loaded!") | |||
| # Data Loader | |||
| infer_data = SeqLabelDataSet(load_func=BaseLoader.load) | |||
| infer_data.load(data_infer_path, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}, infer=True) | |||
| print("data set prepared") | |||
| # Inference interface | |||
| infer = SeqLabelInfer(pickle_path) | |||
| results = infer.predict(model, infer_data) | |||
| for res in results: | |||
| print(res) | |||
| print("Inference finished!") | |||
| def train_and_test(): | |||
| # Config Loader | |||
| trainer_args = ConfigSection() | |||
| model_args = ConfigSection() | |||
| ConfigLoader().load_config(config_dir, { | |||
| "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | |||
| data_set = SeqLabelDataSet() | |||
| data_set.load(data_path) | |||
| train_set, dev_set = data_set.split(0.3, shuffle=True) | |||
| model_args["vocab_size"] = len(data_set.word_vocab) | |||
| model_args["num_classes"] = len(data_set.label_vocab) | |||
| save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl") | |||
| save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl") | |||
| """ | |||
| trainer = SeqLabelTrainer( | |||
| epochs=trainer_args["epochs"], | |||
| batch_size=trainer_args["batch_size"], | |||
| validate=False, | |||
| use_cuda=trainer_args["use_cuda"], | |||
| pickle_path=pickle_path, | |||
| save_best_dev=trainer_args["save_best_dev"], | |||
| model_name=model_name, | |||
| optimizer=Optimizer("SGD", lr=0.01, momentum=0.9), | |||
| ) | |||
| """ | |||
| # Model | |||
| model = SeqLabeling(model_args) | |||
| model.fit(train_set, dev_set, | |||
| epochs=trainer_args["epochs"], | |||
| batch_size=trainer_args["batch_size"], | |||
| validate=False, | |||
| use_cuda=trainer_args["use_cuda"], | |||
| pickle_path=pickle_path, | |||
| save_best_dev=trainer_args["save_best_dev"], | |||
| model_name=model_name, | |||
| optimizer=Optimizer("SGD", lr=0.01, momentum=0.9)) | |||
| # Start training | |||
| # trainer.train(model, train_set, dev_set) | |||
| print("Training finished!") | |||
| # Saver | |||
| saver = ModelSaver(os.path.join(pickle_path, model_name)) | |||
| saver.save_pytorch(model) | |||
| print("Model saved!") | |||
| del model | |||
| change_field_is_target(dev_set, "truth", True) | |||
| # Define the same model | |||
| model = SeqLabeling(model_args) | |||
| # Dump trained parameters into the model | |||
| ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | |||
| print("model loaded!") | |||
| # Load test configuration | |||
| tester_args = ConfigSection() | |||
| ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||
| # Tester | |||
| tester = SeqLabelTester(batch_size=4, | |||
| use_cuda=False, | |||
| pickle_path=pickle_path, | |||
| model_name="seq_label_in_test.pkl", | |||
| evaluator=SeqLabelEvaluator() | |||
| ) | |||
| # Start testing with validation data | |||
| tester.test(model, dev_set) | |||
| print("model tested!") | |||
| if __name__ == "__main__": | |||
| train_and_test() | |||
| infer() | |||
| @@ -1,25 +0,0 @@ | |||
| import unittest | |||
| import numpy as np | |||
| import torch | |||
| from fastNLP.models.char_language_model import CharLM | |||
| class TestCharLM(unittest.TestCase): | |||
| def test_case_1(self): | |||
| char_emb_dim = 50 | |||
| word_emb_dim = 50 | |||
| vocab_size = 1000 | |||
| num_char = 24 | |||
| max_word_len = 21 | |||
| num_seq = 64 | |||
| seq_len = 32 | |||
| model = CharLM(char_emb_dim, word_emb_dim, vocab_size, num_char) | |||
| x = torch.from_numpy(np.random.randint(0, num_char, size=(num_seq, seq_len, max_word_len + 2))) | |||
| self.assertEqual(tuple(x.shape), (num_seq, seq_len, max_word_len + 2)) | |||
| y = model(x) | |||
| self.assertEqual(tuple(y.shape), (num_seq * seq_len, vocab_size)) | |||
| @@ -1,111 +0,0 @@ | |||
| import os | |||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||
| from fastNLP.core.predictor import Predictor | |||
| from fastNLP.core.tester import Tester | |||
| from fastNLP.core.trainer import Trainer | |||
| from fastNLP.core.utils import save_pickle, load_pickle | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.io.config_loader import ConfigLoader, ConfigSection | |||
| from fastNLP.io.dataset_loader import TokenizeDataSetLoader, RawDataSetLoader | |||
| from fastNLP.io.model_loader import ModelLoader | |||
| from fastNLP.io.model_saver import ModelSaver | |||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||
| data_name = "pku_training.utf8" | |||
| cws_data_path = "./test/data_for_tests/cws_pku_utf_8" | |||
| pickle_path = "./save/" | |||
| data_infer_path = "./test/data_for_tests/people_infer.txt" | |||
| config_path = "./test/data_for_tests/config" | |||
| def infer(): | |||
| # Load infer configuration, the same as test | |||
| test_args = ConfigSection() | |||
| ConfigLoader().load_config(config_path, {"POS_infer": test_args}) | |||
| # fetch dictionary size and number of labels from pickle files | |||
| word2index = load_pickle(pickle_path, "word2id.pkl") | |||
| test_args["vocab_size"] = len(word2index) | |||
| index2label = load_pickle(pickle_path, "label2id.pkl") | |||
| test_args["num_classes"] = len(index2label) | |||
| # Define the same model | |||
| model = SeqLabeling(test_args) | |||
| # Dump trained parameters into the model | |||
| ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||
| print("model loaded!") | |||
| # Load infer data | |||
| infer_data = RawDataSetLoader().load(data_infer_path) | |||
| infer_data.index_field("word_seq", word2index) | |||
| infer_data.set_origin_len("word_seq") | |||
| # inference | |||
| infer = Predictor(pickle_path) | |||
| results = infer.predict(model, infer_data) | |||
| print(results) | |||
| def train_test(): | |||
| # Config Loader | |||
| train_args = ConfigSection() | |||
| ConfigLoader().load_config(config_path, {"POS_infer": train_args}) | |||
| # define dataset | |||
| data_train = TokenizeDataSetLoader().load(cws_data_path) | |||
| word_vocab = Vocabulary() | |||
| label_vocab = Vocabulary() | |||
| data_train.update_vocab(word_seq=word_vocab, label_seq=label_vocab) | |||
| data_train.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab) | |||
| data_train.set_origin_len("word_seq") | |||
| data_train.rename_field("label_seq", "truth").set_target(truth=False) | |||
| train_args["vocab_size"] = len(word_vocab) | |||
| train_args["num_classes"] = len(label_vocab) | |||
| save_pickle(word_vocab, pickle_path, "word2id.pkl") | |||
| save_pickle(label_vocab, pickle_path, "label2id.pkl") | |||
| # Trainer | |||
| trainer = Trainer(**train_args.data) | |||
| # Model | |||
| model = SeqLabeling(train_args) | |||
| # Start training | |||
| trainer.train(model, data_train) | |||
| # Saver | |||
| saver = ModelSaver("./save/saved_model.pkl") | |||
| saver.save_pytorch(model) | |||
| del model, trainer | |||
| # Define the same model | |||
| model = SeqLabeling(train_args) | |||
| # Dump trained parameters into the model | |||
| ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||
| # Load test configuration | |||
| test_args = ConfigSection() | |||
| ConfigLoader().load_config(config_path, {"POS_infer": test_args}) | |||
| test_args["evaluator"] = SeqLabelEvaluator() | |||
| # Tester | |||
| tester = Tester(**test_args.data) | |||
| # Start testing | |||
| data_train.set_target(truth=True) | |||
| tester.test(model, data_train) | |||
| def test(): | |||
| os.makedirs("save", exist_ok=True) | |||
| train_test() | |||
| infer() | |||
| os.system("rm -rf save") | |||
| if __name__ == "__main__": | |||
| train_test() | |||
| infer() | |||
| @@ -1,90 +0,0 @@ | |||
| import os | |||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||
| from fastNLP.core.optimizer import Optimizer | |||
| from fastNLP.core.tester import Tester | |||
| from fastNLP.core.trainer import Trainer | |||
| from fastNLP.core.utils import save_pickle | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.io.config_loader import ConfigLoader, ConfigSection | |||
| from fastNLP.io.dataset_loader import TokenizeDataSetLoader | |||
| from fastNLP.io.model_loader import ModelLoader | |||
| from fastNLP.io.model_saver import ModelSaver | |||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||
| pickle_path = "./seq_label/" | |||
| model_name = "seq_label_model.pkl" | |||
| config_dir = "../data_for_tests/config" | |||
| data_path = "../data_for_tests/people.txt" | |||
| data_infer_path = "../data_for_tests/people_infer.txt" | |||
| def test_training(): | |||
| # Config Loader | |||
| trainer_args = ConfigSection() | |||
| model_args = ConfigSection() | |||
| ConfigLoader().load_config(config_dir, { | |||
| "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | |||
| data_set = TokenizeDataSetLoader().load(data_path) | |||
| word_vocab = Vocabulary() | |||
| label_vocab = Vocabulary() | |||
| data_set.update_vocab(word_seq=word_vocab, label_seq=label_vocab) | |||
| data_set.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab) | |||
| data_set.set_origin_len("word_seq") | |||
| data_set.rename_field("label_seq", "truth").set_target(truth=False) | |||
| data_train, data_dev = data_set.split(0.3, shuffle=True) | |||
| model_args["vocab_size"] = len(word_vocab) | |||
| model_args["num_classes"] = len(label_vocab) | |||
| save_pickle(word_vocab, pickle_path, "word2id.pkl") | |||
| save_pickle(label_vocab, pickle_path, "label2id.pkl") | |||
| trainer = Trainer( | |||
| epochs=trainer_args["epochs"], | |||
| batch_size=trainer_args["batch_size"], | |||
| validate=False, | |||
| use_cuda=False, | |||
| pickle_path=pickle_path, | |||
| save_best_dev=trainer_args["save_best_dev"], | |||
| model_name=model_name, | |||
| optimizer=Optimizer("SGD", lr=0.01, momentum=0.9), | |||
| ) | |||
| # Model | |||
| model = SeqLabeling(model_args) | |||
| # Start training | |||
| trainer.train(model, data_train, data_dev) | |||
| # Saver | |||
| saver = ModelSaver(os.path.join(pickle_path, model_name)) | |||
| saver.save_pytorch(model) | |||
| del model, trainer | |||
| # Define the same model | |||
| model = SeqLabeling(model_args) | |||
| # Dump trained parameters into the model | |||
| ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | |||
| # Load test configuration | |||
| tester_args = ConfigSection() | |||
| ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||
| # Tester | |||
| tester = Tester(batch_size=4, | |||
| use_cuda=False, | |||
| pickle_path=pickle_path, | |||
| model_name="seq_label_in_test.pkl", | |||
| evaluator=SeqLabelEvaluator() | |||
| ) | |||
| # Start testing with validation data | |||
| data_dev.set_target(truth=True) | |||
| tester.test(model, data_dev) | |||
| if __name__ == "__main__": | |||
| test_training() | |||
| @@ -1,107 +0,0 @@ | |||
| # Python: 3.5 | |||
| # encoding: utf-8 | |||
| import argparse | |||
| import os | |||
| import sys | |||
| sys.path.append("..") | |||
| from fastNLP.core.predictor import ClassificationInfer | |||
| from fastNLP.core.trainer import ClassificationTrainer | |||
| from fastNLP.io.config_loader import ConfigLoader, ConfigSection | |||
| from fastNLP.io.dataset_loader import ClassDataSetLoader | |||
| from fastNLP.io.model_loader import ModelLoader | |||
| from fastNLP.models.cnn_text_classification import CNNText | |||
| from fastNLP.io.model_saver import ModelSaver | |||
| from fastNLP.core.optimizer import Optimizer | |||
| from fastNLP.core.loss import Loss | |||
| from fastNLP.core.dataset import TextClassifyDataSet | |||
| from fastNLP.core.utils import save_pickle, load_pickle | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files") | |||
| parser.add_argument("-t", "--train", type=str, default="../data_for_tests/text_classify.txt", | |||
| help="path to the training data") | |||
| parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file") | |||
| parser.add_argument("-m", "--model_name", type=str, default="classify_model.pkl", help="the name of the model") | |||
| args = parser.parse_args() | |||
| save_dir = args.save | |||
| train_data_dir = args.train | |||
| model_name = args.model_name | |||
| config_dir = args.config | |||
| def infer(): | |||
| # load dataset | |||
| print("Loading data...") | |||
| word_vocab = load_pickle(save_dir, "word2id.pkl") | |||
| label_vocab = load_pickle(save_dir, "label2id.pkl") | |||
| print("vocabulary size:", len(word_vocab)) | |||
| print("number of classes:", len(label_vocab)) | |||
| infer_data = TextClassifyDataSet(load_func=ClassDataSetLoader.load) | |||
| infer_data.load(train_data_dir, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}) | |||
| model_args = ConfigSection() | |||
| model_args["vocab_size"] = len(word_vocab) | |||
| model_args["num_classes"] = len(label_vocab) | |||
| ConfigLoader.load_config(config_dir, {"text_class_model": model_args}) | |||
| # construct model | |||
| print("Building model...") | |||
| cnn = CNNText(model_args) | |||
| # Dump trained parameters into the model | |||
| ModelLoader.load_pytorch(cnn, os.path.join(save_dir, model_name)) | |||
| print("model loaded!") | |||
| infer = ClassificationInfer(pickle_path=save_dir) | |||
| results = infer.predict(cnn, infer_data) | |||
| print(results) | |||
| def train(): | |||
| train_args, model_args = ConfigSection(), ConfigSection() | |||
| ConfigLoader.load_config(config_dir, {"text_class": train_args}) | |||
| # load dataset | |||
| print("Loading data...") | |||
| data = TextClassifyDataSet(load_func=ClassDataSetLoader.load) | |||
| data.load(train_data_dir) | |||
| print("vocabulary size:", len(data.word_vocab)) | |||
| print("number of classes:", len(data.label_vocab)) | |||
| save_pickle(data.word_vocab, save_dir, "word2id.pkl") | |||
| save_pickle(data.label_vocab, save_dir, "label2id.pkl") | |||
| model_args["num_classes"] = len(data.label_vocab) | |||
| model_args["vocab_size"] = len(data.word_vocab) | |||
| # construct model | |||
| print("Building model...") | |||
| model = CNNText(model_args) | |||
| # train | |||
| print("Training...") | |||
| trainer = ClassificationTrainer(epochs=train_args["epochs"], | |||
| batch_size=train_args["batch_size"], | |||
| validate=train_args["validate"], | |||
| use_cuda=train_args["use_cuda"], | |||
| pickle_path=save_dir, | |||
| save_best_dev=train_args["save_best_dev"], | |||
| model_name=model_name, | |||
| loss=Loss("cross_entropy"), | |||
| optimizer=Optimizer("SGD", lr=0.001, momentum=0.9)) | |||
| trainer.train(model, data) | |||
| print("Training finished!") | |||
| saver = ModelSaver(os.path.join(save_dir, model_name)) | |||
| saver.save_pytorch(model) | |||
| print("Model saved!") | |||
| if __name__ == "__main__": | |||
| train() | |||
| infer() | |||
| @@ -14,7 +14,7 @@ class TestGroupNorm(unittest.TestCase): | |||
| class TestLayerNormalization(unittest.TestCase): | |||
| def test_case_1(self): | |||
| ln = LayerNormalization(d_hid=5, eps=2e-3) | |||
| ln = LayerNormalization(layer_size=5, eps=2e-3) | |||
| x = torch.randn((20, 50, 5)) | |||
| y = ln(x) | |||