Browse Source

update DataBundle and add two property: num_dataset and num_vocab

tags/v0.4.10
Yige Xu 5 years ago
parent
commit
143bf8ed32
1 changed files with 30 additions and 8 deletions
  1. +30
    -8
      fastNLP/io/data_bundle.py

+ 30
- 8
fastNLP/io/data_bundle.py View File

@@ -10,6 +10,7 @@ from typing import Union

from ..core.dataset import DataSet
from ..core.vocabulary import Vocabulary
from ..core._logger import logger


class DataBundle:
@@ -47,13 +48,14 @@ class DataBundle:
self.vocabs[field_name] = vocab
return self

def set_dataset(self, dataset, name):
def set_dataset(self, dataset, name: str):
"""

:param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet
:param str name: dataset的名称
:return: self
"""
assert isinstance(dataset, DataSet), "Only fastNLP.DataSet supports."
self.datasets[name] = dataset
return self

@@ -64,7 +66,13 @@ class DataBundle:
:param str name: dataset的名称,一般为'train', 'dev', 'test'
:return: DataSet
"""
return self.datasets[name]
if name in self.datasets.keys():
return self.datasets[name]
else:
error_msg = f'DataBundle do NOT have DataSet named {name}. ' \
f'It should be one of {self.datasets.keys()}.'
logger.error(error_msg)
raise KeyError(error_msg)

def delete_dataset(self, name: str):
"""
@@ -83,7 +91,13 @@ class DataBundle:
:param str field_name: 名称
:return: Vocabulary
"""
return self.vocabs[field_name]
if field_name in self.vocabs.keys():
return self.vocabs[field_name]
else:
error_msg = f'DataBundle do NOT have Vocabulary named {field_name}. ' \
f'It should be one of {self.vocabs.keys()}.'
logger.error(error_msg)
raise KeyError(error_msg)

def delete_vocab(self, field_name: str):
"""
@@ -94,6 +108,14 @@ class DataBundle:
self.vocabs.pop(field_name, None)
return self

@property
def num_dataset(self):
return len(self.datasets)

@property
def num_vocab(self):
return len(self.vocabs)

def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True):
"""
将field_names中的field设置为input, 对data_bundle中所有的dataset执行该操作::
@@ -238,7 +260,7 @@ class DataBundle:
self.vocabs.pop(field_name)
return self

def iter_datasets(self)->Union[str, DataSet]:
def iter_datasets(self) -> Union[str, DataSet]:
"""
迭代data_bundle中的DataSet

@@ -252,7 +274,7 @@ class DataBundle:
for name, dataset in self.datasets.items():
yield name, dataset

def iter_vocabs(self)->Union[str, Vocabulary]:
def iter_vocabs(self) -> Union[str, Vocabulary]:
"""
迭代data_bundle中的DataSet

@@ -266,7 +288,7 @@ class DataBundle:
for field_name, vocab in self.vocabs.items():
yield field_name, vocab

def apply_field(self, func, field_name:str, new_field_name:str, ignore_miss_dataset=True, **kwargs):
def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs):
"""
对DataBundle中所有的dataset使用apply_field方法

@@ -313,11 +335,11 @@ class DataBundle:
def __repr__(self):
_str = ''
if len(self.datasets):
_str += 'In total {} datasets:\n'.format(len(self.datasets))
_str += 'In total {} datasets:\n'.format(self.num_dataset)
for name, dataset in self.datasets.items():
_str += '\t{} has {} instances.\n'.format(name, len(dataset))
if len(self.vocabs):
_str += 'In total {} vocabs:\n'.format(len(self.vocabs))
_str += 'In total {} vocabs:\n'.format(self.num_vocab)
for name, vocab in self.vocabs.items():
_str += '\t{} has {} entries.\n'.format(name, len(vocab))
return _str


Loading…
Cancel
Save