Browse Source

Merge remote-tracking branch 'origin/dev0.5.0' into dev0.5.0

tags/v0.4.10
xuyige 6 years ago
parent
commit
238d4fbcd0
2 changed files with 18 additions and 16 deletions
  1. +12
    -8
      fastNLP/core/losses.py
  2. +6
    -8
      fastNLP/core/trainer.py

+ 12
- 8
fastNLP/core/losses.py View File

@@ -26,7 +26,7 @@ from .utils import _build_args
from .utils import _check_arg_dict_list
from .utils import _check_function_or_method
from .utils import _get_func_signature
from .utils import seq_len_to_mask

class LossBase(object):
"""
@@ -223,7 +223,9 @@ class CrossEntropyLoss(LossBase):
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容
:param seq_len: 句子的长度, 长度之外的token不会计算loss。。
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替
传入seq_len.

Example::

@@ -231,16 +233,18 @@ class CrossEntropyLoss(LossBase):
"""
def __init__(self, pred=None, target=None, padding_idx=-100):
def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100):
super(CrossEntropyLoss, self).__init__()
self._init_param_map(pred=pred, target=target)
self._init_param_map(pred=pred, target=target, seq_len=seq_len)
self.padding_idx = padding_idx
def get_loss(self, pred, target):
def get_loss(self, pred, target, seq_len=None):
if pred.dim()>2:
if pred.size()[:2]==target.size():
# F.cross_entropy在计算时,如果pred是(16, 10 ,4), 会在第二维上去log_softmax, 所以需要交换一下位置
pred = pred.transpose(1, 2)
pred = pred.view(-1, pred.size(-1))
target = target.view(-1)
if seq_len is not None:
mask = seq_len_to_mask(seq_len).view(-1).eq(0)
target = target.masked_fill(mask, self.padding_idx)

return F.cross_entropy(input=pred, target=target,
ignore_index=self.padding_idx)


+ 6
- 8
fastNLP/core/trainer.py View File

@@ -452,17 +452,15 @@ class Trainer(object):
else:
raise TypeError("train_data type {} not support".format(type(train_data)))

self.model = _move_model_to_device(model, device=device)

if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter):
_check_code(dataset=train_data, model=self.model, losser=losser, metrics=metrics, dev_data=dev_data,
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data,
metric_key=metric_key, check_level=check_code_level,
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE))
# _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码
self.model = _move_model_to_device(model, device=device)

self.train_data = train_data
self.dev_data = dev_data # If None, No validation.
self.model = model
self.losser = losser
self.metrics = metrics
self.n_epochs = int(n_epochs)
@@ -480,16 +478,16 @@ class Trainer(object):
if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer
elif isinstance(optimizer, Optimizer):
self.optimizer = optimizer.construct_from_pytorch(model.parameters())
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())
elif optimizer is None:
self.optimizer = torch.optim.Adam(model.parameters(), lr=4e-3)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3)
else:
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer)))
self.use_tqdm = use_tqdm
self.pbar = None
self.print_every = abs(self.print_every)
if self.dev_data is not None:
self.tester = Tester(model=self.model,
data=self.dev_data,


Loading…
Cancel
Save