Browse Source

tags/v0.2.0^2
yh 6 years ago
parent
commit
87e5d44b01
3 changed files with 19 additions and 2 deletions
  1. +2
    -0
      fastNLP/core/losses.py
  2. +7
    -0
      test/core/test_dataset.py
  3. +10
    -2
      test/core/test_loss.py

+ 2
- 0
fastNLP/core/losses.py View File

@@ -169,6 +169,8 @@ class LossFunc(LossBase):

class CrossEntropyLoss(LossBase):
def __init__(self, pred=None, target=None):
# TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际却需要
# TODO (16, 4)
super(CrossEntropyLoss, self).__init__()
self.get_loss = F.cross_entropy
self._init_param_map(input=pred, target=target)


+ 7
- 0
test/core/test_dataset.py View File

@@ -125,6 +125,13 @@ class TestDataSet(unittest.TestCase):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target])

def test_apply2(self):
def split_sent(ins):
return ins['raw_sentence'].split()
dataset = DataSet.read_csv('../../sentence.csv', headers=('raw_sentence', 'label'), sep='\t')
dataset.apply(split_sent, new_field_name='words')
# print(dataset)


class TestDataSetIter(unittest.TestCase):
def test__repr__(self):


+ 10
- 2
test/core/test_loss.py View File

@@ -311,9 +311,17 @@ class TestLosserError(unittest.TestCase):
print(los(pred_dict=pred_dict, target_dict=target_dict))

#
def test_AccuracyMetric2(self):
def test_losser2(self):
# (2) with corrupted size
pred_dict = {"pred": torch.zeros(16, 3, 4)}
pred_dict = {"pred": torch.zeros(16, 3)}
target_dict = {'target': torch.zeros(16, 3).long()}
los = loss.CrossEntropyLoss()

print(los(pred_dict=pred_dict, target_dict=target_dict))

def test_losser3(self):
# (2) with corrupted size
pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param':0}
target_dict = {'target': torch.zeros(16, 3).long()}
los = loss.CrossEntropyLoss()



Loading…
Cancel
Save