Browse Source

add model.fit() method

tags/v0.2.0
FengZiYjun 6 years ago
parent
commit
671975a223
2 changed files with 21 additions and 3 deletions
  1. +6
    -0
      fastNLP/models/base_model.py
  2. +15
    -3
      test/model/seq_labeling.py

+ 6
- 0
fastNLP/models/base_model.py View File

@@ -1,5 +1,7 @@
import torch

from fastNLP.core.trainer import Trainer


class BaseModel(torch.nn.Module):
"""Base PyTorch model for all models.
@@ -8,6 +10,10 @@ class BaseModel(torch.nn.Module):
def __init__(self):
super(BaseModel, self).__init__()

def fit(self, train_data, dev_data=None, **train_args):
trainer = Trainer(**train_args)
trainer.train(self, train_data, dev_data)


class Vocabulary(object):
"""A look-up table that allows you to access `Lexeme` objects. The `Vocab`


+ 15
- 3
test/model/seq_labeling.py View File

@@ -1,9 +1,9 @@
import os
import sys

sys.path.append("..")
import argparse
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.dataset_loader import BaseLoader
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader
@@ -82,6 +82,7 @@ def train_and_test():
save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl")
save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl")

"""
trainer = SeqLabelTrainer(
epochs=trainer_args["epochs"],
batch_size=trainer_args["batch_size"],
@@ -92,12 +93,23 @@ def train_and_test():
model_name=model_name,
optimizer=Optimizer("SGD", lr=0.01, momentum=0.9),
)
"""

# Model
model = SeqLabeling(model_args)

model.fit(train_set, dev_set,
epochs=trainer_args["epochs"],
batch_size=trainer_args["batch_size"],
validate=False,
use_cuda=trainer_args["use_cuda"],
pickle_path=pickle_path,
save_best_dev=trainer_args["save_best_dev"],
model_name=model_name,
optimizer=Optimizer("SGD", lr=0.01, momentum=0.9))

# Start training
trainer.train(model, train_set, dev_set)
# trainer.train(model, train_set, dev_set)
print("Training finished!")

# Saver
@@ -105,7 +117,7 @@ def train_and_test():
saver.save_pytorch(model)
print("Model saved!")

del model, trainer
del model

change_field_is_target(dev_set, "truth", True)



Loading…
Cancel
Save