|
- import os
- import unittest
-
- from fastNLP.core.dataset import DataSet
- from fastNLP.core.metrics import SeqLabelEvaluator
- from fastNLP.core.field import TextField, LabelField
- from fastNLP.core.instance import Instance
- from fastNLP.core.loss import Loss
- from fastNLP.core.optimizer import Optimizer
- from fastNLP.core.trainer import SeqLabelTrainer
- from fastNLP.models.sequence_modeling import SeqLabeling
-
-
- class TestTrainer(unittest.TestCase):
- def test_case_1(self):
- args = {"epochs": 3, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/",
- "save_best_dev": True, "model_name": "default_model_name.pkl",
- "loss": Loss("cross_entropy"),
- "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
- "vocab_size": 10,
- "word_emb_dim": 100,
- "rnn_hidden_units": 100,
- "num_classes": 5,
- "evaluator": SeqLabelEvaluator()
- }
- trainer = SeqLabelTrainer(**args)
-
- train_data = [
- [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
- [['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
- [['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
- [['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']],
- [['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']],
- [['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
- ]
- vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
- label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}
-
- data_set = DataSet()
- for example in train_data:
- text, label = example[0], example[1]
- x = TextField(text, False)
- x_len = LabelField(len(text), is_target=False)
- y = TextField(label, is_target=False)
- ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len)
- data_set.append(ins)
-
- data_set.index_field("word_seq", vocab)
- data_set.index_field("truth", label_vocab)
-
- model = SeqLabeling(args)
-
- trainer.train(network=model, train_data=data_set, dev_data=data_set)
- # If this can run, everything is OK.
-
- os.system("rm -rf save")
- print("pickle path deleted")
|