diff --git a/test/core/test_trainer.py b/test/core/test_trainer.py new file mode 100644 index 00000000..7db861af --- /dev/null +++ b/test/core/test_trainer.py @@ -0,0 +1,33 @@ +import os + +import torch.nn as nn +import unittest + +from fastNLP.core.trainer import SeqLabelTrainer +from fastNLP.core.loss import Loss +from fastNLP.core.optimizer import Optimizer +from fastNLP.models.sequence_modeling import SeqLabeling + +class TestTrainer(unittest.TestCase): + def test_case_1(self): + args = {"epochs": 3, "batch_size": 8, "validate": True, "use_cuda": True, "pickle_path": "./save/", + "save_best_dev": True, "model_name": "default_model_name.pkl", + "loss": Loss(None), + "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), + "vocab_size": 20, + "word_emb_dim": 100, + "rnn_hidden_units": 100, + "num_classes": 3 + } + trainer = SeqLabelTrainer() + train_data = [ + [[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]], + [[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]], + [[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]], + [[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]], + [[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]], + [[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]], + ] + dev_data = train_data + model = SeqLabeling(args) + trainer.train(network=model, train_data=train_data, dev_data=dev_data) \ No newline at end of file