@@ -26,7 +26,7 @@ from .utils import _build_args | |||
from .utils import _check_arg_dict_list | |||
from .utils import _check_function_or_method | |||
from .utils import _get_func_signature | |||
from .utils import seq_len_to_mask | |||
class LossBase(object): | |||
""" | |||
@@ -223,7 +223,9 @@ class CrossEntropyLoss(LossBase): | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容 | |||
:param seq_len: 句子的长度, 长度之外的token不会计算loss。。 | |||
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 | |||
传入seq_len. | |||
Example:: | |||
@@ -231,16 +233,18 @@ class CrossEntropyLoss(LossBase): | |||
""" | |||
def __init__(self, pred=None, target=None, padding_idx=-100): | |||
def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100): | |||
super(CrossEntropyLoss, self).__init__() | |||
self._init_param_map(pred=pred, target=target) | |||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | |||
self.padding_idx = padding_idx | |||
def get_loss(self, pred, target): | |||
def get_loss(self, pred, target, seq_len=None): | |||
if pred.dim()>2: | |||
if pred.size()[:2]==target.size(): | |||
# F.cross_entropy在计算时,如果pred是(16, 10 ,4), 会在第二维上去log_softmax, 所以需要交换一下位置 | |||
pred = pred.transpose(1, 2) | |||
pred = pred.view(-1, pred.size(-1)) | |||
target = target.view(-1) | |||
if seq_len is not None: | |||
mask = seq_len_to_mask(seq_len).view(-1).eq(0) | |||
target = target.masked_fill(mask, self.padding_idx) | |||
return F.cross_entropy(input=pred, target=target, | |||
ignore_index=self.padding_idx) | |||
@@ -452,17 +452,15 @@ class Trainer(object): | |||
else: | |||
raise TypeError("train_data type {} not support".format(type(train_data))) | |||
self.model = _move_model_to_device(model, device=device) | |||
if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): | |||
_check_code(dataset=train_data, model=self.model, losser=losser, metrics=metrics, dev_data=dev_data, | |||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||
metric_key=metric_key, check_level=check_code_level, | |||
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | |||
# _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 | |||
self.model = _move_model_to_device(model, device=device) | |||
self.train_data = train_data | |||
self.dev_data = dev_data # If None, No validation. | |||
self.model = model | |||
self.losser = losser | |||
self.metrics = metrics | |||
self.n_epochs = int(n_epochs) | |||
@@ -480,16 +478,16 @@ class Trainer(object): | |||
if isinstance(optimizer, torch.optim.Optimizer): | |||
self.optimizer = optimizer | |||
elif isinstance(optimizer, Optimizer): | |||
self.optimizer = optimizer.construct_from_pytorch(model.parameters()) | |||
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | |||
elif optimizer is None: | |||
self.optimizer = torch.optim.Adam(model.parameters(), lr=4e-3) | |||
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) | |||
else: | |||
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | |||
self.use_tqdm = use_tqdm | |||
self.pbar = None | |||
self.print_every = abs(self.print_every) | |||
if self.dev_data is not None: | |||
self.tester = Tester(model=self.model, | |||
data=self.dev_data, | |||