From 671975a22306b2a0db652eb22aa15596e15c17bb Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sun, 7 Oct 2018 11:52:56 +0800 Subject: [PATCH] add model.fit() method --- fastNLP/models/base_model.py | 6 ++++++ test/model/seq_labeling.py | 18 +++++++++++++++--- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/fastNLP/models/base_model.py b/fastNLP/models/base_model.py index 0fcc14e1..b1ae828f 100644 --- a/fastNLP/models/base_model.py +++ b/fastNLP/models/base_model.py @@ -1,5 +1,7 @@ import torch +from fastNLP.core.trainer import Trainer + class BaseModel(torch.nn.Module): """Base PyTorch model for all models. @@ -8,6 +10,10 @@ class BaseModel(torch.nn.Module): def __init__(self): super(BaseModel, self).__init__() + def fit(self, train_data, dev_data=None, **train_args): + trainer = Trainer(**train_args) + trainer.train(self, train_data, dev_data) + class Vocabulary(object): """A look-up table that allows you to access `Lexeme` objects. The `Vocab` diff --git a/test/model/seq_labeling.py b/test/model/seq_labeling.py index 06c67fa7..64561a4b 100644 --- a/test/model/seq_labeling.py +++ b/test/model/seq_labeling.py @@ -1,9 +1,9 @@ import os import sys + sys.path.append("..") import argparse from fastNLP.loader.config_loader import ConfigLoader, ConfigSection -from fastNLP.core.trainer import SeqLabelTrainer from fastNLP.loader.dataset_loader import BaseLoader from fastNLP.saver.model_saver import ModelSaver from fastNLP.loader.model_loader import ModelLoader @@ -82,6 +82,7 @@ def train_and_test(): save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl") save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl") + """ trainer = SeqLabelTrainer( epochs=trainer_args["epochs"], batch_size=trainer_args["batch_size"], @@ -92,12 +93,23 @@ def train_and_test(): model_name=model_name, optimizer=Optimizer("SGD", lr=0.01, momentum=0.9), ) + """ # Model model = SeqLabeling(model_args) + model.fit(train_set, dev_set, + epochs=trainer_args["epochs"], + batch_size=trainer_args["batch_size"], + validate=False, + use_cuda=trainer_args["use_cuda"], + pickle_path=pickle_path, + save_best_dev=trainer_args["save_best_dev"], + model_name=model_name, + optimizer=Optimizer("SGD", lr=0.01, momentum=0.9)) + # Start training - trainer.train(model, train_set, dev_set) + # trainer.train(model, train_set, dev_set) print("Training finished!") # Saver @@ -105,7 +117,7 @@ def train_and_test(): saver.save_pytorch(model) print("Model saved!") - del model, trainer + del model change_field_is_target(dev_set, "truth", True)