diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 684bd18d..db0ebc53 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -1,3 +1,5 @@ +import numpy as np + from fastNLP.core.fieldarray import FieldArray _READERS = {} @@ -6,7 +8,7 @@ _READERS = {} def construct_dataset(sentences): """Construct a data set from a list of sentences. - :param sentences: list of str + :param sentences: list of list of str :return dataset: a DataSet object """ dataset = DataSet() @@ -18,7 +20,9 @@ def construct_dataset(sentences): class DataSet(object): - """A DataSet object is a list of Instance objects. + """DataSet is the collection of examples. + DataSet provides instance-level interface. You can append and access an instance of the DataSet. + However, it stores data in a different way: Field-first, Instance-second. """ @@ -47,6 +51,11 @@ class DataSet(object): in self.dataset.get_fields().keys()]) def __init__(self, data=None): + """ + + :param data: a dict or a list. If it is a dict, the key is the name of a field and the value is the field. + If it is a list, it must be a list of Instance objects. + """ self.field_arrays = {} if data is not None: if isinstance(data, dict): @@ -78,8 +87,14 @@ class DataSet(object): self.append(ins_list) def append(self, ins): - # no field + """Add an instance to the DataSet. + If the DataSet is not empty, the instance must have the same field names as the rest instances in the DataSet. + + :param ins: an Instance object + + """ if len(self.field_arrays) == 0: + # DataSet has no field yet for name, field in ins.fields.items(): self.field_arrays[name] = FieldArray(name, [field]) else: @@ -89,6 +104,15 @@ class DataSet(object): self.field_arrays[name].append(field) def add_field(self, name, fields, padding_val=0, need_tensor=False, is_target=False): + """ + + :param name: + :param fields: + :param padding_val: + :param need_tensor: + :param is_target: + :return: + """ if len(self.field_arrays) != 0: assert len(self) == len(fields) self.field_arrays[name] = FieldArray(name, fields, @@ -210,6 +234,20 @@ class DataSet(object): else: return results + def split(self, test_ratio): + assert isinstance(test_ratio, float) + all_indices = [_ for _ in range(len(self))] + np.random.shuffle(all_indices) + test_indices = all_indices[:int(test_ratio)] + train_indices = all_indices[int(test_ratio):] + test_set = DataSet() + train_set = DataSet() + for idx in test_indices: + test_set.append(self[idx]) + for idx in train_indices: + train_set.append(self[idx]) + return train_set, test_set + if __name__ == '__main__': from fastNLP.core.instance import Instance diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py deleted file mode 100644 index 0df103b2..00000000 --- a/fastNLP/core/field.py +++ /dev/null @@ -1,89 +0,0 @@ -import torch - - -class Field(object): - """A field defines a data type. - - """ - - def __init__(self, content, is_target: bool): - self.is_target = is_target - self.content = content - - def index(self, vocab): - """create index field - """ - raise NotImplementedError - - def __len__(self): - """number of samples - """ - assert self.content is not None - return len(self.content) - - def to_tensor(self, id_list): - """convert batch of index to tensor - """ - raise NotImplementedError - - def __repr__(self): - return self.content.__repr__() - - -class TextField(Field): - def __init__(self, text, is_target): - """ - :param text: list of strings - :param is_target: bool - """ - super(TextField, self).__init__(text, is_target) - - -class LabelField(Field): - """The Field representing a single label. Can be a string or integer. - - """ - - def __init__(self, label, is_target=True): - super(LabelField, self).__init__(label, is_target) - - -class SeqLabelField(Field): - def __init__(self, label_seq, is_target=True): - super(SeqLabelField, self).__init__(label_seq, is_target) - - -class CharTextField(Field): - def __init__(self, text, max_word_len, is_target=False): - super(CharTextField, self).__init__(is_target) - # TODO - raise NotImplementedError - self.max_word_len = max_word_len - self._index = [] - - def get_length(self): - return len(self.text) - - def contents(self): - return self.text.copy() - - def index(self, char_vocab): - if len(self._index) == 0: - for word in self.text: - char_index = [char_vocab[ch] for ch in word] - if self.max_word_len >= len(char_index): - char_index += [0] * (self.max_word_len - len(char_index)) - else: - self._index.clear() - raise RuntimeError("Word {} has more than {} characters. ".format(word, self.max_word_len)) - self._index.append(char_index) - return self._index - - def to_tensor(self, padding_length): - """ - - :param padding_length: int, the padding length of the word sequence. - :return : tensor of shape (padding_length, max_word_len) - """ - pads = [[0] * self.max_word_len] * (padding_length - self.get_length()) - return torch.LongTensor(self._index + pads) diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index deba6a07..2a0d33e0 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -5,9 +5,9 @@ import torch from fastNLP.core.batch import Batch from fastNLP.core.metrics import Evaluator from fastNLP.core.sampler import RandomSampler -from fastNLP.io.logger import create_logger -logger = create_logger(__name__, "./train_test.log") + +# logger = create_logger(__name__, "./train_test.log") class Tester(object): diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 0fd27f14..b879ad11 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -1,4 +1,3 @@ -import os import time from datetime import timedelta, datetime @@ -11,157 +10,76 @@ from fastNLP.core.metrics import Evaluator from fastNLP.core.optimizer import Optimizer from fastNLP.core.sampler import RandomSampler from fastNLP.core.tester import Tester -from fastNLP.io.logger import create_logger -from fastNLP.io.model_saver import ModelSaver - -logger = create_logger(__name__, "./train_test.log") -logger.disabled = True class Trainer(object): - """Operations of training a model, including data loading, gradient descent, and validation. + """Main Training Loop """ - def __init__(self, **kwargs): - """ - :param kwargs: dict of (key, value), or dict-like object. key is str. - - The base trainer requires the following keys: - - epochs: int, the number of epochs in training - - validate: bool, whether or not to validate on dev set - - batch_size: int - - pickle_path: str, the path to pickle files for pre-processing - """ + def __init__(self, train_data, model, n_epochs, batch_size, n_print, + dev_data=None, use_cuda=False, loss=Loss(None), save_path="./save", + optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), + evaluator=Evaluator(), + **kwargs): super(Trainer, self).__init__() - """ - "default_args" provides default value for important settings. - The initialization arguments "kwargs" with the same key (name) will override the default value. - "kwargs" must have the same type as "default_args" on corresponding keys. - Otherwise, error will raise. - """ - default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", - "save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1, - "valid_step": 500, "eval_sort_key": 'acc', - "loss": Loss(None), # used to pass type check - "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), - "eval_batch_size": 64, - "evaluator": Evaluator(), - } - """ - "required_args" is the collection of arguments that users must pass to Trainer explicitly. - This is used to warn users of essential settings in the training. - Specially, "required_args" does not have default value, so they have nothing to do with "default_args". - """ - required_args = {} - - for req_key in required_args: - if req_key not in kwargs: - logger.error("Trainer lacks argument {}".format(req_key)) - raise ValueError("Trainer lacks argument {}".format(req_key)) - - for key in default_args: - if key in kwargs: - if isinstance(kwargs[key], type(default_args[key])): - default_args[key] = kwargs[key] - else: - msg = "Argument %s type mismatch: expected %s while get %s" % ( - key, type(default_args[key]), type(kwargs[key])) - logger.error(msg) - raise ValueError(msg) - else: - # Trainer doesn't care about extra arguments - pass - print("Training Args {}".format(default_args)) - logger.info("Training Args {}".format(default_args)) - - self.n_epochs = int(default_args["epochs"]) - self.batch_size = int(default_args["batch_size"]) - self.eval_batch_size = int(default_args['eval_batch_size']) - self.pickle_path = default_args["pickle_path"] - self.validate = default_args["validate"] - self.save_best_dev = default_args["save_best_dev"] - self.use_cuda = default_args["use_cuda"] - self.model_name = default_args["model_name"] - self.print_every_step = int(default_args["print_every_step"]) - self.valid_step = int(default_args["valid_step"]) - if self.validate is not None: - assert self.valid_step > 0 - - self._model = None - self._loss_func = default_args["loss"].get() # return a pytorch loss function or None - self._optimizer = None - self._optimizer_proto = default_args["optimizer"] - self._evaluator = default_args["evaluator"] - self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') + self.train_data = train_data + self.dev_data = dev_data # If None, No validation. + self.model = model + self.n_epochs = int(n_epochs) + self.batch_size = int(batch_size) + self.use_cuda = bool(use_cuda) + self.save_path = str(save_path) + self.n_print = int(n_print) + + self.loss_func = self.model.loss if hasattr(self.model, "loss") else loss.get() + self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) + self.evaluator = evaluator + + if self.dev_data is not None: + valid_args = {"batch_size": self.batch_size, "save_path": self.save_path, + "use_cuda": self.use_cuda, "evaluator": self.evaluator} + self.tester = Tester(**valid_args) + + for k, v in kwargs.items(): + setattr(self, k, v) + + self._summary_writer = SummaryWriter(self.save_path + 'tensorboard_logs') self._graph_summaried = False - self._best_accuracy = 0.0 - self.eval_sort_key = default_args['eval_sort_key'] - self.validator = None - self.epoch = 0 self.step = 0 + self.start_time = None # start timestamp - def train(self, network, train_data, dev_data=None): - """General Training Procedure + print(self.__dict__) - :param network: a model - :param train_data: a DataSet instance, the training data - :param dev_data: a DataSet instance, the validation data (optional) + def train(self): + """Start Training. + + :return: """ - # transfer model to gpu if available if torch.cuda.is_available() and self.use_cuda: - self._model = network.cuda() - # self._model is used to access model-specific loss - else: - self._model = network - - print(self._model) - - # define Tester over dev data - self.dev_data = None - if self.validate: - default_valid_args = {"batch_size": self.eval_batch_size, "pickle_path": self.pickle_path, - "use_cuda": self.use_cuda, "evaluator": self._evaluator} - if self.validator is None: - self.validator = self._create_validator(default_valid_args) - logger.info("validator defined as {}".format(str(self.validator))) - self.dev_data = dev_data - - # optimizer and loss - self.define_optimizer() - logger.info("optimizer defined as {}".format(str(self._optimizer))) - self.define_loss() - logger.info("loss function defined as {}".format(str(self._loss_func))) - - # turn on network training mode - self.mode(network, is_test=False) - - # main training procedure + self.model = self.model.cuda() + + self.mode(self.model, is_test=False) + start = time.time() self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) print("training epochs started " + self.start_time) - logger.info("training epochs started " + self.start_time) - self.epoch, self.step = 1, 0 - while self.epoch <= self.n_epochs: - logger.info("training epoch {}".format(self.epoch)) - - # prepare mini-batch iterator - data_iterator = Batch(train_data, batch_size=self.batch_size, - sampler=BucketSampler(10, self.batch_size, "word_seq_origin_len"), + + epoch = 1 + while epoch <= self.n_epochs: + + data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) - logger.info("prepared data iterator") - # one forward and backward pass - self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, dev_data=dev_data) + self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start, self.n_print) - # validation - if self.validate: - self.valid_model() - self.save_model(self._model, 'training_model_' + self.start_time) - self.epoch += 1 + if self.dev_data: + self.do_validation() + self.save_model(self.model, 'training_model_' + self.start_time) + epoch += 1 - def _train_step(self, data_iterator, network, **kwargs): + def _train_epoch(self, data_iterator, model, epoch, dev_data, start, n_print, **kwargs): """Training process in one epoch. kwargs should contain: @@ -170,7 +88,7 @@ class Trainer(object): - epoch: int, """ for batch_x, batch_y in data_iterator: - prediction = self.data_forward(network, batch_x) + prediction = self.data_forward(model, batch_x) # TODO: refactor self.get_loss loss = prediction["loss"] if "loss" in prediction else self.get_loss(prediction, batch_y) @@ -179,35 +97,25 @@ class Trainer(object): self.grad_backward(loss) self.update() self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) - for name, param in self._model.named_parameters(): + for name, param in self.model.named_parameters(): if param.requires_grad: self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) - # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) - # self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) - if kwargs["n_print"] > 0 and self.step % kwargs["n_print"] == 0: + self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) + self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) + if n_print > 0 and self.step % n_print == 0: end = time.time() diff = timedelta(seconds=round(end - kwargs["start"])) print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( - self.epoch, self.step, loss.data, diff) + epoch, self.step, loss.data, diff) print(print_output) - logger.info(print_output) - if self.validate and self.valid_step > 0 and self.step > 0 and self.step % self.valid_step == 0: - self.valid_model() + self.step += 1 - def valid_model(self): - if self.dev_data is None: - raise RuntimeError( - "self.validate is True in trainer, but dev_data is None. Please provide the validation data.") - logger.info("validation started") - res = self.validator.test(self._model, self.dev_data) + def do_validation(self): + res = self.tester.test(self.model, self.dev_data) for name, num in res.items(): self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) - if self.save_best_dev and self.best_eval_result(res): - logger.info('save best result! {}'.format(res)) - print('save best result! {}'.format(res)) - self.save_model(self._model, 'best_model_' + self.start_time) - return res + self.save_model(self.model, 'best_model_' + self.start_time) def mode(self, model, is_test=False): """Train mode or Test mode. This is for PyTorch currently. @@ -221,23 +129,11 @@ class Trainer(object): else: model.train() - def define_optimizer(self, optim=None): - """Define framework-specific optimizer specified by the models. - - """ - if optim is not None: - # optimizer constructed by user - self._optimizer = optim - elif self._optimizer is None: - # optimizer constructed by proto - self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters()) - return self._optimizer - def update(self): """Perform weight update on a model. """ - self._optimizer.step() + self.optimizer.step() def data_forward(self, network, x): y = network(**x) @@ -253,7 +149,7 @@ class Trainer(object): For PyTorch, just do "loss.backward()" """ - self._model.zero_grad() + self.model.zero_grad() loss.backward() def get_loss(self, predict, truth): @@ -264,68 +160,37 @@ class Trainer(object): :return: a scalar """ if isinstance(predict, dict) and isinstance(truth, dict): - return self._loss_func(**predict, **truth) + return self.loss_func(**predict, **truth) if len(truth) > 1: raise NotImplementedError("Not ready to handle multi-labels.") truth = list(truth.values())[0] if len(truth) > 0 else None - return self._loss_func(predict, truth) - - def define_loss(self): - """Define a loss for the trainer. + return self.loss_func(predict, truth) - If the model defines a loss, use model's loss. - Otherwise, Trainer must has a loss argument, use it as loss. - These two losses cannot be defined at the same time. - Trainer does not handle loss definition or choose default losses. - """ - # if hasattr(self._model, "loss") and self._loss_func is not None: - # raise ValueError("Both the model and Trainer define loss. Please take out your loss.") - - if hasattr(self._model, "loss"): - self._loss_func = self._model.loss - logger.info("The model has a loss function, use it.") + def save_model(self, model, model_name, only_param=False): + if only_param: + torch.save(model.state_dict(), model_name) else: - if self._loss_func is None: - raise ValueError("Please specify a loss function.") - logger.info("The model didn't define loss, use Trainer's loss.") + torch.save(model, model_name) - def best_eval_result(self, metrics): - """Check if the current epoch yields better validation results. - :param validator: a Tester instance - :return: bool, True means current results on dev set is the best. - """ - if isinstance(metrics, tuple): - loss, metrics = metrics - - if isinstance(metrics, dict): - if len(metrics) == 1: - accuracy = list(metrics.values())[0] - else: - accuracy = metrics[self.eval_sort_key] - else: - accuracy = metrics +def best_eval_result(self, metrics): + """Check if the current epoch yields better validation results. - if accuracy > self._best_accuracy: - self._best_accuracy = accuracy - return True - else: - return False - - def save_model(self, network, model_name): - """Save this model with such a name. - This method may be called multiple times by Trainer to overwritten a better model. - - :param network: the PyTorch model - :param model_name: str - """ - if model_name[-4:] != ".pkl": - model_name += ".pkl" - ModelSaver(os.path.join(self.pickle_path, model_name)).save_pytorch(network) - - def _create_validator(self, valid_args): - return Tester(**valid_args) - - def set_validator(self, validor): - self.validator = validor + :return: bool, True means current results on dev set is the best. + """ + if isinstance(metrics, tuple): + loss, metrics = metrics + if isinstance(metrics, dict): + if len(metrics) == 1: + accuracy = list(metrics.values())[0] + else: + accuracy = metrics[self.eval_sort_key] + else: + accuracy = metrics + + if accuracy > self._best_accuracy: + self._best_accuracy = accuracy + return True + else: + return False diff --git a/fastNLP/io/config_saver.py b/fastNLP/io/config_saver.py index bee49b51..49d6804d 100644 --- a/fastNLP/io/config_saver.py +++ b/fastNLP/io/config_saver.py @@ -1,7 +1,6 @@ import os from fastNLP.io.config_loader import ConfigSection, ConfigLoader -from fastNLP.io.logger import create_logger class ConfigSaver(object): @@ -61,8 +60,8 @@ class ConfigSaver(object): continue if '=' not in line: - log = create_logger(__name__, './config_saver.log') - log.error("can NOT load config file [%s]" % self.file_path) + # log = create_logger(__name__, './config_saver.log') + # log.error("can NOT load config file [%s]" % self.file_path) raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) key = line.split('=', maxsplit=1)[0].strip() @@ -123,10 +122,10 @@ class ConfigSaver(object): change_file = True break if section_file[k] != section[k]: - logger = create_logger(__name__, "./config_loader.log") - logger.warning("section [%s] in config file [%s] has been changed" % ( - section_name, self.file_path - )) + # logger = create_logger(__name__, "./config_loader.log") + # logger.warning("section [%s] in config file [%s] has been changed" % ( + # section_name, self.file_path + #)) change_file = True break if not change_file: