| @@ -2,10 +2,10 @@ from .batch import Batch | |||||
| from .dataset import DataSet | from .dataset import DataSet | ||||
| from .fieldarray import FieldArray | from .fieldarray import FieldArray | ||||
| from .instance import Instance | from .instance import Instance | ||||
| from .losses import Loss | |||||
| from .metrics import Evaluator, ClassifyEvaluator, SNLIEvaluator, SeqLabelEvaluator | from .metrics import Evaluator, ClassifyEvaluator, SNLIEvaluator, SeqLabelEvaluator | ||||
| from .optimizer import Optimizer | |||||
| from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler | from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler | ||||
| from .tester import Tester | from .tester import Tester | ||||
| from .trainer import Trainer | from .trainer import Trainer | ||||
| from .vocabulary import Vocabulary | from .vocabulary import Vocabulary | ||||
| from .optimizer import Optimizer | |||||
| from .loss import Loss | |||||
| @@ -1,196 +0,0 @@ | |||||
| import torch | |||||
| def squash(predict , truth , **kwargs): | |||||
| '''To reshape tensors in order to fit Loss functions in pytorch | |||||
| :param predict : Tensor, model output | |||||
| :param truth : Tensor, truth from dataset | |||||
| :param **kwargs : extra arguments | |||||
| :return predict , truth: predict & truth after processing | |||||
| ''' | |||||
| return predict.view(-1 , predict.size()[-1]) , truth.view(-1,) | |||||
| def unpad(predict , truth , **kwargs): | |||||
| '''To process padded sequence output to get true loss | |||||
| Using pack_padded_sequence() method | |||||
| This method contains squash() | |||||
| :param predict : Tensor, [batch_size , max_len , tag_size] | |||||
| :param truth : Tensor, [batch_size , max_len] | |||||
| :param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist | |||||
| kwargs["lens"] : list or LongTensor, [batch_size] | |||||
| the i-th element is true lengths of i-th sequence | |||||
| :return predict , truth: predict & truth after processing | |||||
| ''' | |||||
| if kwargs.get("lens") is None: | |||||
| return predict , truth | |||||
| lens = torch.LongTensor(kwargs["lens"]) | |||||
| lens , idx = torch.sort(lens , descending = True) | |||||
| predict = torch.nn.utils.rnn.pack_padded_sequence(predict[idx] , lens , batch_first = True).data | |||||
| truth = torch.nn.utils.rnn.pack_padded_sequence(truth[idx] , lens , batch_first = True).data | |||||
| return predict , truth | |||||
| def unpad_mask(predict , truth , **kwargs): | |||||
| '''To process padded sequence output to get true loss | |||||
| Using mask() method | |||||
| This method contains squash() | |||||
| :param predict : Tensor, [batch_size , max_len , tag_size] | |||||
| :param truth : Tensor, [batch_size , max_len] | |||||
| :param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist | |||||
| kwargs["lens"] : list or LongTensor, [batch_size] | |||||
| the i-th element is true lengths of i-th sequence | |||||
| :return predict , truth: predict & truth after processing | |||||
| ''' | |||||
| if kwargs.get("lens") is None: | |||||
| return predict , truth | |||||
| mas = make_mask(kwargs["lens"] , truth.size()[1]) | |||||
| return mask(predict , truth , mask = mas) | |||||
| def mask(predict , truth , **kwargs): | |||||
| '''To select specific elements from Tensor | |||||
| This method contains squash() | |||||
| :param predict : Tensor, [batch_size , max_len , tag_size] | |||||
| :param truth : Tensor, [batch_size , max_len] | |||||
| :param **kwargs : extra arguments, kwargs["mask"] is expected to be exsist | |||||
| kwargs["mask"] : ByteTensor, [batch_size , max_len] | |||||
| the mask Tensor , the position that is 1 will be selected | |||||
| :return predict , truth: predict & truth after processing | |||||
| ''' | |||||
| if kwargs.get("mask") is None: | |||||
| return predict , truth | |||||
| mask = kwargs["mask"] | |||||
| predict , truth = squash(predict , truth) | |||||
| mask = mask.view(-1,) | |||||
| predict = torch.masked_select(predict.permute(1,0) , mask).view(predict.size()[-1] , -1).permute(1,0) | |||||
| truth = torch.masked_select(truth , mask) | |||||
| return predict , truth | |||||
| def make_mask(lens , tar_len): | |||||
| '''to generate a mask that select [:lens[i]] for i-th element | |||||
| embezzle from fastNLP.models.sequence_modeling.seq_mask | |||||
| :param lens : list or LongTensor, [batch_size] | |||||
| :param tar_len : int | |||||
| :return mask : ByteTensor | |||||
| ''' | |||||
| lens = torch.LongTensor(lens) | |||||
| mask = [torch.ge(lens, i + 1) for i in range(tar_len)] | |||||
| mask = torch.stack(mask, 1) | |||||
| return mask | |||||
| #map string to function. Just for more elegant using | |||||
| method_dict = { | |||||
| "squash" : squash, | |||||
| "unpad" : unpad, | |||||
| "unpad_mask" : unpad_mask, | |||||
| "mask" : mask, | |||||
| } | |||||
| loss_function_name = { | |||||
| "L1Loss".lower() : torch.nn.L1Loss, | |||||
| "BCELoss".lower() : torch.nn.BCELoss, | |||||
| "MSELoss".lower() : torch.nn.MSELoss, | |||||
| "NLLLoss".lower() : torch.nn.NLLLoss, | |||||
| "KLDivLoss".lower() : torch.nn.KLDivLoss, | |||||
| "NLLLoss2dLoss".lower() : torch.nn.NLLLoss2d, #every name should end with "loss" | |||||
| "SmoothL1Loss".lower() : torch.nn.SmoothL1Loss, | |||||
| "SoftMarginLoss".lower() : torch.nn.SoftMarginLoss, | |||||
| "PoissonNLLLoss".lower() : torch.nn.PoissonNLLLoss, | |||||
| "MultiMarginLoss".lower() : torch.nn.MultiMarginLoss, | |||||
| "CrossEntropyLoss".lower() : torch.nn.CrossEntropyLoss, | |||||
| "BCEWithLogitsLoss".lower() : torch.nn.BCEWithLogitsLoss, | |||||
| "MarginRankingLoss".lower() : torch.nn.MarginRankingLoss, | |||||
| "TripletMarginLoss".lower() : torch.nn.TripletMarginLoss, | |||||
| "HingeEmbeddingLoss".lower() : torch.nn.HingeEmbeddingLoss, | |||||
| "CosineEmbeddingLoss".lower() : torch.nn.CosineEmbeddingLoss, | |||||
| "MultiLabelMarginLoss".lower() : torch.nn.MultiLabelMarginLoss, | |||||
| "MultiLabelSoftMarginLoss".lower() : torch.nn.MultiLabelSoftMarginLoss, | |||||
| } | |||||
| class Loss(object): | |||||
| '''a Loss object is a callable object represents loss functions | |||||
| ''' | |||||
| def __init__(self , loss_name , pre_pro = [squash], **kwargs): | |||||
| ''' | |||||
| :param loss_name: str or None , the name of loss function | |||||
| :param pre_pro : list of function or str, methods to reform parameters before calculating loss | |||||
| the strings will be auto translated to pre-defined functions | |||||
| :param **kwargs: kwargs for torch loss function | |||||
| pre_pro funcsions should have three arguments: predict, truth, **arg | |||||
| predict and truth is the necessary parameters in loss function | |||||
| kwargs is the extra parameters passed-in when calling loss function | |||||
| pre_pro functions should return two objects, respectively predict and truth that after processed | |||||
| ''' | |||||
| if loss_name is None: | |||||
| # this is useful when Trainer.__init__ performs type check | |||||
| self._loss = None | |||||
| else: | |||||
| if not isinstance(loss_name, str): | |||||
| raise NotImplementedError | |||||
| else: | |||||
| self._loss = self._get_loss(loss_name , **kwargs) | |||||
| self.pre_pro = [f if callable(f) else method_dict.get(f) for f in pre_pro] | |||||
| def add_pre_pro(self , func): | |||||
| '''add a pre_pro function | |||||
| :param func: a function or str, methods to reform parameters before calculating loss | |||||
| the strings will be auto translated to pre-defined functions | |||||
| ''' | |||||
| if not callable(func): | |||||
| func = method_dict.get(func) | |||||
| if func is None: | |||||
| return | |||||
| self.pre_pro.append(func) | |||||
| @staticmethod | |||||
| def _get_loss(loss_name , **kwargs): | |||||
| '''Get loss function from torch | |||||
| :param loss_name: str, the name of loss function | |||||
| :param **kwargs: kwargs for torch loss function | |||||
| :return: A callable loss function object | |||||
| ''' | |||||
| loss_name = loss_name.strip().lower() | |||||
| loss_name = "".join(loss_name.split("_")) | |||||
| if len(loss_name) < 4 or loss_name[-4 : ] != "loss": | |||||
| loss_name += "loss" | |||||
| return loss_function_name[loss_name](**kwargs) | |||||
| def get(self): | |||||
| '''This method exists just for make some existing codes run error-freely | |||||
| ''' | |||||
| return self | |||||
| def __call__(self , predict , truth , **kwargs): | |||||
| '''call a loss function | |||||
| predict and truth will be processed by pre_pro methods in order of addition | |||||
| :param predict : Tensor, model output | |||||
| :param truth : Tensor, truth from dataset | |||||
| :param **kwargs : extra arguments, pass to pre_pro functions | |||||
| for example, if used unpad_mask() in pre_pro, there should be a kwarg named lens | |||||
| ''' | |||||
| for f in self.pre_pro: | |||||
| if f is None: | |||||
| continue | |||||
| predict , truth = f(predict , truth , **kwargs) | |||||
| return self._loss(predict , truth) | |||||
| @@ -0,0 +1,219 @@ | |||||
| import torch | |||||
| class LossBase(object): | |||||
| def __init__(self): | |||||
| self.param_map = {} | |||||
| def get_loss(self, *args, **kwargs): | |||||
| raise NotImplementedError | |||||
| def __call__(self, output_dict, predict_dict): | |||||
| pass | |||||
| class Loss(LossBase): | |||||
| def __init__(self): | |||||
| pass | |||||
| def squash(predict, truth, **kwargs): | |||||
| '''To reshape tensors in order to fit Loss functions in pytorch | |||||
| :param predict : Tensor, model output | |||||
| :param truth : Tensor, truth from dataset | |||||
| :param **kwargs : extra arguments | |||||
| :return predict , truth: predict & truth after processing | |||||
| ''' | |||||
| return predict.view(-1, predict.size()[-1]), truth.view(-1, ) | |||||
| def unpad(predict, truth, **kwargs): | |||||
| '''To process padded sequence output to get true loss | |||||
| Using pack_padded_sequence() method | |||||
| This method contains squash() | |||||
| :param predict : Tensor, [batch_size , max_len , tag_size] | |||||
| :param truth : Tensor, [batch_size , max_len] | |||||
| :param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist | |||||
| kwargs["lens"] : list or LongTensor, [batch_size] | |||||
| the i-th element is true lengths of i-th sequence | |||||
| :return predict , truth: predict & truth after processing | |||||
| ''' | |||||
| if kwargs.get("lens") is None: | |||||
| return predict, truth | |||||
| lens = torch.LongTensor(kwargs["lens"]) | |||||
| lens, idx = torch.sort(lens, descending=True) | |||||
| predict = torch.nn.utils.rnn.pack_padded_sequence(predict[idx], lens, batch_first=True).data | |||||
| truth = torch.nn.utils.rnn.pack_padded_sequence(truth[idx], lens, batch_first=True).data | |||||
| return predict, truth | |||||
| def unpad_mask(predict, truth, **kwargs): | |||||
| '''To process padded sequence output to get true loss | |||||
| Using mask() method | |||||
| This method contains squash() | |||||
| :param predict : Tensor, [batch_size , max_len , tag_size] | |||||
| :param truth : Tensor, [batch_size , max_len] | |||||
| :param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist | |||||
| kwargs["lens"] : list or LongTensor, [batch_size] | |||||
| the i-th element is true lengths of i-th sequence | |||||
| :return predict , truth: predict & truth after processing | |||||
| ''' | |||||
| if kwargs.get("lens") is None: | |||||
| return predict, truth | |||||
| mas = make_mask(kwargs["lens"], truth.size()[1]) | |||||
| return mask(predict, truth, mask=mas) | |||||
| def mask(predict, truth, **kwargs): | |||||
| '''To select specific elements from Tensor | |||||
| This method contains squash() | |||||
| :param predict : Tensor, [batch_size , max_len , tag_size] | |||||
| :param truth : Tensor, [batch_size , max_len] | |||||
| :param **kwargs : extra arguments, kwargs["mask"] is expected to be exsist | |||||
| kwargs["mask"] : ByteTensor, [batch_size , max_len] | |||||
| the mask Tensor , the position that is 1 will be selected | |||||
| :return predict , truth: predict & truth after processing | |||||
| ''' | |||||
| if kwargs.get("mask") is None: | |||||
| return predict, truth | |||||
| mask = kwargs["mask"] | |||||
| predict, truth = squash(predict, truth) | |||||
| mask = mask.view(-1, ) | |||||
| predict = torch.masked_select(predict.permute(1, 0), mask).view(predict.size()[-1], -1).permute(1, 0) | |||||
| truth = torch.masked_select(truth, mask) | |||||
| return predict, truth | |||||
| def make_mask(lens, tar_len): | |||||
| '''to generate a mask that select [:lens[i]] for i-th element | |||||
| embezzle from fastNLP.models.sequence_modeling.seq_mask | |||||
| :param lens : list or LongTensor, [batch_size] | |||||
| :param tar_len : int | |||||
| :return mask : ByteTensor | |||||
| ''' | |||||
| lens = torch.LongTensor(lens) | |||||
| mask = [torch.ge(lens, i + 1) for i in range(tar_len)] | |||||
| mask = torch.stack(mask, 1) | |||||
| return mask | |||||
| # map string to function. Just for more elegant using | |||||
| method_dict = { | |||||
| "squash": squash, | |||||
| "unpad": unpad, | |||||
| "unpad_mask": unpad_mask, | |||||
| "mask": mask, | |||||
| } | |||||
| loss_function_name = { | |||||
| "L1Loss".lower(): torch.nn.L1Loss, | |||||
| "BCELoss".lower(): torch.nn.BCELoss, | |||||
| "MSELoss".lower(): torch.nn.MSELoss, | |||||
| "NLLLoss".lower(): torch.nn.NLLLoss, | |||||
| "KLDivLoss".lower(): torch.nn.KLDivLoss, | |||||
| "NLLLoss2dLoss".lower(): torch.nn.NLLLoss2d, # every name should end with "loss" | |||||
| "SmoothL1Loss".lower(): torch.nn.SmoothL1Loss, | |||||
| "SoftMarginLoss".lower(): torch.nn.SoftMarginLoss, | |||||
| "PoissonNLLLoss".lower(): torch.nn.PoissonNLLLoss, | |||||
| "MultiMarginLoss".lower(): torch.nn.MultiMarginLoss, | |||||
| "CrossEntropyLoss".lower(): torch.nn.CrossEntropyLoss, | |||||
| "BCEWithLogitsLoss".lower(): torch.nn.BCEWithLogitsLoss, | |||||
| "MarginRankingLoss".lower(): torch.nn.MarginRankingLoss, | |||||
| "TripletMarginLoss".lower(): torch.nn.TripletMarginLoss, | |||||
| "HingeEmbeddingLoss".lower(): torch.nn.HingeEmbeddingLoss, | |||||
| "CosineEmbeddingLoss".lower(): torch.nn.CosineEmbeddingLoss, | |||||
| "MultiLabelMarginLoss".lower(): torch.nn.MultiLabelMarginLoss, | |||||
| "MultiLabelSoftMarginLoss".lower(): torch.nn.MultiLabelSoftMarginLoss, | |||||
| } | |||||
| class Loss(object): | |||||
| '''a Loss object is a callable object represents loss functions | |||||
| ''' | |||||
| def __init__(self, loss_name, pre_pro=[squash], **kwargs): | |||||
| ''' | |||||
| :param loss_name: str or None , the name of loss function | |||||
| :param pre_pro : list of function or str, methods to reform parameters before calculating loss | |||||
| the strings will be auto translated to pre-defined functions | |||||
| :param **kwargs: kwargs for torch loss function | |||||
| pre_pro funcsions should have three arguments: predict, truth, **arg | |||||
| predict and truth is the necessary parameters in loss function | |||||
| kwargs is the extra parameters passed-in when calling loss function | |||||
| pre_pro functions should return two objects, respectively predict and truth that after processed | |||||
| ''' | |||||
| if loss_name is None: | |||||
| # this is useful when Trainer.__init__ performs type check | |||||
| self._loss = None | |||||
| else: | |||||
| if not isinstance(loss_name, str): | |||||
| raise NotImplementedError | |||||
| else: | |||||
| self._loss = self._get_loss(loss_name, **kwargs) | |||||
| self.pre_pro = [f if callable(f) else method_dict.get(f) for f in pre_pro] | |||||
| def add_pre_pro(self, func): | |||||
| '''add a pre_pro function | |||||
| :param func: a function or str, methods to reform parameters before calculating loss | |||||
| the strings will be auto translated to pre-defined functions | |||||
| ''' | |||||
| if not callable(func): | |||||
| func = method_dict.get(func) | |||||
| if func is None: | |||||
| return | |||||
| self.pre_pro.append(func) | |||||
| @staticmethod | |||||
| def _get_loss(loss_name, **kwargs): | |||||
| '''Get loss function from torch | |||||
| :param loss_name: str, the name of loss function | |||||
| :param **kwargs: kwargs for torch loss function | |||||
| :return: A callable loss function object | |||||
| ''' | |||||
| loss_name = loss_name.strip().lower() | |||||
| loss_name = "".join(loss_name.split("_")) | |||||
| if len(loss_name) < 4 or loss_name[-4:] != "loss": | |||||
| loss_name += "loss" | |||||
| return loss_function_name[loss_name](**kwargs) | |||||
| def get(self): | |||||
| '''This method exists just for make some existing codes run error-freely | |||||
| ''' | |||||
| return self | |||||
| def __call__(self, predict, truth, **kwargs): | |||||
| '''call a loss function | |||||
| predict and truth will be processed by pre_pro methods in order of addition | |||||
| :param predict : Tensor, model output | |||||
| :param truth : Tensor, truth from dataset | |||||
| :param **kwargs : extra arguments, pass to pre_pro functions | |||||
| for example, if used unpad_mask() in pre_pro, there should be a kwarg named lens | |||||
| ''' | |||||
| for f in self.pre_pro: | |||||
| if f is None: | |||||
| continue | |||||
| predict, truth = f(predict, truth, **kwargs) | |||||
| return self._loss(predict, truth) | |||||
| @@ -1,29 +1,27 @@ | |||||
| import itertools | |||||
| import os | |||||
| import time | import time | ||||
| from datetime import timedelta | |||||
| from datetime import datetime | |||||
| import warnings | import warnings | ||||
| from collections import defaultdict | from collections import defaultdict | ||||
| import os | |||||
| import itertools | |||||
| import shutil | |||||
| from datetime import datetime | |||||
| from datetime import timedelta | |||||
| from tensorboardX import SummaryWriter | |||||
| import torch | import torch | ||||
| from torch import nn | from torch import nn | ||||
| from tensorboardX import SummaryWriter | |||||
| from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
| from fastNLP.core.loss import Loss | |||||
| from fastNLP.core.metrics import Evaluator | |||||
| from fastNLP.core.optimizer import Optimizer | from fastNLP.core.optimizer import Optimizer | ||||
| from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
| from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
| from fastNLP.core.tester import Tester | from fastNLP.core.tester import Tester | ||||
| from fastNLP.core.utils import _check_arg_dict_list | |||||
| from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
| from fastNLP.core.utils import _check_arg_dict_list | |||||
| from fastNLP.core.utils import _syn_model_data | from fastNLP.core.utils import _syn_model_data | ||||
| from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
| from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
| class Trainer(object): | class Trainer(object): | ||||
| """Main Training Loop | """Main Training Loop | ||||
| @@ -1,9 +1,10 @@ | |||||
| import math | |||||
| import unittest | import unittest | ||||
| import fastNLP.core.loss as loss | |||||
| import math | |||||
| import torch as tc | import torch as tc | ||||
| import pdb | |||||
| import fastNLP.core.losses as loss | |||||
| class TestLoss(unittest.TestCase): | class TestLoss(unittest.TestCase): | ||||