|
|
@@ -3,21 +3,69 @@ from functools import reduce |
|
|
|
from typing import Sequence, Mapping, Dict |
|
|
|
|
|
|
|
|
|
|
|
def unpack_batch_mapping(batch:Sequence[Mapping], ignore_fields:set)->Dict: |
|
|
|
""" |
|
|
|
将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]} |
|
|
|
class MappingPackerUnPacker: |
|
|
|
@staticmethod |
|
|
|
def unpack_batch(batch:Sequence[Mapping], ignore_fields:set, input_fields:Dict)->Dict: |
|
|
|
""" |
|
|
|
将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]} |
|
|
|
|
|
|
|
:param batch: |
|
|
|
:param ignore_fields: |
|
|
|
:param input_fields: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
dict_batch = defaultdict(list) |
|
|
|
for sample in batch: |
|
|
|
for key, value in sample.items(): |
|
|
|
if key in ignore_fields: |
|
|
|
continue |
|
|
|
dict_batch[key].append(value) |
|
|
|
return dict_batch |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def pack_batch(batch): |
|
|
|
return batch |
|
|
|
|
|
|
|
|
|
|
|
class NestedMappingPackerUnpacker: |
|
|
|
@staticmethod |
|
|
|
def unpack_batch(batch:Sequence[Mapping], ignore_fields:set, input_fields:Dict): |
|
|
|
""" |
|
|
|
将 nested 的 dict 中的内容展开到一个 flat dict 中 |
|
|
|
|
|
|
|
:param batch: |
|
|
|
:param ignore_fields: 需要忽略的 field 。 |
|
|
|
:param input_fields: 不需要继续往下衍射的 |
|
|
|
:return: |
|
|
|
""" |
|
|
|
dict_batch = defaultdict(list) |
|
|
|
for sample in batch: |
|
|
|
for key, value in sample.items(): |
|
|
|
if key in ignore_fields: |
|
|
|
continue |
|
|
|
if isinstance(value, Mapping) and key not in input_fields: |
|
|
|
_dict_batch = _unpack_batch_nested_mapping(value, ignore_fields, input_fields, _parent=(key,)) |
|
|
|
for key, value in _dict_batch.items(): |
|
|
|
dict_batch[key].append(value) |
|
|
|
else: |
|
|
|
dict_batch[key].append(value) |
|
|
|
return dict_batch |
|
|
|
|
|
|
|
:param batch: |
|
|
|
:param ignore_fields: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
dict_batch = defaultdict(list) |
|
|
|
for sample in batch: |
|
|
|
for key, value in sample.items(): |
|
|
|
if key in ignore_fields: |
|
|
|
continue |
|
|
|
dict_batch[key].append(value) |
|
|
|
return dict_batch |
|
|
|
@staticmethod |
|
|
|
def pack_batch(batch): |
|
|
|
dicts = [] |
|
|
|
|
|
|
|
for key, value in batch.items(): |
|
|
|
if not isinstance(key, tuple): |
|
|
|
key = [key] |
|
|
|
d = {key[-1]: value} |
|
|
|
for k in key[:-1:][::-1]: |
|
|
|
d = {k: d} |
|
|
|
dicts.append(d) |
|
|
|
return reduce(_merge_dict, dicts) |
|
|
|
|
|
|
|
|
|
|
|
class |
|
|
|
|
|
|
|
|
|
|
|
def unpack_batch_nested_mapping(batch:Sequence[Mapping], ignore_fields:set, stop_deep_fields:set)->Dict: |