From 43d3380b730398ac4594edfbfc28b9e8fc55ce77 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Mon, 24 Jun 2019 18:31:38 +0800 Subject: [PATCH] =?UTF-8?q?1.=E4=BF=AE=E5=A4=8DTrainer=E5=88=9D=E5=A7=8B?= =?UTF-8?q?=E5=8C=96=E7=9A=84=E5=A4=9Adevice=20bug;=202.=E5=9C=A8CrossEntr?= =?UTF-8?q?opyLoss=E4=B8=AD=E5=A2=9E=E5=8A=A0seq=5Flen?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/losses.py | 20 ++++++++++++-------- fastNLP/core/trainer.py | 14 ++++++-------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 62e7a8c8..526bf37a 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -26,7 +26,7 @@ from .utils import _build_args from .utils import _check_arg_dict_list from .utils import _check_function_or_method from .utils import _get_func_signature - +from .utils import seq_len_to_mask class LossBase(object): """ @@ -223,7 +223,9 @@ class CrossEntropyLoss(LossBase): :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` - :param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容 + :param seq_len: 句子的长度, 长度之外的token不会计算loss。。 + :param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 + 传入seq_len. Example:: @@ -231,16 +233,18 @@ class CrossEntropyLoss(LossBase): """ - def __init__(self, pred=None, target=None, padding_idx=-100): + def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100): super(CrossEntropyLoss, self).__init__() - self._init_param_map(pred=pred, target=target) + self._init_param_map(pred=pred, target=target, seq_len=seq_len) self.padding_idx = padding_idx - def get_loss(self, pred, target): + def get_loss(self, pred, target, seq_len=None): if pred.dim()>2: - if pred.size()[:2]==target.size(): - # F.cross_entropy在计算时,如果pred是(16, 10 ,4), 会在第二维上去log_softmax, 所以需要交换一下位置 - pred = pred.transpose(1, 2) + pred = pred.view(-1, pred.size(-1)) + target = target.view(-1) + if seq_len is not None: + mask = seq_len_to_mask(seq_len).view(-1).eq(0) + target = target.masked_fill(mask, self.padding_idx) return F.cross_entropy(input=pred, target=target, ignore_index=self.padding_idx) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index a303f742..e8dfa814 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -452,17 +452,15 @@ class Trainer(object): else: raise TypeError("train_data type {} not support".format(type(train_data))) - self.model = _move_model_to_device(model, device=device) - if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): - _check_code(dataset=train_data, model=self.model, losser=losser, metrics=metrics, dev_data=dev_data, + _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, metric_key=metric_key, check_level=check_code_level, batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) # _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 - + self.model = _move_model_to_device(model, device=device) + self.train_data = train_data self.dev_data = dev_data # If None, No validation. - self.model = model self.losser = losser self.metrics = metrics self.n_epochs = int(n_epochs) @@ -480,16 +478,16 @@ class Trainer(object): if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer elif isinstance(optimizer, Optimizer): - self.optimizer = optimizer.construct_from_pytorch(model.parameters()) + self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) elif optimizer is None: - self.optimizer = torch.optim.Adam(model.parameters(), lr=4e-3) + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) else: raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) self.use_tqdm = use_tqdm self.pbar = None self.print_every = abs(self.print_every) - + if self.dev_data is not None: self.tester = Tester(model=self.model, data=self.dev_data,