Browse Source

add model.fit() method

tags/v0.2.0
FengZiYjun 6 years ago
parent
commit
671975a223
2 changed files with 21 additions and 3 deletions
  1. +6
    -0
      fastNLP/models/base_model.py
  2. +15
    -3
      test/model/seq_labeling.py

+ 6
- 0
fastNLP/models/base_model.py View File

@@ -1,5 +1,7 @@
import torch import torch


from fastNLP.core.trainer import Trainer



class BaseModel(torch.nn.Module): class BaseModel(torch.nn.Module):
"""Base PyTorch model for all models. """Base PyTorch model for all models.
@@ -8,6 +10,10 @@ class BaseModel(torch.nn.Module):
def __init__(self): def __init__(self):
super(BaseModel, self).__init__() super(BaseModel, self).__init__()


def fit(self, train_data, dev_data=None, **train_args):
trainer = Trainer(**train_args)
trainer.train(self, train_data, dev_data)



class Vocabulary(object): class Vocabulary(object):
"""A look-up table that allows you to access `Lexeme` objects. The `Vocab` """A look-up table that allows you to access `Lexeme` objects. The `Vocab`


+ 15
- 3
test/model/seq_labeling.py View File

@@ -1,9 +1,9 @@
import os import os
import sys import sys

sys.path.append("..") sys.path.append("..")
import argparse import argparse
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.dataset_loader import BaseLoader from fastNLP.loader.dataset_loader import BaseLoader
from fastNLP.saver.model_saver import ModelSaver from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader from fastNLP.loader.model_loader import ModelLoader
@@ -82,6 +82,7 @@ def train_and_test():
save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl") save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl")
save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl") save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl")


"""
trainer = SeqLabelTrainer( trainer = SeqLabelTrainer(
epochs=trainer_args["epochs"], epochs=trainer_args["epochs"],
batch_size=trainer_args["batch_size"], batch_size=trainer_args["batch_size"],
@@ -92,12 +93,23 @@ def train_and_test():
model_name=model_name, model_name=model_name,
optimizer=Optimizer("SGD", lr=0.01, momentum=0.9), optimizer=Optimizer("SGD", lr=0.01, momentum=0.9),
) )
"""


# Model # Model
model = SeqLabeling(model_args) model = SeqLabeling(model_args)


model.fit(train_set, dev_set,
epochs=trainer_args["epochs"],
batch_size=trainer_args["batch_size"],
validate=False,
use_cuda=trainer_args["use_cuda"],
pickle_path=pickle_path,
save_best_dev=trainer_args["save_best_dev"],
model_name=model_name,
optimizer=Optimizer("SGD", lr=0.01, momentum=0.9))

# Start training # Start training
trainer.train(model, train_set, dev_set)
# trainer.train(model, train_set, dev_set)
print("Training finished!") print("Training finished!")


# Saver # Saver
@@ -105,7 +117,7 @@ def train_and_test():
saver.save_pytorch(model) saver.save_pytorch(model)
print("Model saved!") print("Model saved!")


del model, trainer
del model


change_field_is_target(dev_set, "truth", True) change_field_is_target(dev_set, "truth", True)




Loading…
Cancel
Save