- add Loss, Optimizer - change Trainer & Tester initialization interface: two styles of definition provided - handle Optimizer construction and loss function definition in a hard manner - add argparse in task-specific scripts. (seq_labeling.py & text_classify.py) - seq_labeling.py & text_classify.py worktags/v0.1.0
@@ -0,0 +1,27 @@ | |||||
import torch | |||||
class Loss(object): | |||||
"""Loss function of the algorithm, | |||||
either the wrapper of a loss function from framework, or a user-defined loss (need pytorch auto_grad support) | |||||
""" | |||||
def __init__(self, args): | |||||
if args is None: | |||||
# this is useful when | |||||
self._loss = None | |||||
elif isinstance(args, str): | |||||
self._loss = self._borrow_from_pytorch(args) | |||||
else: | |||||
raise NotImplementedError | |||||
def get(self): | |||||
return self._loss | |||||
@staticmethod | |||||
def _borrow_from_pytorch(loss_name): | |||||
if loss_name == "cross_entropy": | |||||
return torch.nn.CrossEntropyLoss() | |||||
else: | |||||
raise NotImplementedError |
@@ -1,3 +1,54 @@ | |||||
""" | |||||
use optimizer from Pytorch | |||||
""" | |||||
import torch | |||||
class Optimizer(object): | |||||
"""Wrapper of optimizer from framework | |||||
names: arguments (type) | |||||
1. Adam: lr (float), weight_decay (float) | |||||
2. AdaGrad | |||||
3. RMSProp | |||||
4. SGD: lr (float), momentum (float) | |||||
""" | |||||
def __init__(self, optimizer_name, **kwargs): | |||||
""" | |||||
:param optimizer_name: str, the name of the optimizer | |||||
:param kwargs: the arguments | |||||
""" | |||||
self.optim_name = optimizer_name | |||||
self.kwargs = kwargs | |||||
@property | |||||
def name(self): | |||||
return self.optim_name | |||||
@property | |||||
def params(self): | |||||
return self.kwargs | |||||
def construct_from_pytorch(self, model_params): | |||||
"""construct a optimizer from framework over given model parameters""" | |||||
if self.optim_name in ["SGD", "sgd"]: | |||||
if "lr" in self.kwargs: | |||||
if "momentum" not in self.kwargs: | |||||
self.kwargs["momentum"] = 0 | |||||
optimizer = torch.optim.SGD(model_params, lr=self.kwargs["lr"], momentum=self.kwargs["momentum"]) | |||||
else: | |||||
raise ValueError("requires learning rate for SGD optimizer") | |||||
elif self.optim_name in ["adam", "Adam"]: | |||||
if "lr" in self.kwargs: | |||||
if "weight_decay" not in self.kwargs: | |||||
self.kwargs["weight_decay"] = 0 | |||||
optimizer = torch.optim.Adam(model_params, lr=self.kwargs["lr"], | |||||
weight_decay=self.kwargs["weight_decay"]) | |||||
else: | |||||
raise ValueError("requires learning rate for Adam optimizer") | |||||
else: | |||||
raise NotImplementedError | |||||
return optimizer |
@@ -1,5 +1,3 @@ | |||||
import _pickle | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
@@ -14,43 +12,78 @@ logger = create_logger(__name__, "./train_test.log") | |||||
class BaseTester(object): | class BaseTester(object): | ||||
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | ||||
def __init__(self, test_args): | |||||
def __init__(self, **kwargs): | |||||
""" | """ | ||||
:param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | |||||
:param kwargs: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | |||||
""" | """ | ||||
super(BaseTester, self).__init__() | super(BaseTester, self).__init__() | ||||
self.validate_in_training = test_args["validate_in_training"] | |||||
self.save_dev_data = None | |||||
self.save_output = test_args["save_output"] | |||||
self.output = None | |||||
self.save_loss = test_args["save_loss"] | |||||
self.mean_loss = None | |||||
self.batch_size = test_args["batch_size"] | |||||
self.pickle_path = test_args["pickle_path"] | |||||
self.iterator = None | |||||
self.use_cuda = test_args["use_cuda"] | |||||
self.model = None | |||||
""" | |||||
"default_args" provides default value for important settings. | |||||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||||
Otherwise, error will raise. | |||||
""" | |||||
default_args = {"save_output": False, # collect outputs of validation set | |||||
"save_loss": False, # collect losses in validation | |||||
"save_best_dev": False, # save best model during validation | |||||
"batch_size": 8, | |||||
"use_cuda": True, | |||||
"pickle_path": "./save/", | |||||
"model_name": "dev_best_model.pkl", | |||||
"print_every_step": 1, | |||||
} | |||||
""" | |||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||||
This is used to warn users of essential settings in the training. | |||||
Obviously, "required_args" is the subset of "default_args". | |||||
The value in "default_args" to the keys in "required_args" is simply for type check. | |||||
""" | |||||
# TODO: required arguments | |||||
required_args = {} | |||||
for req_key in required_args: | |||||
if req_key not in kwargs: | |||||
logger.error("Tester lacks argument {}".format(req_key)) | |||||
raise ValueError("Tester lacks argument {}".format(req_key)) | |||||
for key in default_args: | |||||
if key in kwargs: | |||||
if isinstance(kwargs[key], type(default_args[key])): | |||||
default_args[key] = kwargs[key] | |||||
else: | |||||
msg = "Argument %s type mismatch: expected %s while get %s" % ( | |||||
key, type(default_args[key]), type(kwargs[key])) | |||||
logger.error(msg) | |||||
raise ValueError(msg) | |||||
else: | |||||
# BeseTester doesn't care about extra arguments | |||||
pass | |||||
print(default_args) | |||||
self.save_output = default_args["save_output"] | |||||
self.save_best_dev = default_args["save_best_dev"] | |||||
self.save_loss = default_args["save_loss"] | |||||
self.batch_size = default_args["batch_size"] | |||||
self.pickle_path = default_args["pickle_path"] | |||||
self.use_cuda = default_args["use_cuda"] | |||||
self.print_every_step = default_args["print_every_step"] | |||||
self._model = None | |||||
self.eval_history = [] | self.eval_history = [] | ||||
self.batch_output = [] | self.batch_output = [] | ||||
def test(self, network, dev_data): | def test(self, network, dev_data): | ||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
self.model = network.cuda() | |||||
self._model = network.cuda() | |||||
else: | else: | ||||
self.model = network | |||||
self._model = network | |||||
# turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
self.mode(network, test=True) | self.mode(network, test=True) | ||||
self.eval_history.clear() | self.eval_history.clear() | ||||
self.batch_output.clear() | self.batch_output.clear() | ||||
# dev_data = self.prepare_input(self.pickle_path) | |||||
# logger.info("validation data loaded") | |||||
iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True)) | iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True)) | ||||
n_batches = len(dev_data) // self.batch_size | |||||
print_every_step = 1 | |||||
step = 0 | step = 0 | ||||
for batch_x, batch_y in self.make_batch(iterator, dev_data): | for batch_x, batch_y in self.make_batch(iterator, dev_data): | ||||
@@ -65,21 +98,10 @@ class BaseTester(object): | |||||
print_output = "[test step {}] {}".format(step, eval_results) | print_output = "[test step {}] {}".format(step, eval_results) | ||||
logger.info(print_output) | logger.info(print_output) | ||||
if step % print_every_step == 0: | |||||
if step % self.print_every_step == 0: | |||||
print(print_output) | print(print_output) | ||||
step += 1 | step += 1 | ||||
def prepare_input(self, data_path): | |||||
"""Save the dev data once it is loaded. Can return directly next time. | |||||
:param data_path: str, the path to the pickle data for dev | |||||
:return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s). | |||||
""" | |||||
if self.save_dev_data is None: | |||||
data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb")) | |||||
self.save_dev_data = data_dev | |||||
return self.save_dev_data | |||||
def mode(self, model, test): | def mode(self, model, test): | ||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
@@ -117,15 +139,14 @@ class SeqLabelTester(BaseTester): | |||||
Tester for sequence labeling. | Tester for sequence labeling. | ||||
""" | """ | ||||
def __init__(self, test_args): | |||||
def __init__(self, **test_args): | |||||
""" | """ | ||||
:param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | ||||
""" | """ | ||||
super(SeqLabelTester, self).__init__(test_args) | |||||
super(SeqLabelTester, self).__init__(**test_args) | |||||
self.max_len = None | self.max_len = None | ||||
self.mask = None | self.mask = None | ||||
self.seq_len = None | self.seq_len = None | ||||
self.batch_result = None | |||||
def data_forward(self, network, inputs): | def data_forward(self, network, inputs): | ||||
"""This is only for sequence labeling with CRF decoder. | """This is only for sequence labeling with CRF decoder. | ||||
@@ -159,10 +180,10 @@ class SeqLabelTester(BaseTester): | |||||
:return: | :return: | ||||
""" | """ | ||||
batch_size, max_len = predict.size(0), predict.size(1) | batch_size, max_len = predict.size(0), predict.size(1) | ||||
loss = self.model.loss(predict, truth, self.mask) / batch_size | |||||
loss = self._model.loss(predict, truth, self.mask) / batch_size | |||||
prediction = self.model.prediction(predict, self.mask) | |||||
results = torch.Tensor(prediction).view(-1,) | |||||
prediction = self._model.prediction(predict, self.mask) | |||||
results = torch.Tensor(prediction).view(-1, ) | |||||
# make sure "results" is in the same device as "truth" | # make sure "results" is in the same device as "truth" | ||||
results = results.to(truth) | results = results.to(truth) | ||||
accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] | accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] | ||||
@@ -184,21 +205,16 @@ class SeqLabelTester(BaseTester): | |||||
def make_batch(self, iterator, data): | def make_batch(self, iterator, data): | ||||
return Action.make_batch(iterator, use_cuda=self.use_cuda, output_length=True) | return Action.make_batch(iterator, use_cuda=self.use_cuda, output_length=True) | ||||
class ClassificationTester(BaseTester): | class ClassificationTester(BaseTester): | ||||
"""Tester for classification.""" | """Tester for classification.""" | ||||
def __init__(self, test_args): | |||||
def __init__(self, **test_args): | |||||
""" | """ | ||||
:param test_args: a dict-like object that has __getitem__ method, \ | :param test_args: a dict-like object that has __getitem__ method, \ | ||||
can be accessed by "test_args["key_str"]" | can be accessed by "test_args["key_str"]" | ||||
""" | """ | ||||
super(ClassificationTester, self).__init__(test_args) | |||||
self.pickle_path = test_args["pickle_path"] | |||||
self.save_dev_data = None | |||||
self.output = None | |||||
self.mean_loss = None | |||||
self.iterator = None | |||||
super(ClassificationTester, self).__init__(**test_args) | |||||
def make_batch(self, iterator, data, max_len=None): | def make_batch(self, iterator, data, max_len=None): | ||||
return Action.make_batch(iterator, use_cuda=self.use_cuda, max_len=max_len) | return Action.make_batch(iterator, use_cuda=self.use_cuda, max_len=max_len) | ||||
@@ -221,4 +237,3 @@ class ClassificationTester(BaseTester): | |||||
y_true = torch.cat(y_true, dim=0) | y_true = torch.cat(y_true, dim=0) | ||||
acc = float(torch.sum(y_pred == y_true)) / len(y_true) | acc = float(torch.sum(y_pred == y_true)) / len(y_true) | ||||
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc | return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc | ||||
@@ -6,10 +6,11 @@ from datetime import timedelta | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import torch.nn as nn | |||||
from fastNLP.core.action import Action | from fastNLP.core.action import Action | ||||
from fastNLP.core.action import RandomSampler, Batchifier | from fastNLP.core.action import RandomSampler, Batchifier | ||||
from fastNLP.core.loss import Loss | |||||
from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.tester import SeqLabelTester, ClassificationTester | from fastNLP.core.tester import SeqLabelTester, ClassificationTester | ||||
from fastNLP.modules import utils | from fastNLP.modules import utils | ||||
from fastNLP.saver.logger import create_logger | from fastNLP.saver.logger import create_logger | ||||
@@ -23,14 +24,13 @@ class BaseTrainer(object): | |||||
"""Operations to train a model, including data loading, SGD, and validation. | """Operations to train a model, including data loading, SGD, and validation. | ||||
Subclasses must implement the following abstract methods: | Subclasses must implement the following abstract methods: | ||||
- define_optimizer | |||||
- grad_backward | - grad_backward | ||||
- get_loss | - get_loss | ||||
""" | """ | ||||
def __init__(self, train_args): | |||||
def __init__(self, **kwargs): | |||||
""" | """ | ||||
:param train_args: dict of (key, value), or dict-like object. key is str. | |||||
:param kwargs: dict of (key, value), or dict-like object. key is str. | |||||
The base trainer requires the following keys: | The base trainer requires the following keys: | ||||
- epochs: int, the number of epochs in training | - epochs: int, the number of epochs in training | ||||
@@ -39,19 +39,58 @@ class BaseTrainer(object): | |||||
- pickle_path: str, the path to pickle files for pre-processing | - pickle_path: str, the path to pickle files for pre-processing | ||||
""" | """ | ||||
super(BaseTrainer, self).__init__() | super(BaseTrainer, self).__init__() | ||||
self.n_epochs = train_args["epochs"] | |||||
self.batch_size = train_args["batch_size"] | |||||
self.pickle_path = train_args["pickle_path"] | |||||
self.validate = train_args["validate"] | |||||
self.save_best_dev = train_args["save_best_dev"] | |||||
self.model_saved_path = train_args["model_saved_path"] | |||||
self.use_cuda = train_args["use_cuda"] | |||||
self.model = None | |||||
self.iterator = None | |||||
self.loss_func = None | |||||
self.optimizer = None | |||||
""" | |||||
"default_args" provides default value for important settings. | |||||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||||
Otherwise, error will raise. | |||||
""" | |||||
default_args = {"epochs": 3, "batch_size": 8, "validate": True, "use_cuda": True, "pickle_path": "./save/", | |||||
"save_best_dev": True, "model_name": "default_model_name.pkl", | |||||
"loss": Loss(None), | |||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0) | |||||
} | |||||
""" | |||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||||
This is used to warn users of essential settings in the training. | |||||
Obviously, "required_args" is the subset of "default_args". | |||||
The value in "default_args" to the keys in "required_args" is simply for type check. | |||||
""" | |||||
# TODO: required arguments | |||||
required_args = {} | |||||
for req_key in required_args: | |||||
if req_key not in kwargs: | |||||
logger.error("Trainer lacks argument {}".format(req_key)) | |||||
raise ValueError("Trainer lacks argument {}".format(req_key)) | |||||
for key in default_args: | |||||
if key in kwargs: | |||||
if isinstance(kwargs[key], type(default_args[key])): | |||||
default_args[key] = kwargs[key] | |||||
else: | |||||
msg = "Argument %s type mismatch: expected %s while get %s" % ( | |||||
key, type(default_args[key]), type(kwargs[key])) | |||||
logger.error(msg) | |||||
raise ValueError(msg) | |||||
else: | |||||
# BaseTrainer doesn't care about extra arguments | |||||
pass | |||||
print(default_args) | |||||
self.n_epochs = default_args["epochs"] | |||||
self.batch_size = default_args["batch_size"] | |||||
self.pickle_path = default_args["pickle_path"] | |||||
self.validate = default_args["validate"] | |||||
self.save_best_dev = default_args["save_best_dev"] | |||||
self.use_cuda = default_args["use_cuda"] | |||||
self.model_name = default_args["model_name"] | |||||
self._model = None | |||||
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None | |||||
self._optimizer = None | |||||
self._optimizer_proto = default_args["optimizer"] | |||||
def train(self, network, train_data, dev_data=None): | def train(self, network, train_data, dev_data=None): | ||||
"""General Training Steps | """General Training Steps | ||||
@@ -72,12 +111,9 @@ class BaseTrainer(object): | |||||
""" | """ | ||||
# prepare model and data, transfer model to gpu if available | # prepare model and data, transfer model to gpu if available | ||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
self.model = network.cuda() | |||||
self._model = network.cuda() | |||||
else: | else: | ||||
self.model = network | |||||
# train_data = self.load_train_data(self.pickle_path) | |||||
# logger.info("training data loaded") | |||||
self._model = network | |||||
# define tester over dev data | # define tester over dev data | ||||
if self.validate: | if self.validate: | ||||
@@ -88,7 +124,9 @@ class BaseTrainer(object): | |||||
logger.info("validator defined as {}".format(str(validator))) | logger.info("validator defined as {}".format(str(validator))) | ||||
self.define_optimizer() | self.define_optimizer() | ||||
logger.info("optimizer defined as {}".format(str(self.optimizer))) | |||||
logger.info("optimizer defined as {}".format(str(self._optimizer))) | |||||
self.define_loss() | |||||
logger.info("loss function defined as {}".format(str(self._loss_func))) | |||||
# main training epochs | # main training epochs | ||||
n_samples = len(train_data) | n_samples = len(train_data) | ||||
@@ -113,7 +151,7 @@ class BaseTrainer(object): | |||||
validator.test(network, dev_data) | validator.test(network, dev_data) | ||||
if self.save_best_dev and self.best_eval_result(validator): | if self.save_best_dev and self.best_eval_result(validator): | ||||
self.save_model(network) | |||||
self.save_model(network, self.model_name) | |||||
print("saved better model selected by dev") | print("saved better model selected by dev") | ||||
logger.info("saved better model selected by dev") | logger.info("saved better model selected by dev") | ||||
@@ -153,6 +191,11 @@ class BaseTrainer(object): | |||||
logger.error("the number of folds in train and dev data unequals {}!={}".format(len(train_data_cv), | logger.error("the number of folds in train and dev data unequals {}!={}".format(len(train_data_cv), | ||||
len(dev_data_cv))) | len(dev_data_cv))) | ||||
raise RuntimeError("the number of folds in train and dev data unequals") | raise RuntimeError("the number of folds in train and dev data unequals") | ||||
if self.validate is False: | |||||
logger.warn("Cross validation requires self.validate to be True. Please turn it on. ") | |||||
print("[warning] Cross validation requires self.validate to be True. Please turn it on. ") | |||||
self.validate = True | |||||
n_fold = len(train_data_cv) | n_fold = len(train_data_cv) | ||||
logger.info("perform {} folds cross validation.".format(n_fold)) | logger.info("perform {} folds cross validation.".format(n_fold)) | ||||
for i in range(n_fold): | for i in range(n_fold): | ||||
@@ -186,7 +229,7 @@ class BaseTrainer(object): | |||||
""" | """ | ||||
Define framework-specific optimizer specified by the models. | Define framework-specific optimizer specified by the models. | ||||
""" | """ | ||||
raise NotImplementedError | |||||
self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters()) | |||||
def update(self): | def update(self): | ||||
""" | """ | ||||
@@ -194,7 +237,7 @@ class BaseTrainer(object): | |||||
For PyTorch, just call optimizer to update. | For PyTorch, just call optimizer to update. | ||||
""" | """ | ||||
raise NotImplementedError | |||||
self._optimizer.step() | |||||
def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -206,7 +249,8 @@ class BaseTrainer(object): | |||||
For PyTorch, just do "loss.backward()" | For PyTorch, just do "loss.backward()" | ||||
""" | """ | ||||
raise NotImplementedError | |||||
self._model.zero_grad() | |||||
loss.backward() | |||||
def get_loss(self, predict, truth): | def get_loss(self, predict, truth): | ||||
""" | """ | ||||
@@ -215,21 +259,25 @@ class BaseTrainer(object): | |||||
:param truth: ground truth label vector | :param truth: ground truth label vector | ||||
:return: a scalar | :return: a scalar | ||||
""" | """ | ||||
if self.loss_func is None: | |||||
if hasattr(self.model, "loss"): | |||||
self.loss_func = self.model.loss | |||||
logger.info("The model has a loss function, use it.") | |||||
else: | |||||
logger.info("The model didn't define loss, use Trainer's loss.") | |||||
self.define_loss() | |||||
return self.loss_func(predict, truth) | |||||
return self._loss_func(predict, truth) | |||||
def define_loss(self): | def define_loss(self): | ||||
""" | """ | ||||
Assign an instance of loss function to self.loss_func | |||||
E.g. self.loss_func = nn.CrossEntropyLoss() | |||||
if the model defines a loss, use model's loss. | |||||
Otherwise, Trainer must has a loss argument, use it as loss. | |||||
These two losses cannot be defined at the same time. | |||||
Trainer does not handle loss definition or choose default losses. | |||||
""" | """ | ||||
raise NotImplementedError | |||||
if hasattr(self._model, "loss") and self._loss_func is not None: | |||||
raise ValueError("Both the model and Trainer define loss. Please take out your loss.") | |||||
if hasattr(self._model, "loss"): | |||||
self._loss_func = self._model.loss | |||||
logger.info("The model has a loss function, use it.") | |||||
else: | |||||
if self._loss_func is None: | |||||
raise ValueError("Please specify a loss function.") | |||||
logger.info("The model didn't define loss, use Trainer's loss.") | |||||
def best_eval_result(self, validator): | def best_eval_result(self, validator): | ||||
""" | """ | ||||
@@ -238,12 +286,15 @@ class BaseTrainer(object): | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def save_model(self, network): | |||||
def save_model(self, network, model_name): | |||||
""" | """ | ||||
:param network: the PyTorch model | :param network: the PyTorch model | ||||
:param model_name: str | |||||
model_best_dev.pkl may be overwritten by a better model in future epochs. | model_best_dev.pkl may be overwritten by a better model in future epochs. | ||||
""" | """ | ||||
ModelSaver(self.model_saved_path + "model_best_dev.pkl").save_pytorch(network) | |||||
if model_name[-4:] != ".pkl": | |||||
model_name += ".pkl" | |||||
ModelSaver(self.pickle_path + model_name).save_pytorch(network) | |||||
def _create_validator(self, valid_args): | def _create_validator(self, valid_args): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -266,18 +317,12 @@ class ToyTrainer(BaseTrainer): | |||||
return network(x) | return network(x) | ||||
def grad_backward(self, loss): | def grad_backward(self, loss): | ||||
self.model.zero_grad() | |||||
self._model.zero_grad() | |||||
loss.backward() | loss.backward() | ||||
def get_loss(self, pred, truth): | def get_loss(self, pred, truth): | ||||
return np.mean(np.square(pred - truth)) | return np.mean(np.square(pred - truth)) | ||||
def define_optimizer(self): | |||||
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01) | |||||
def update(self): | |||||
self.optimizer.step() | |||||
class SeqLabelTrainer(BaseTrainer): | class SeqLabelTrainer(BaseTrainer): | ||||
""" | """ | ||||
@@ -285,24 +330,14 @@ class SeqLabelTrainer(BaseTrainer): | |||||
""" | """ | ||||
def __init__(self, train_args): | |||||
super(SeqLabelTrainer, self).__init__(train_args) | |||||
self.vocab_size = train_args["vocab_size"] | |||||
self.num_classes = train_args["num_classes"] | |||||
def __init__(self, **kwargs): | |||||
super(SeqLabelTrainer, self).__init__(**kwargs) | |||||
# self.vocab_size = kwargs["vocab_size"] | |||||
# self.num_classes = kwargs["num_classes"] | |||||
self.max_len = None | self.max_len = None | ||||
self.mask = None | self.mask = None | ||||
self.best_accuracy = 0.0 | self.best_accuracy = 0.0 | ||||
def define_optimizer(self): | |||||
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) | |||||
def grad_backward(self, loss): | |||||
self.model.zero_grad() | |||||
loss.backward() | |||||
def update(self): | |||||
self.optimizer.step() | |||||
def data_forward(self, network, inputs): | def data_forward(self, network, inputs): | ||||
if not isinstance(inputs, tuple): | if not isinstance(inputs, tuple): | ||||
raise RuntimeError("output_length must be true for sequence modeling. Receive {}".format(type(inputs[0]))) | raise RuntimeError("output_length must be true for sequence modeling. Receive {}".format(type(inputs[0]))) | ||||
@@ -330,7 +365,7 @@ class SeqLabelTrainer(BaseTrainer): | |||||
batch_size, max_len = predict.size(0), predict.size(1) | batch_size, max_len = predict.size(0), predict.size(1) | ||||
assert truth.shape == (batch_size, max_len) | assert truth.shape == (batch_size, max_len) | ||||
loss = self.model.loss(predict, truth, self.mask) | |||||
loss = self._model.loss(predict, truth, self.mask) | |||||
return loss | return loss | ||||
def best_eval_result(self, validator): | def best_eval_result(self, validator): | ||||
@@ -345,48 +380,25 @@ class SeqLabelTrainer(BaseTrainer): | |||||
return Action.make_batch(iterator, output_length=True, use_cuda=self.use_cuda) | return Action.make_batch(iterator, output_length=True, use_cuda=self.use_cuda) | ||||
def _create_validator(self, valid_args): | def _create_validator(self, valid_args): | ||||
return SeqLabelTester(valid_args) | |||||
return SeqLabelTester(**valid_args) | |||||
class ClassificationTrainer(BaseTrainer): | class ClassificationTrainer(BaseTrainer): | ||||
"""Trainer for classification.""" | """Trainer for classification.""" | ||||
def __init__(self, train_args): | |||||
super(ClassificationTrainer, self).__init__(train_args) | |||||
self.learn_rate = train_args["learn_rate"] | |||||
self.momentum = train_args["momentum"] | |||||
def __init__(self, **train_args): | |||||
super(ClassificationTrainer, self).__init__(**train_args) | |||||
self.iterator = None | self.iterator = None | ||||
self.loss_func = None | self.loss_func = None | ||||
self.optimizer = None | self.optimizer = None | ||||
self.best_accuracy = 0 | self.best_accuracy = 0 | ||||
def define_loss(self): | |||||
self.loss_func = nn.CrossEntropyLoss() | |||||
def define_optimizer(self): | |||||
""" | |||||
Define framework-specific optimizer specified by the models. | |||||
""" | |||||
self.optimizer = torch.optim.SGD( | |||||
self.model.parameters(), | |||||
lr=self.learn_rate, | |||||
momentum=self.momentum) | |||||
def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
"""Forward through network.""" | """Forward through network.""" | ||||
logits = network(x) | logits = network(x) | ||||
return logits | return logits | ||||
def grad_backward(self, loss): | |||||
"""Compute gradient backward.""" | |||||
self.model.zero_grad() | |||||
loss.backward() | |||||
def update(self): | |||||
"""Apply gradient.""" | |||||
self.optimizer.step() | |||||
def make_batch(self, iterator): | def make_batch(self, iterator): | ||||
return Action.make_batch(iterator, output_length=False, use_cuda=self.use_cuda) | return Action.make_batch(iterator, output_length=False, use_cuda=self.use_cuda) | ||||
@@ -404,4 +416,4 @@ class ClassificationTrainer(BaseTrainer): | |||||
return False | return False | ||||
def _create_validator(self, valid_args): | def _create_validator(self, valid_args): | ||||
return ClassificationTester(valid_args) | |||||
return ClassificationTester(**valid_args) |
@@ -94,6 +94,10 @@ class ConfigSection(object): | |||||
def __contains__(self, item): | def __contains__(self, item): | ||||
return item in self.__dict__.keys() | return item in self.__dict__.keys() | ||||
@property | |||||
def data(self): | |||||
return self.__dict__ | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
config = ConfigLoader('configLoader', 'there is no data') | config = ConfigLoader('configLoader', 'there is no data') | ||||
@@ -18,7 +18,6 @@ MLP_HIDDEN = 2000 | |||||
CLASSES_NUM = 5 | CLASSES_NUM = 5 | ||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.core.trainer import BaseTrainer | |||||
class MyNet(BaseModel): | class MyNet(BaseModel): | ||||
@@ -60,18 +59,6 @@ class Net(nn.Module): | |||||
return x, penalty | return x, penalty | ||||
class MyTrainer(BaseTrainer): | |||||
def __init__(self, args): | |||||
super(MyTrainer, self).__init__(args) | |||||
self.optimizer = None | |||||
def define_optimizer(self): | |||||
self.optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) | |||||
def define_loss(self): | |||||
self.loss_func = nn.CrossEntropyLoss() | |||||
def train(model_dict=None, using_cuda=True, learning_rate=0.06,\ | def train(model_dict=None, using_cuda=True, learning_rate=0.06,\ | ||||
momentum=0.3, batch_size=32, epochs=5, coef=1.0, interval=10): | momentum=0.3, batch_size=32, epochs=5, coef=1.0, interval=10): | ||||
""" | """ | ||||
@@ -1,65 +1,11 @@ | |||||
[General] | |||||
revision = "first" | |||||
datapath = "./data/smallset/imdb/" | |||||
embed_path = "./data/smallset/imdb/embedding.txt" | |||||
optimizer = "adam" | |||||
attn_mode = "rout" | |||||
seq_encoder = "bilstm" | |||||
out_caps_num = 5 | |||||
rout_iter = 3 | |||||
max_snt_num = 40 | |||||
max_wd_num = 40 | |||||
max_epochs = 50 | |||||
pre_trained = true | |||||
batch_sz = 32 | |||||
batch_sz_min = 32 | |||||
bucket_sz = 5000 | |||||
partial_update_until_epoch = 2 | |||||
embed_size = 300 | |||||
hidden_size = 200 | |||||
dense_hidden = [300, 10] | |||||
lr = 0.0002 | |||||
decay_steps = 1000 | |||||
decay_rate = 0.9 | |||||
dropout = 0.2 | |||||
early_stopping = 7 | |||||
reg = 1e-06 | |||||
[My] | |||||
datapath = "./data/smallset/imdb/" | |||||
embed_path = "./data/smallset/imdb/embedding.txt" | |||||
optimizer = "adam" | |||||
attn_mode = "rout" | |||||
seq_encoder = "bilstm" | |||||
out_caps_num = 5 | |||||
rout_iter = 3 | |||||
max_snt_num = 40 | |||||
max_wd_num = 40 | |||||
max_epochs = 50 | |||||
pre_trained = true | |||||
batch_sz = 32 | |||||
batch_sz_min = 32 | |||||
bucket_sz = 5000 | |||||
partial_update_until_epoch = 2 | |||||
embed_size = 300 | |||||
hidden_size = 200 | |||||
dense_hidden = [300, 10] | |||||
lr = 0.0002 | |||||
decay_steps = 1000 | |||||
decay_rate = 0.9 | |||||
dropout = 0.2 | |||||
early_stopping = 70 | |||||
reg = 1e-05 | |||||
test = 5 | |||||
new_attr = 40 | |||||
[POS] | |||||
[test_seq_label_trainer] | |||||
epochs = 1 | epochs = 1 | ||||
batch_size = 32 | batch_size = 32 | ||||
pickle_path = "./data_for_tests/" | |||||
validate = true | validate = true | ||||
save_best_dev = true | save_best_dev = true | ||||
model_saved_path = "./" | |||||
use_cuda = true | |||||
[test_seq_label_model] | |||||
rnn_hidden_units = 100 | rnn_hidden_units = 100 | ||||
rnn_layers = 1 | rnn_layers = 1 | ||||
rnn_bi_direction = true | rnn_bi_direction = true | ||||
@@ -68,13 +14,12 @@ dropout = 0.5 | |||||
use_crf = true | use_crf = true | ||||
use_cuda = true | use_cuda = true | ||||
[POS_test] | |||||
[test_seq_label_tester] | |||||
save_output = true | save_output = true | ||||
validate_in_training = true | validate_in_training = true | ||||
save_dev_input = false | save_dev_input = false | ||||
save_loss = true | save_loss = true | ||||
batch_size = 1 | batch_size = 1 | ||||
pickle_path = "./data_for_tests/" | |||||
rnn_hidden_units = 100 | rnn_hidden_units = 100 | ||||
rnn_layers = 1 | rnn_layers = 1 | ||||
rnn_bi_direction = true | rnn_bi_direction = true | ||||
@@ -84,7 +29,6 @@ use_crf = true | |||||
use_cuda = true | use_cuda = true | ||||
[POS_infer] | [POS_infer] | ||||
pickle_path = "./data_for_tests/" | |||||
rnn_hidden_units = 100 | rnn_hidden_units = 100 | ||||
rnn_layers = 1 | rnn_layers = 1 | ||||
rnn_bi_direction = true | rnn_bi_direction = true | ||||
@@ -95,14 +39,9 @@ num_classes = 27 | |||||
[text_class] | [text_class] | ||||
epochs = 1 | epochs = 1 | ||||
batch_size = 10 | batch_size = 10 | ||||
pickle_path = "./save_path/" | |||||
validate = false | validate = false | ||||
save_best_dev = false | save_best_dev = false | ||||
model_saved_path = "./save_path/" | |||||
use_cuda = true | use_cuda = true | ||||
learn_rate = 1e-3 | learn_rate = 1e-3 | ||||
momentum = 0.9 | momentum = 0.9 | ||||
[text_class_model] | |||||
vocab_size = 867 | |||||
num_classes = 18 | |||||
model_name = "class_model.pkl" |
@@ -20,7 +20,7 @@ class MyNERTrainer(SeqLabelTrainer): | |||||
override | override | ||||
:return: | :return: | ||||
""" | """ | ||||
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) | |||||
self.optimizer = torch.optim.Adam(self._model.parameters(), lr=0.001) | |||||
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=3000, gamma=0.5) | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=3000, gamma=0.5) | ||||
def update(self): | def update(self): | ||||
@@ -1,7 +1,7 @@ | |||||
import os | |||||
import sys | import sys | ||||
sys.path.append("..") | sys.path.append("..") | ||||
import argparse | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader | from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader | ||||
@@ -11,17 +11,29 @@ from fastNLP.loader.model_loader import ModelLoader | |||||
from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
from fastNLP.core.predictor import SeqLabelInfer | from fastNLP.core.predictor import SeqLabelInfer | ||||
from fastNLP.core.optimizer import Optimizer | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files") | |||||
parser.add_argument("-t", "--train", type=str, default="./data_for_tests/people.txt", | |||||
help="path to the training data") | |||||
parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file") | |||||
parser.add_argument("-m", "--model_name", type=str, default="seq_label_model.pkl", help="the name of the model") | |||||
parser.add_argument("-i", "--infer", type=str, default="data_for_tests/people_infer.txt", | |||||
help="data used for inference") | |||||
data_name = "people.txt" | |||||
data_path = "data_for_tests/people.txt" | |||||
pickle_path = "seq_label/" | |||||
data_infer_path = "data_for_tests/people_infer.txt" | |||||
args = parser.parse_args() | |||||
pickle_path = args.save | |||||
model_name = args.model_name | |||||
config_dir = args.config | |||||
data_path = args.train | |||||
data_infer_path = args.infer | |||||
def infer(): | def infer(): | ||||
# Load infer configuration, the same as test | # Load infer configuration, the same as test | ||||
test_args = ConfigSection() | test_args = ConfigSection() | ||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
ConfigLoader("config.cfg", "").load_config(config_dir, {"POS_infer": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
word2index = load_pickle(pickle_path, "word2id.pkl") | word2index = load_pickle(pickle_path, "word2id.pkl") | ||||
@@ -33,11 +45,11 @@ def infer(): | |||||
model = SeqLabeling(test_args) | model = SeqLabeling(test_args) | ||||
# Dump trained parameters into the model | # Dump trained parameters into the model | ||||
ModelLoader.load_pytorch(model, pickle_path + "saved_model.pkl") | |||||
ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | |||||
print("model loaded!") | print("model loaded!") | ||||
# Data Loader | # Data Loader | ||||
raw_data_loader = BaseLoader(data_name, data_infer_path) | |||||
raw_data_loader = BaseLoader("xxx", data_infer_path) | |||||
infer_data = raw_data_loader.load_lines() | infer_data = raw_data_loader.load_lines() | ||||
# Inference interface | # Inference interface | ||||
@@ -51,49 +63,72 @@ def infer(): | |||||
def train_and_test(): | def train_and_test(): | ||||
# Config Loader | # Config Loader | ||||
train_args = ConfigSection() | |||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | |||||
trainer_args = ConfigSection() | |||||
model_args = ConfigSection() | |||||
ConfigLoader("config.cfg", "").load_config(config_dir, { | |||||
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | |||||
# Data Loader | # Data Loader | ||||
pos_loader = POSDatasetLoader(data_name, data_path) | |||||
pos_loader = POSDatasetLoader("xxx", data_path) | |||||
train_data = pos_loader.load_lines() | train_data = pos_loader.load_lines() | ||||
# Preprocessor | # Preprocessor | ||||
p = SeqLabelPreprocess() | p = SeqLabelPreprocess() | ||||
data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5) | data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5) | ||||
train_args["vocab_size"] = p.vocab_size | |||||
train_args["num_classes"] = p.num_classes | |||||
# Trainer | |||||
trainer = SeqLabelTrainer(train_args) | |||||
model_args["vocab_size"] = p.vocab_size | |||||
model_args["num_classes"] = p.num_classes | |||||
# Trainer: two definition styles | |||||
# 1 | |||||
# trainer = SeqLabelTrainer(trainer_args.data) | |||||
# 2 | |||||
trainer = SeqLabelTrainer( | |||||
epochs=trainer_args["epochs"], | |||||
batch_size=trainer_args["batch_size"], | |||||
validate=trainer_args["validate"], | |||||
use_cuda=trainer_args["use_cuda"], | |||||
pickle_path=pickle_path, | |||||
save_best_dev=trainer_args["save_best_dev"], | |||||
model_name=model_name, | |||||
optimizer=Optimizer("SGD", lr=0.01, momentum=0.9), | |||||
) | |||||
# Model | # Model | ||||
model = SeqLabeling(train_args) | |||||
model = SeqLabeling(model_args) | |||||
# Start training | # Start training | ||||
trainer.train(model, data_train, data_dev) | trainer.train(model, data_train, data_dev) | ||||
print("Training finished!") | print("Training finished!") | ||||
# Saver | # Saver | ||||
saver = ModelSaver(pickle_path + "saved_model.pkl") | |||||
saver = ModelSaver(os.path.join(pickle_path, model_name)) | |||||
saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
print("Model saved!") | print("Model saved!") | ||||
del model, trainer, pos_loader | del model, trainer, pos_loader | ||||
# Define the same model | # Define the same model | ||||
model = SeqLabeling(train_args) | |||||
model = SeqLabeling(model_args) | |||||
# Dump trained parameters into the model | # Dump trained parameters into the model | ||||
ModelLoader.load_pytorch(model, pickle_path + "saved_model.pkl") | |||||
ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | |||||
print("model loaded!") | print("model loaded!") | ||||
# Load test configuration | # Load test configuration | ||||
test_args = ConfigSection() | |||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
tester_args = ConfigSection() | |||||
ConfigLoader("config.cfg", "").load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||||
# Tester | # Tester | ||||
tester = SeqLabelTester(test_args) | |||||
tester = SeqLabelTester(save_output=False, | |||||
save_loss=False, | |||||
save_best_dev=False, | |||||
batch_size=8, | |||||
use_cuda=False, | |||||
pickle_path=pickle_path, | |||||
model_name="seq_label_in_test.pkl", | |||||
print_every_step=1 | |||||
) | |||||
# Start testing with validation data | # Start testing with validation data | ||||
tester.test(model, data_dev) | tester.test(model, data_dev) | ||||
@@ -105,4 +140,4 @@ def train_and_test(): | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
train_and_test() | train_and_test() | ||||
# infer() | |||||
infer() |
@@ -1,6 +1,7 @@ | |||||
# Python: 3.5 | # Python: 3.5 | ||||
# encoding: utf-8 | # encoding: utf-8 | ||||
import argparse | |||||
import os | import os | ||||
import sys | import sys | ||||
@@ -13,75 +14,105 @@ from fastNLP.loader.model_loader import ModelLoader | |||||
from fastNLP.core.preprocess import ClassPreprocess | from fastNLP.core.preprocess import ClassPreprocess | ||||
from fastNLP.models.cnn_text_classification import CNNText | from fastNLP.models.cnn_text_classification import CNNText | ||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.loss import Loss | |||||
save_path = "./test_classification/" | |||||
data_dir = "./data_for_tests/" | |||||
train_file = 'text_classify.txt' | |||||
model_name = "model_class.pkl" | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files") | |||||
parser.add_argument("-t", "--train", type=str, default="./data_for_tests/text_classify.txt", | |||||
help="path to the training data") | |||||
parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file") | |||||
parser.add_argument("-m", "--model_name", type=str, default="classify_model.pkl", help="the name of the model") | |||||
args = parser.parse_args() | |||||
save_dir = args.save | |||||
train_data_dir = args.train | |||||
model_name = args.model_name | |||||
config_dir = args.config | |||||
def infer(): | def infer(): | ||||
# load dataset | # load dataset | ||||
print("Loading data...") | print("Loading data...") | ||||
ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file)) | |||||
ds_loader = ClassDatasetLoader("train", train_data_dir) | |||||
data = ds_loader.load() | data = ds_loader.load() | ||||
unlabeled_data = [x[0] for x in data] | unlabeled_data = [x[0] for x in data] | ||||
# pre-process data | # pre-process data | ||||
pre = ClassPreprocess() | pre = ClassPreprocess() | ||||
vocab_size, n_classes = pre.run(data, pickle_path=save_path) | |||||
print("vocabulary size:", vocab_size) | |||||
print("number of classes:", n_classes) | |||||
data = pre.run(data, pickle_path=save_dir) | |||||
print("vocabulary size:", pre.vocab_size) | |||||
print("number of classes:", pre.num_classes) | |||||
model_args = ConfigSection() | model_args = ConfigSection() | ||||
ConfigLoader.load_config("data_for_tests/config", {"text_class_model": model_args}) | |||||
# TODO: load from config file | |||||
model_args["vocab_size"] = pre.vocab_size | |||||
model_args["num_classes"] = pre.num_classes | |||||
# ConfigLoader.load_config(config_dir, {"text_class_model": model_args}) | |||||
# construct model | # construct model | ||||
print("Building model...") | print("Building model...") | ||||
cnn = CNNText(model_args) | cnn = CNNText(model_args) | ||||
# Dump trained parameters into the model | # Dump trained parameters into the model | ||||
ModelLoader.load_pytorch(cnn, "./data_for_tests/saved_model.pkl") | |||||
ModelLoader.load_pytorch(cnn, os.path.join(save_dir, model_name)) | |||||
print("model loaded!") | print("model loaded!") | ||||
infer = ClassificationInfer(data_dir) | |||||
infer = ClassificationInfer(pickle_path=save_dir) | |||||
results = infer.predict(cnn, unlabeled_data) | results = infer.predict(cnn, unlabeled_data) | ||||
print(results) | print(results) | ||||
def train(): | def train(): | ||||
train_args, model_args = ConfigSection(), ConfigSection() | train_args, model_args = ConfigSection(), ConfigSection() | ||||
ConfigLoader.load_config("data_for_tests/config", {"text_class": train_args, "text_class_model": model_args}) | |||||
ConfigLoader.load_config(config_dir, {"text_class": train_args}) | |||||
# load dataset | # load dataset | ||||
print("Loading data...") | print("Loading data...") | ||||
ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file)) | |||||
ds_loader = ClassDatasetLoader("train", train_data_dir) | |||||
data = ds_loader.load() | data = ds_loader.load() | ||||
print(data[0]) | print(data[0]) | ||||
# pre-process data | # pre-process data | ||||
pre = ClassPreprocess() | pre = ClassPreprocess() | ||||
data_train = pre.run(data, pickle_path=save_path) | |||||
data_train = pre.run(data, pickle_path=save_dir) | |||||
print("vocabulary size:", pre.vocab_size) | print("vocabulary size:", pre.vocab_size) | ||||
print("number of classes:", pre.num_classes) | print("number of classes:", pre.num_classes) | ||||
model_args["num_classes"] = pre.num_classes | |||||
model_args["vocab_size"] = pre.vocab_size | |||||
# construct model | # construct model | ||||
print("Building model...") | print("Building model...") | ||||
model = CNNText(model_args) | model = CNNText(model_args) | ||||
# ConfigSaver().save_config(config_dir, {"text_class_model": model_args}) | |||||
# train | # train | ||||
print("Training...") | print("Training...") | ||||
trainer = ClassificationTrainer(train_args) | |||||
# 1 | |||||
# trainer = ClassificationTrainer(train_args) | |||||
# 2 | |||||
trainer = ClassificationTrainer(epochs=train_args["epochs"], | |||||
batch_size=train_args["batch_size"], | |||||
validate=train_args["validate"], | |||||
use_cuda=train_args["use_cuda"], | |||||
pickle_path=save_dir, | |||||
save_best_dev=train_args["save_best_dev"], | |||||
model_name=model_name, | |||||
loss=Loss("cross_entropy"), | |||||
optimizer=Optimizer("SGD", lr=0.001, momentum=0.9)) | |||||
trainer.train(model, data_train) | trainer.train(model, data_train) | ||||
print("Training finished!") | print("Training finished!") | ||||
saver = ModelSaver("./data_for_tests/saved_model.pkl") | |||||
saver = ModelSaver(os.path.join(save_dir, model_name)) | |||||
saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
print("Model saved!") | print("Model saved!") | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
train() | train() | ||||
# infer() | |||||
infer() |