diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 5e80a6fb..ec64d484 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -531,11 +531,11 @@ class DataSet(object): | pad_value | 0 | | +-------------+-------+-------+ - :param field_names: DataSet中field的名称 - :param is_input: field是否为input - :param is_target: field是否为target - :param ignore_type: 是否忽略该field的type, 一般仅在该field至少为input或target时才有意义 - :param pad_value: 该field的pad的值,仅在该field为input或target时有意义 + str field_names: DataSet中field的名称 + bool is_input: field是否为input + bool is_target: field是否为target + bool ignore_type: 是否忽略该field的type, 一般仅在该field至少为input或target时才有意义 + int pad_value: 该field的pad的值,仅在该field为input或target时有意义 :return: """ if len(self.field_arrays)>0: @@ -1146,3 +1146,40 @@ class DataSet(object): def _collate_batch(self, ins_list): return self.collater.collate_batch(ins_list) + + def concat(self, dataset, inplace=True, field_mapping=None): + """ + 将当前dataset与输入的dataset结合成一个更大的dataset,需要保证两个dataset都包含了相同的field。结合后的dataset的input,target + 以及collate_fn以当前dataset为准。当dataset中包含的field多于当前的dataset,则多余的field会被忽略;若dataset中未包含所有 + 当前dataset含有field,则会报错。 + + :param DataSet, dataset: 需要和当前dataset concat的dataset + :param bool, inplace: 是否直接将dataset组合到当前dataset中 + :param dict, field_mapping: 当dataset中的field名称和当前dataset不一致时,需要通过field_mapping把输入的dataset中的field + 名称映射到当前field. field_mapping为dict类型,key为dataset中的field名称,value是需要映射成的名称 + + :return: DataSet + """ + assert isinstance(dataset, DataSet), "Can only concat two datasets." + + fns_in_this_dataset = set(self.get_field_names()) + fns_in_other_dataset = dataset.get_field_names() + reverse_field_mapping = {} + if field_mapping is not None: + fns_in_other_dataset = [field_mapping.get(fn, fn) for fn in fns_in_other_dataset] + reverse_field_mapping = {v:k for k, v in field_mapping.items()} + fns_in_other_dataset = set(fns_in_other_dataset) + fn_not_seen = list(fns_in_this_dataset - fns_in_other_dataset) + + if fn_not_seen: + raise RuntimeError(f"The following fields are not provided in the dataset:{fn_not_seen}") + + if inplace: + ds = self + else: + ds = deepcopy(self) + + for fn in fns_in_this_dataset: + ds.get_field(fn).content.extend(deepcopy(dataset.get_field(reverse_field_mapping.get(fn, fn)).content)) + + return ds diff --git a/tests/core/test_dataset.py b/tests/core/test_dataset.py index 94dd3bdb..d0d08d97 100644 --- a/tests/core/test_dataset.py +++ b/tests/core/test_dataset.py @@ -268,6 +268,57 @@ class TestDataSetMethods(unittest.TestCase): with self.assertRaises(RuntimeError) as RE: ds.add_field('test', []) + def test_concat(self): + """ + 测试两个dataset能否正确concat + + """ + ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) + ds2 = DataSet({"x": [[4,3,2,1] for i in range(10)], "y": [[6,5] for i in range(10)]}) + ds3 = ds1.concat(ds2) + + self.assertEqual(len(ds3), 20) + + self.assertListEqual(ds1[9]['x'], [1, 2, 3, 4]) + self.assertListEqual(ds1[10]['x'], [4,3,2,1]) + + ds2[0]['x'][0] = 100 + self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了 + + ds3[10]['x'][0] = -100 + self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了 + + # 测试inplace + ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) + ds2 = DataSet({"x": [[4, 3, 2, 1] for i in range(10)], "y": [[6, 5] for i in range(10)]}) + ds3 = ds1.concat(ds2, inplace=True) + + ds2[0]['x'][0] = 100 + self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了 + + ds3[10]['x'][0] = -100 + self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了 + + ds3[0]['x'][0] = 100 + self.assertEqual(ds1[0]['x'][0], 100) # 改变copy前的field了 + + # 测试mapping + ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) + ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)]}) + ds3 = ds1.concat(ds2, field_mapping={'X':'x', 'Y':'y'}) + self.assertEqual(len(ds3), 20) + + # 测试忽略掉多余的 + ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) + ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)], 'Z':[0]*10}) + ds3 = ds1.concat(ds2, field_mapping={'X':'x', 'Y':'y'}) + + # 测试报错 + ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) + ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)]}) + with self.assertRaises(RuntimeError): + ds3 = ds1.concat(ds2, field_mapping={'X':'x'}) + class TestDataSetIter(unittest.TestCase): def test__repr__(self):