diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index 1f05cf68..ba275e61 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -10,6 +10,7 @@ from typing import Union from ..core.dataset import DataSet from ..core.vocabulary import Vocabulary +from ..core._logger import logger class DataBundle: @@ -47,13 +48,14 @@ class DataBundle: self.vocabs[field_name] = vocab return self - def set_dataset(self, dataset, name): + def set_dataset(self, dataset, name: str): """ :param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet :param str name: dataset的名称 :return: self """ + assert isinstance(dataset, DataSet), "Only fastNLP.DataSet supports." self.datasets[name] = dataset return self @@ -64,7 +66,13 @@ class DataBundle: :param str name: dataset的名称,一般为'train', 'dev', 'test' :return: DataSet """ - return self.datasets[name] + if name in self.datasets.keys(): + return self.datasets[name] + else: + error_msg = f'DataBundle do NOT have DataSet named {name}. ' \ + f'It should be one of {self.datasets.keys()}.' + logger.error(error_msg) + raise KeyError(error_msg) def delete_dataset(self, name: str): """ @@ -83,7 +91,13 @@ class DataBundle: :param str field_name: 名称 :return: Vocabulary """ - return self.vocabs[field_name] + if field_name in self.vocabs.keys(): + return self.vocabs[field_name] + else: + error_msg = f'DataBundle do NOT have Vocabulary named {field_name}. ' \ + f'It should be one of {self.vocabs.keys()}.' + logger.error(error_msg) + raise KeyError(error_msg) def delete_vocab(self, field_name: str): """ @@ -94,6 +108,14 @@ class DataBundle: self.vocabs.pop(field_name, None) return self + @property + def num_dataset(self): + return len(self.datasets) + + @property + def num_vocab(self): + return len(self.vocabs) + def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True): """ 将field_names中的field设置为input, 对data_bundle中所有的dataset执行该操作:: @@ -238,7 +260,7 @@ class DataBundle: self.vocabs.pop(field_name) return self - def iter_datasets(self)->Union[str, DataSet]: + def iter_datasets(self) -> Union[str, DataSet]: """ 迭代data_bundle中的DataSet @@ -252,7 +274,7 @@ class DataBundle: for name, dataset in self.datasets.items(): yield name, dataset - def iter_vocabs(self)->Union[str, Vocabulary]: + def iter_vocabs(self) -> Union[str, Vocabulary]: """ 迭代data_bundle中的DataSet @@ -266,7 +288,7 @@ class DataBundle: for field_name, vocab in self.vocabs.items(): yield field_name, vocab - def apply_field(self, func, field_name:str, new_field_name:str, ignore_miss_dataset=True, **kwargs): + def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs): """ 对DataBundle中所有的dataset使用apply_field方法 @@ -313,11 +335,11 @@ class DataBundle: def __repr__(self): _str = '' if len(self.datasets): - _str += 'In total {} datasets:\n'.format(len(self.datasets)) + _str += 'In total {} datasets:\n'.format(self.num_dataset) for name, dataset in self.datasets.items(): _str += '\t{} has {} instances.\n'.format(name, len(dataset)) if len(self.vocabs): - _str += 'In total {} vocabs:\n'.format(len(self.vocabs)) + _str += 'In total {} vocabs:\n'.format(self.num_vocab) for name, vocab in self.vocabs.items(): _str += '\t{} has {} entries.\n'.format(name, len(vocab)) return _str