Browse Source

在dataset中添加concat函数,支持将两个dataset concat起来

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
50d7cfdd7a
2 changed files with 93 additions and 5 deletions
  1. +42
    -5
      fastNLP/core/dataset.py
  2. +51
    -0
      tests/core/test_dataset.py

+ 42
- 5
fastNLP/core/dataset.py View File

@@ -531,11 +531,11 @@ class DataSet(object):
| pad_value | 0 | |
+-------------+-------+-------+

:param field_names: DataSet中field的名称
:param is_input: field是否为input
:param is_target: field是否为target
:param ignore_type: 是否忽略该field的type, 一般仅在该field至少为input或target时才有意义
:param pad_value: 该field的pad的值,仅在该field为input或target时有意义
str field_names: DataSet中field的名称
bool is_input: field是否为input
bool is_target: field是否为target
bool ignore_type: 是否忽略该field的type, 一般仅在该field至少为input或target时才有意义
int pad_value: 该field的pad的值,仅在该field为input或target时有意义
:return:
"""
if len(self.field_arrays)>0:
@@ -1146,3 +1146,40 @@ class DataSet(object):

def _collate_batch(self, ins_list):
return self.collater.collate_batch(ins_list)

def concat(self, dataset, inplace=True, field_mapping=None):
"""
将当前dataset与输入的dataset结合成一个更大的dataset,需要保证两个dataset都包含了相同的field。结合后的dataset的input,target
以及collate_fn以当前dataset为准。当dataset中包含的field多于当前的dataset,则多余的field会被忽略;若dataset中未包含所有
当前dataset含有field,则会报错。

:param DataSet, dataset: 需要和当前dataset concat的dataset
:param bool, inplace: 是否直接将dataset组合到当前dataset中
:param dict, field_mapping: 当dataset中的field名称和当前dataset不一致时,需要通过field_mapping把输入的dataset中的field
名称映射到当前field. field_mapping为dict类型,key为dataset中的field名称,value是需要映射成的名称

:return: DataSet
"""
assert isinstance(dataset, DataSet), "Can only concat two datasets."

fns_in_this_dataset = set(self.get_field_names())
fns_in_other_dataset = dataset.get_field_names()
reverse_field_mapping = {}
if field_mapping is not None:
fns_in_other_dataset = [field_mapping.get(fn, fn) for fn in fns_in_other_dataset]
reverse_field_mapping = {v:k for k, v in field_mapping.items()}
fns_in_other_dataset = set(fns_in_other_dataset)
fn_not_seen = list(fns_in_this_dataset - fns_in_other_dataset)

if fn_not_seen:
raise RuntimeError(f"The following fields are not provided in the dataset:{fn_not_seen}")

if inplace:
ds = self
else:
ds = deepcopy(self)

for fn in fns_in_this_dataset:
ds.get_field(fn).content.extend(deepcopy(dataset.get_field(reverse_field_mapping.get(fn, fn)).content))

return ds

+ 51
- 0
tests/core/test_dataset.py View File

@@ -268,6 +268,57 @@ class TestDataSetMethods(unittest.TestCase):
with self.assertRaises(RuntimeError) as RE:
ds.add_field('test', [])

def test_concat(self):
"""
测试两个dataset能否正确concat

"""
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
ds2 = DataSet({"x": [[4,3,2,1] for i in range(10)], "y": [[6,5] for i in range(10)]})
ds3 = ds1.concat(ds2)

self.assertEqual(len(ds3), 20)

self.assertListEqual(ds1[9]['x'], [1, 2, 3, 4])
self.assertListEqual(ds1[10]['x'], [4,3,2,1])

ds2[0]['x'][0] = 100
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了

ds3[10]['x'][0] = -100
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了

# 测试inplace
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
ds2 = DataSet({"x": [[4, 3, 2, 1] for i in range(10)], "y": [[6, 5] for i in range(10)]})
ds3 = ds1.concat(ds2, inplace=True)

ds2[0]['x'][0] = 100
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了

ds3[10]['x'][0] = -100
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了

ds3[0]['x'][0] = 100
self.assertEqual(ds1[0]['x'][0], 100) # 改变copy前的field了

# 测试mapping
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)]})
ds3 = ds1.concat(ds2, field_mapping={'X':'x', 'Y':'y'})
self.assertEqual(len(ds3), 20)

# 测试忽略掉多余的
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)], 'Z':[0]*10})
ds3 = ds1.concat(ds2, field_mapping={'X':'x', 'Y':'y'})

# 测试报错
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)]})
with self.assertRaises(RuntimeError):
ds3 = ds1.concat(ds2, field_mapping={'X':'x'})


class TestDataSetIter(unittest.TestCase):
def test__repr__(self):


Loading…
Cancel
Save