diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index c1e8de0e..58847c31 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -169,6 +169,8 @@ class LossFunc(LossBase): class CrossEntropyLoss(LossBase): def __init__(self, pred=None, target=None): + # TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际却需要 + # TODO (16, 4) super(CrossEntropyLoss, self).__init__() self.get_loss = F.cross_entropy self._init_param_map(input=pred, target=target) diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index 8ca2ed86..697bcd78 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -125,6 +125,13 @@ class TestDataSet(unittest.TestCase): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target]) + def test_apply2(self): + def split_sent(ins): + return ins['raw_sentence'].split() + dataset = DataSet.read_csv('../../sentence.csv', headers=('raw_sentence', 'label'), sep='\t') + dataset.apply(split_sent, new_field_name='words') + # print(dataset) + class TestDataSetIter(unittest.TestCase): def test__repr__(self): diff --git a/test/core/test_loss.py b/test/core/test_loss.py index 53b889c6..270b4d3b 100644 --- a/test/core/test_loss.py +++ b/test/core/test_loss.py @@ -311,9 +311,17 @@ class TestLosserError(unittest.TestCase): print(los(pred_dict=pred_dict, target_dict=target_dict)) # - def test_AccuracyMetric2(self): + def test_losser2(self): # (2) with corrupted size - pred_dict = {"pred": torch.zeros(16, 3, 4)} + pred_dict = {"pred": torch.zeros(16, 3)} + target_dict = {'target': torch.zeros(16, 3).long()} + los = loss.CrossEntropyLoss() + + print(los(pred_dict=pred_dict, target_dict=target_dict)) + + def test_losser3(self): + # (2) with corrupted size + pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param':0} target_dict = {'target': torch.zeros(16, 3).long()} los = loss.CrossEntropyLoss()