@@ -169,6 +169,8 @@ class LossFunc(LossBase): | |||||
class CrossEntropyLoss(LossBase): | class CrossEntropyLoss(LossBase): | ||||
def __init__(self, pred=None, target=None): | def __init__(self, pred=None, target=None): | ||||
# TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际却需要 | |||||
# TODO (16, 4) | |||||
super(CrossEntropyLoss, self).__init__() | super(CrossEntropyLoss, self).__init__() | ||||
self.get_loss = F.cross_entropy | self.get_loss = F.cross_entropy | ||||
self._init_param_map(input=pred, target=target) | self._init_param_map(input=pred, target=target) | ||||
@@ -125,6 +125,13 @@ class TestDataSet(unittest.TestCase): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | ||||
self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target]) | self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target]) | ||||
def test_apply2(self): | |||||
def split_sent(ins): | |||||
return ins['raw_sentence'].split() | |||||
dataset = DataSet.read_csv('../../sentence.csv', headers=('raw_sentence', 'label'), sep='\t') | |||||
dataset.apply(split_sent, new_field_name='words') | |||||
# print(dataset) | |||||
class TestDataSetIter(unittest.TestCase): | class TestDataSetIter(unittest.TestCase): | ||||
def test__repr__(self): | def test__repr__(self): | ||||
@@ -311,9 +311,17 @@ class TestLosserError(unittest.TestCase): | |||||
print(los(pred_dict=pred_dict, target_dict=target_dict)) | print(los(pred_dict=pred_dict, target_dict=target_dict)) | ||||
# | # | ||||
def test_AccuracyMetric2(self): | |||||
def test_losser2(self): | |||||
# (2) with corrupted size | # (2) with corrupted size | ||||
pred_dict = {"pred": torch.zeros(16, 3, 4)} | |||||
pred_dict = {"pred": torch.zeros(16, 3)} | |||||
target_dict = {'target': torch.zeros(16, 3).long()} | |||||
los = loss.CrossEntropyLoss() | |||||
print(los(pred_dict=pred_dict, target_dict=target_dict)) | |||||
def test_losser3(self): | |||||
# (2) with corrupted size | |||||
pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param':0} | |||||
target_dict = {'target': torch.zeros(16, 3).long()} | target_dict = {'target': torch.zeros(16, 3).long()} | ||||
los = loss.CrossEntropyLoss() | los = loss.CrossEntropyLoss() | ||||