From 3139d3fb4d8dd40f7fe4114102a79edc2901e33f Mon Sep 17 00:00:00 2001 From: yh_cc Date: Thu, 5 May 2022 20:21:20 +0800 Subject: [PATCH] add instance to mapping --- .../{utils.py => packer_unpacker.py} | 76 +++++++++++++++---- fastNLP/core/dataset/instance.py | 9 ++- 2 files changed, 70 insertions(+), 15 deletions(-) rename fastNLP/core/collators/{utils.py => packer_unpacker.py} (58%) diff --git a/fastNLP/core/collators/utils.py b/fastNLP/core/collators/packer_unpacker.py similarity index 58% rename from fastNLP/core/collators/utils.py rename to fastNLP/core/collators/packer_unpacker.py index 1a82aa23..f71b4113 100644 --- a/fastNLP/core/collators/utils.py +++ b/fastNLP/core/collators/packer_unpacker.py @@ -3,21 +3,69 @@ from functools import reduce from typing import Sequence, Mapping, Dict -def unpack_batch_mapping(batch:Sequence[Mapping], ignore_fields:set)->Dict: - """ - 将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]} +class MappingPackerUnPacker: + @staticmethod + def unpack_batch(batch:Sequence[Mapping], ignore_fields:set, input_fields:Dict)->Dict: + """ + 将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]} + + :param batch: + :param ignore_fields: + :param input_fields: + :return: + """ + dict_batch = defaultdict(list) + for sample in batch: + for key, value in sample.items(): + if key in ignore_fields: + continue + dict_batch[key].append(value) + return dict_batch + + @staticmethod + def pack_batch(batch): + return batch + + +class NestedMappingPackerUnpacker: + @staticmethod + def unpack_batch(batch:Sequence[Mapping], ignore_fields:set, input_fields:Dict): + """ + 将 nested 的 dict 中的内容展开到一个 flat dict 中 + + :param batch: + :param ignore_fields: 需要忽略的 field 。 + :param input_fields: 不需要继续往下衍射的 + :return: + """ + dict_batch = defaultdict(list) + for sample in batch: + for key, value in sample.items(): + if key in ignore_fields: + continue + if isinstance(value, Mapping) and key not in input_fields: + _dict_batch = _unpack_batch_nested_mapping(value, ignore_fields, input_fields, _parent=(key,)) + for key, value in _dict_batch.items(): + dict_batch[key].append(value) + else: + dict_batch[key].append(value) + return dict_batch - :param batch: - :param ignore_fields: - :return: - """ - dict_batch = defaultdict(list) - for sample in batch: - for key, value in sample.items(): - if key in ignore_fields: - continue - dict_batch[key].append(value) - return dict_batch + @staticmethod + def pack_batch(batch): + dicts = [] + + for key, value in batch.items(): + if not isinstance(key, tuple): + key = [key] + d = {key[-1]: value} + for k in key[:-1:][::-1]: + d = {k: d} + dicts.append(d) + return reduce(_merge_dict, dicts) + + +class def unpack_batch_nested_mapping(batch:Sequence[Mapping], ignore_fields:set, stop_deep_fields:set)->Dict: diff --git a/fastNLP/core/dataset/instance.py b/fastNLP/core/dataset/instance.py index db3d4be7..74ed3c9f 100644 --- a/fastNLP/core/dataset/instance.py +++ b/fastNLP/core/dataset/instance.py @@ -8,10 +8,11 @@ __all__ = [ "Instance" ] +from typing import Mapping from fastNLP.core.utils.utils import pretty_table_printer -class Instance: +class Instance(Mapping): r""" Instance是fastNLP中对应一个sample的类。每个sample在fastNLP中是一个Instance对象。 Instance一般与 :class:`~fastNLP.DataSet` 一起使用, Instance的初始化如下面的Example所示:: @@ -69,3 +70,9 @@ class Instance: def __repr__(self): return str(pretty_table_printer(self)) + + def __len__(self): + return len(self.fields) + + def __iter__(self): + return iter(self.fields)