diff --git a/fastNLP/core/loss.py b/fastNLP/core/loss.py index 8d866bbf..8a0eedd7 100644 --- a/fastNLP/core/loss.py +++ b/fastNLP/core/loss.py @@ -37,5 +37,7 @@ class Loss(object): """ if loss_name == "cross_entropy": return torch.nn.CrossEntropyLoss() + elif loss_name == 'nll': + return torch.nn.NLLLoss() else: raise NotImplementedError