* 更新Loss的接口形参跟metric保持一致 * 添加对几种loss的测试 * embed_loader采用维度独立的方法采样 * 对应测试代码的修改tags/v0.2.0^2
@@ -70,11 +70,11 @@ class LossBase(object): | |||
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " | |||
f"positional argument.).") | |||
def __call__(self, output_dict, target_dict, force_check=False): | |||
def __call__(self, pred_dict, target_dict, check=False): | |||
""" | |||
:param output_dict: A dict from forward function of the network. | |||
:param pred_dict: A dict from forward function of the network. | |||
:param target_dict: A dict from DataSet.batch_y. | |||
:param force_check: Boolean. Force to check the mapping functions when it is running. | |||
:param check: Boolean. Force to check the mapping functions when it is running. | |||
:return: | |||
""" | |||
args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss) | |||
@@ -88,7 +88,8 @@ class LossBase(object): | |||
raise RuntimeError( | |||
f"There is not any param in function{get_func_signature(self.get_loss)}" | |||
) | |||
self._checked = self._checked and not force_check | |||
self._checked = self._checked and not check | |||
if not self._checked: | |||
for keys in args: | |||
if keys not in param_map: | |||
@@ -105,12 +106,12 @@ class LossBase(object): | |||
duplicated = [] | |||
missing = [] | |||
if not self._checked: | |||
for keys, val in output_dict.items(): | |||
for keys, val in pred_dict.items(): | |||
if keys in target_dict.keys(): | |||
duplicated.append(keys) | |||
param_val_dict = {} | |||
for keys, val in output_dict.items(): | |||
for keys, val in pred_dict.items(): | |||
param_val_dict.update({keys: val}) | |||
for keys, val in target_dict.items(): | |||
param_val_dict.update({keys: val}) | |||
@@ -158,29 +159,31 @@ class LossFunc(LossBase): | |||
class CrossEntropyLoss(LossBase): | |||
def __init__(self, input=None, target=None): | |||
def __init__(self, pred=None, target=None): | |||
super(CrossEntropyLoss, self).__init__() | |||
self.get_loss = F.cross_entropy | |||
self._init_param_map(input=input, target=target) | |||
self._init_param_map(input=pred, target=target) | |||
class L1Loss(LossBase): | |||
def __init__(self): | |||
def __init__(self, pred=None, target=None): | |||
super(L1Loss, self).__init__() | |||
self.get_loss = F.l1_loss | |||
self._init_param_map(input=pred, target=target) | |||
class BCELoss(LossBase): | |||
def __init__(self, input=None, target=None): | |||
def __init__(self, pred=None, target=None): | |||
super(BCELoss, self).__init__() | |||
self.get_loss = F.binary_cross_entropy | |||
self._init_param_map(input=input, target=target) | |||
self._init_param_map(input=pred, target=target) | |||
class NLLLoss(LossBase): | |||
def __init__(self): | |||
def __init__(self, pred=None, target=None): | |||
super(NLLLoss, self).__init__() | |||
self.get_loss = F.nll_loss | |||
self._init_param_map(input=pred, target=target) | |||
class LossInForward(LossBase): | |||
@@ -200,9 +203,9 @@ class LossInForward(LossBase): | |||
varargs=[]) | |||
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) | |||
def __call__(self, output_dict, predict_dict, force_check=False): | |||
def __call__(self, pred_dict, target_dict, check=False): | |||
loss = self.get_loss(**output_dict) | |||
loss = self.get_loss(**pred_dict) | |||
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | |||
if not isinstance(loss, torch.Tensor): | |||
@@ -105,9 +105,9 @@ class EmbedLoader(BaseLoader): | |||
if np.sum(hit_flags) < len(vocab): | |||
# some words from vocab are missing in pre-trained embedding | |||
# we normally sample them | |||
# we normally sample each dimension | |||
vocab_embed = embedding_matrix[np.where(hit_flags)] | |||
mean, cov = vocab_embed.mean(axis=0), np.cov(vocab_embed.T) | |||
sampled_vectors = np.random.multivariate_normal(mean, cov, size=(len(vocab) - np.sum(hit_flags),)) | |||
sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0), | |||
size=(len(vocab) - np.sum(hit_flags), emb_dim)) | |||
embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors | |||
return embedding_matrix |
@@ -271,40 +271,32 @@ class TestLoss(unittest.TestCase): | |||
loss3 = get_loss_3({'predict': predict}, {'truth': truth}) | |||
assert loss1 == loss2 and loss1 == loss3 | |||
""" | |||
get_loss_4 = LossFunc(func4) | |||
loss4 = get_loss_4({'a': 1, 'b': 3}, {}) | |||
print(loss4) | |||
assert loss4 == (1 + 3) * 2 | |||
get_loss_5 = LossFunc(func4) | |||
loss5 = get_loss_5({'a': 1, 'b': 3}, {'c': 4}) | |||
print(loss5) | |||
assert loss5 == (1 + 3) * 4 | |||
get_loss_6 = LossFunc(func6) | |||
loss6 = get_loss_6({'a': 1, 'b': 3}, {'c': 4}) | |||
print(loss6) | |||
assert loss6 == (1 + 3) * 4 | |||
get_loss_7 = LossFunc(func6, c='cc') | |||
loss7 = get_loss_7({'a': 1, 'b': 3}, {'cc': 4}) | |||
print(loss7) | |||
assert loss7 == (1 + 3) * 4 | |||
""" | |||
class TestLoss_v2(unittest.TestCase): | |||
def test_CrossEntropyLoss(self): | |||
ce = loss.CrossEntropyLoss(input="my_predict", target="my_truth") | |||
ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth") | |||
a = torch.randn(3, 5, requires_grad=False) | |||
b = torch.empty(3, dtype=torch.long).random_(5) | |||
ans = ce({"my_predict": a}, {"my_truth": b}) | |||
self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) | |||
def test_BCELoss(self): | |||
bce = loss.BCELoss(input="my_predict", target="my_truth") | |||
bce = loss.BCELoss(pred="my_predict", target="my_truth") | |||
a = torch.sigmoid(torch.randn((3, 5), requires_grad=False)) | |||
b = torch.randn((3, 5), requires_grad=False) | |||
ans = bce({"my_predict": a}, {"my_truth": b}) | |||
self.assertEqual(ans, torch.nn.functional.binary_cross_entropy(a, b)) | |||
def test_L1Loss(self): | |||
l1 = loss.L1Loss(pred="my_predict", target="my_truth") | |||
a = torch.randn(3, 5, requires_grad=False) | |||
b = torch.randn(3, 5) | |||
ans = l1({"my_predict": a}, {"my_truth": b}) | |||
self.assertEqual(ans, torch.nn.functional.l1_loss(a, b)) | |||
def test_NLLLoss(self): | |||
l1 = loss.NLLLoss(pred="my_predict", target="my_truth") | |||
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) | |||
b = torch.tensor([1, 0, 4]) | |||
ans = l1({"my_predict": a}, {"my_truth": b}) | |||
self.assertEqual(ans, torch.nn.functional.nll_loss(a, b)) |
@@ -32,7 +32,7 @@ class TrainerTestGround(unittest.TestCase): | |||
model = NaiveClassifier(2, 1) | |||
trainer = Trainer(train_set, model, | |||
losser=BCELoss(input="predict", target="y"), | |||
losser=BCELoss(pred="predict", target="y"), | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
n_epochs=10, | |||
batch_size=32, | |||
@@ -1,12 +1,12 @@ | |||
import unittest | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.io.embed_loader import EmbedLoader | |||
class TestEmbedLoader(unittest.TestCase): | |||
def test_case(self): | |||
vocab = Vocabulary() | |||
vocab.update(["the", "in", "I", "to", "of", "hahaha"]) | |||
# TODO: np.cov在linux上segment fault,原因未知 | |||
# embedding = EmbedLoader().fast_load_embedding(50, "../data_for_tests/glove.6B.50d_test.txt", vocab) | |||
# self.assertEqual(tuple(embedding.shape), (len(vocab), 50)) | |||
embedding = EmbedLoader().fast_load_embedding(50, "test/data_for_tests/glove.6B.50d_test.txt", vocab) | |||
self.assertEqual(tuple(embedding.shape), (len(vocab), 50)) |
@@ -72,7 +72,7 @@ class TestTutorial(unittest.TestCase): | |||
# 实例化Trainer,传入模型和数据,进行训练 | |||
copy_model = deepcopy(model) | |||
overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data, | |||
losser=CrossEntropyLoss(input="output", target="label_seq"), | |||
losser=CrossEntropyLoss(pred="output", target="label_seq"), | |||
metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||
save_path="./save", | |||
batch_size=4, | |||
@@ -80,7 +80,7 @@ class TestTutorial(unittest.TestCase): | |||
overfit_trainer.train() | |||
trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, | |||
losser=CrossEntropyLoss(input="output", target="label_seq"), | |||
losser=CrossEntropyLoss(pred="output", target="label_seq"), | |||
metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||
save_path="./save", | |||
batch_size=4, | |||