Browse Source

add instance to mapping

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
3139d3fb4d
2 changed files with 70 additions and 15 deletions
  1. +62
    -14
      fastNLP/core/collators/packer_unpacker.py
  2. +8
    -1
      fastNLP/core/dataset/instance.py

fastNLP/core/collators/utils.py → fastNLP/core/collators/packer_unpacker.py View File

@@ -3,21 +3,69 @@ from functools import reduce
from typing import Sequence, Mapping, Dict from typing import Sequence, Mapping, Dict




def unpack_batch_mapping(batch:Sequence[Mapping], ignore_fields:set)->Dict:
"""
将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]}
class MappingPackerUnPacker:
@staticmethod
def unpack_batch(batch:Sequence[Mapping], ignore_fields:set, input_fields:Dict)->Dict:
"""
将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]}

:param batch:
:param ignore_fields:
:param input_fields:
:return:
"""
dict_batch = defaultdict(list)
for sample in batch:
for key, value in sample.items():
if key in ignore_fields:
continue
dict_batch[key].append(value)
return dict_batch

@staticmethod
def pack_batch(batch):
return batch


class NestedMappingPackerUnpacker:
@staticmethod
def unpack_batch(batch:Sequence[Mapping], ignore_fields:set, input_fields:Dict):
"""
将 nested 的 dict 中的内容展开到一个 flat dict 中

:param batch:
:param ignore_fields: 需要忽略的 field 。
:param input_fields: 不需要继续往下衍射的
:return:
"""
dict_batch = defaultdict(list)
for sample in batch:
for key, value in sample.items():
if key in ignore_fields:
continue
if isinstance(value, Mapping) and key not in input_fields:
_dict_batch = _unpack_batch_nested_mapping(value, ignore_fields, input_fields, _parent=(key,))
for key, value in _dict_batch.items():
dict_batch[key].append(value)
else:
dict_batch[key].append(value)
return dict_batch


:param batch:
:param ignore_fields:
:return:
"""
dict_batch = defaultdict(list)
for sample in batch:
for key, value in sample.items():
if key in ignore_fields:
continue
dict_batch[key].append(value)
return dict_batch
@staticmethod
def pack_batch(batch):
dicts = []

for key, value in batch.items():
if not isinstance(key, tuple):
key = [key]
d = {key[-1]: value}
for k in key[:-1:][::-1]:
d = {k: d}
dicts.append(d)
return reduce(_merge_dict, dicts)


class




def unpack_batch_nested_mapping(batch:Sequence[Mapping], ignore_fields:set, stop_deep_fields:set)->Dict: def unpack_batch_nested_mapping(batch:Sequence[Mapping], ignore_fields:set, stop_deep_fields:set)->Dict:

+ 8
- 1
fastNLP/core/dataset/instance.py View File

@@ -8,10 +8,11 @@ __all__ = [
"Instance" "Instance"
] ]


from typing import Mapping
from fastNLP.core.utils.utils import pretty_table_printer from fastNLP.core.utils.utils import pretty_table_printer




class Instance:
class Instance(Mapping):
r""" r"""
Instance是fastNLP中对应一个sample的类。每个sample在fastNLP中是一个Instance对象。 Instance是fastNLP中对应一个sample的类。每个sample在fastNLP中是一个Instance对象。
Instance一般与 :class:`~fastNLP.DataSet` 一起使用, Instance的初始化如下面的Example所示:: Instance一般与 :class:`~fastNLP.DataSet` 一起使用, Instance的初始化如下面的Example所示::
@@ -69,3 +70,9 @@ class Instance:


def __repr__(self): def __repr__(self):
return str(pretty_table_printer(self)) return str(pretty_table_printer(self))

def __len__(self):
return len(self.fields)

def __iter__(self):
return iter(self.fields)

Loading…
Cancel
Save