|
@@ -37,5 +37,7 @@ class Loss(object): |
|
|
""" |
|
|
""" |
|
|
if loss_name == "cross_entropy": |
|
|
if loss_name == "cross_entropy": |
|
|
return torch.nn.CrossEntropyLoss() |
|
|
return torch.nn.CrossEntropyLoss() |
|
|
|
|
|
elif loss_name == 'nll': |
|
|
|
|
|
return torch.nn.NLLLoss() |
|
|
else: |
|
|
else: |
|
|
raise NotImplementedError |
|
|
raise NotImplementedError |