@@ -531,11 +531,11 @@ class DataSet(object): | |||
| pad_value | 0 | | | |||
+-------------+-------+-------+ | |||
:param field_names: DataSet中field的名称 | |||
:param is_input: field是否为input | |||
:param is_target: field是否为target | |||
:param ignore_type: 是否忽略该field的type, 一般仅在该field至少为input或target时才有意义 | |||
:param pad_value: 该field的pad的值,仅在该field为input或target时有意义 | |||
str field_names: DataSet中field的名称 | |||
bool is_input: field是否为input | |||
bool is_target: field是否为target | |||
bool ignore_type: 是否忽略该field的type, 一般仅在该field至少为input或target时才有意义 | |||
int pad_value: 该field的pad的值,仅在该field为input或target时有意义 | |||
:return: | |||
""" | |||
if len(self.field_arrays)>0: | |||
@@ -1146,3 +1146,40 @@ class DataSet(object): | |||
def _collate_batch(self, ins_list): | |||
return self.collater.collate_batch(ins_list) | |||
def concat(self, dataset, inplace=True, field_mapping=None): | |||
""" | |||
将当前dataset与输入的dataset结合成一个更大的dataset,需要保证两个dataset都包含了相同的field。结合后的dataset的input,target | |||
以及collate_fn以当前dataset为准。当dataset中包含的field多于当前的dataset,则多余的field会被忽略;若dataset中未包含所有 | |||
当前dataset含有field,则会报错。 | |||
:param DataSet, dataset: 需要和当前dataset concat的dataset | |||
:param bool, inplace: 是否直接将dataset组合到当前dataset中 | |||
:param dict, field_mapping: 当dataset中的field名称和当前dataset不一致时,需要通过field_mapping把输入的dataset中的field | |||
名称映射到当前field. field_mapping为dict类型,key为dataset中的field名称,value是需要映射成的名称 | |||
:return: DataSet | |||
""" | |||
assert isinstance(dataset, DataSet), "Can only concat two datasets." | |||
fns_in_this_dataset = set(self.get_field_names()) | |||
fns_in_other_dataset = dataset.get_field_names() | |||
reverse_field_mapping = {} | |||
if field_mapping is not None: | |||
fns_in_other_dataset = [field_mapping.get(fn, fn) for fn in fns_in_other_dataset] | |||
reverse_field_mapping = {v:k for k, v in field_mapping.items()} | |||
fns_in_other_dataset = set(fns_in_other_dataset) | |||
fn_not_seen = list(fns_in_this_dataset - fns_in_other_dataset) | |||
if fn_not_seen: | |||
raise RuntimeError(f"The following fields are not provided in the dataset:{fn_not_seen}") | |||
if inplace: | |||
ds = self | |||
else: | |||
ds = deepcopy(self) | |||
for fn in fns_in_this_dataset: | |||
ds.get_field(fn).content.extend(deepcopy(dataset.get_field(reverse_field_mapping.get(fn, fn)).content)) | |||
return ds |
@@ -268,6 +268,57 @@ class TestDataSetMethods(unittest.TestCase): | |||
with self.assertRaises(RuntimeError) as RE: | |||
ds.add_field('test', []) | |||
def test_concat(self): | |||
""" | |||
测试两个dataset能否正确concat | |||
""" | |||
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) | |||
ds2 = DataSet({"x": [[4,3,2,1] for i in range(10)], "y": [[6,5] for i in range(10)]}) | |||
ds3 = ds1.concat(ds2) | |||
self.assertEqual(len(ds3), 20) | |||
self.assertListEqual(ds1[9]['x'], [1, 2, 3, 4]) | |||
self.assertListEqual(ds1[10]['x'], [4,3,2,1]) | |||
ds2[0]['x'][0] = 100 | |||
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了 | |||
ds3[10]['x'][0] = -100 | |||
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了 | |||
# 测试inplace | |||
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) | |||
ds2 = DataSet({"x": [[4, 3, 2, 1] for i in range(10)], "y": [[6, 5] for i in range(10)]}) | |||
ds3 = ds1.concat(ds2, inplace=True) | |||
ds2[0]['x'][0] = 100 | |||
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了 | |||
ds3[10]['x'][0] = -100 | |||
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了 | |||
ds3[0]['x'][0] = 100 | |||
self.assertEqual(ds1[0]['x'][0], 100) # 改变copy前的field了 | |||
# 测试mapping | |||
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) | |||
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)]}) | |||
ds3 = ds1.concat(ds2, field_mapping={'X':'x', 'Y':'y'}) | |||
self.assertEqual(len(ds3), 20) | |||
# 测试忽略掉多余的 | |||
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) | |||
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)], 'Z':[0]*10}) | |||
ds3 = ds1.concat(ds2, field_mapping={'X':'x', 'Y':'y'}) | |||
# 测试报错 | |||
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) | |||
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)]}) | |||
with self.assertRaises(RuntimeError): | |||
ds3 = ds1.concat(ds2, field_mapping={'X':'x'}) | |||
class TestDataSetIter(unittest.TestCase): | |||
def test__repr__(self): | |||