|
@@ -10,6 +10,7 @@ from typing import Union |
|
|
|
|
|
|
|
|
from ..core.dataset import DataSet |
|
|
from ..core.dataset import DataSet |
|
|
from ..core.vocabulary import Vocabulary |
|
|
from ..core.vocabulary import Vocabulary |
|
|
|
|
|
from ..core._logger import logger |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DataBundle: |
|
|
class DataBundle: |
|
@@ -47,13 +48,14 @@ class DataBundle: |
|
|
self.vocabs[field_name] = vocab |
|
|
self.vocabs[field_name] = vocab |
|
|
return self |
|
|
return self |
|
|
|
|
|
|
|
|
def set_dataset(self, dataset, name): |
|
|
|
|
|
|
|
|
def set_dataset(self, dataset, name: str): |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
:param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet |
|
|
:param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet |
|
|
:param str name: dataset的名称 |
|
|
:param str name: dataset的名称 |
|
|
:return: self |
|
|
:return: self |
|
|
""" |
|
|
""" |
|
|
|
|
|
assert isinstance(dataset, DataSet), "Only fastNLP.DataSet supports." |
|
|
self.datasets[name] = dataset |
|
|
self.datasets[name] = dataset |
|
|
return self |
|
|
return self |
|
|
|
|
|
|
|
@@ -64,7 +66,13 @@ class DataBundle: |
|
|
:param str name: dataset的名称,一般为'train', 'dev', 'test' |
|
|
:param str name: dataset的名称,一般为'train', 'dev', 'test' |
|
|
:return: DataSet |
|
|
:return: DataSet |
|
|
""" |
|
|
""" |
|
|
return self.datasets[name] |
|
|
|
|
|
|
|
|
if name in self.datasets.keys(): |
|
|
|
|
|
return self.datasets[name] |
|
|
|
|
|
else: |
|
|
|
|
|
error_msg = f'DataBundle do NOT have DataSet named {name}. ' \ |
|
|
|
|
|
f'It should be one of {self.datasets.keys()}.' |
|
|
|
|
|
logger.error(error_msg) |
|
|
|
|
|
raise KeyError(error_msg) |
|
|
|
|
|
|
|
|
def delete_dataset(self, name: str): |
|
|
def delete_dataset(self, name: str): |
|
|
""" |
|
|
""" |
|
@@ -83,7 +91,13 @@ class DataBundle: |
|
|
:param str field_name: 名称 |
|
|
:param str field_name: 名称 |
|
|
:return: Vocabulary |
|
|
:return: Vocabulary |
|
|
""" |
|
|
""" |
|
|
return self.vocabs[field_name] |
|
|
|
|
|
|
|
|
if field_name in self.vocabs.keys(): |
|
|
|
|
|
return self.vocabs[field_name] |
|
|
|
|
|
else: |
|
|
|
|
|
error_msg = f'DataBundle do NOT have Vocabulary named {field_name}. ' \ |
|
|
|
|
|
f'It should be one of {self.vocabs.keys()}.' |
|
|
|
|
|
logger.error(error_msg) |
|
|
|
|
|
raise KeyError(error_msg) |
|
|
|
|
|
|
|
|
def delete_vocab(self, field_name: str): |
|
|
def delete_vocab(self, field_name: str): |
|
|
""" |
|
|
""" |
|
@@ -94,6 +108,14 @@ class DataBundle: |
|
|
self.vocabs.pop(field_name, None) |
|
|
self.vocabs.pop(field_name, None) |
|
|
return self |
|
|
return self |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
|
def num_dataset(self): |
|
|
|
|
|
return len(self.datasets) |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
|
def num_vocab(self): |
|
|
|
|
|
return len(self.vocabs) |
|
|
|
|
|
|
|
|
def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True): |
|
|
def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True): |
|
|
""" |
|
|
""" |
|
|
将field_names中的field设置为input, 对data_bundle中所有的dataset执行该操作:: |
|
|
将field_names中的field设置为input, 对data_bundle中所有的dataset执行该操作:: |
|
@@ -238,7 +260,7 @@ class DataBundle: |
|
|
self.vocabs.pop(field_name) |
|
|
self.vocabs.pop(field_name) |
|
|
return self |
|
|
return self |
|
|
|
|
|
|
|
|
def iter_datasets(self)->Union[str, DataSet]: |
|
|
|
|
|
|
|
|
def iter_datasets(self) -> Union[str, DataSet]: |
|
|
""" |
|
|
""" |
|
|
迭代data_bundle中的DataSet |
|
|
迭代data_bundle中的DataSet |
|
|
|
|
|
|
|
@@ -252,7 +274,7 @@ class DataBundle: |
|
|
for name, dataset in self.datasets.items(): |
|
|
for name, dataset in self.datasets.items(): |
|
|
yield name, dataset |
|
|
yield name, dataset |
|
|
|
|
|
|
|
|
def iter_vocabs(self)->Union[str, Vocabulary]: |
|
|
|
|
|
|
|
|
def iter_vocabs(self) -> Union[str, Vocabulary]: |
|
|
""" |
|
|
""" |
|
|
迭代data_bundle中的DataSet |
|
|
迭代data_bundle中的DataSet |
|
|
|
|
|
|
|
@@ -266,7 +288,7 @@ class DataBundle: |
|
|
for field_name, vocab in self.vocabs.items(): |
|
|
for field_name, vocab in self.vocabs.items(): |
|
|
yield field_name, vocab |
|
|
yield field_name, vocab |
|
|
|
|
|
|
|
|
def apply_field(self, func, field_name:str, new_field_name:str, ignore_miss_dataset=True, **kwargs): |
|
|
|
|
|
|
|
|
def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs): |
|
|
""" |
|
|
""" |
|
|
对DataBundle中所有的dataset使用apply_field方法 |
|
|
对DataBundle中所有的dataset使用apply_field方法 |
|
|
|
|
|
|
|
@@ -313,11 +335,11 @@ class DataBundle: |
|
|
def __repr__(self): |
|
|
def __repr__(self): |
|
|
_str = '' |
|
|
_str = '' |
|
|
if len(self.datasets): |
|
|
if len(self.datasets): |
|
|
_str += 'In total {} datasets:\n'.format(len(self.datasets)) |
|
|
|
|
|
|
|
|
_str += 'In total {} datasets:\n'.format(self.num_dataset) |
|
|
for name, dataset in self.datasets.items(): |
|
|
for name, dataset in self.datasets.items(): |
|
|
_str += '\t{} has {} instances.\n'.format(name, len(dataset)) |
|
|
_str += '\t{} has {} instances.\n'.format(name, len(dataset)) |
|
|
if len(self.vocabs): |
|
|
if len(self.vocabs): |
|
|
_str += 'In total {} vocabs:\n'.format(len(self.vocabs)) |
|
|
|
|
|
|
|
|
_str += 'In total {} vocabs:\n'.format(self.num_vocab) |
|
|
for name, vocab in self.vocabs.items(): |
|
|
for name, vocab in self.vocabs.items(): |
|
|
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) |
|
|
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) |
|
|
return _str |
|
|
return _str |
|
|