@@ -10,10 +10,10 @@ class Batch(object): | |||||
for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()): | for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()): | ||||
# ... | # ... | ||||
:param dataset: a DataSet object | |||||
:param batch_size: int, the size of the batch | |||||
:param sampler: a Sampler object | |||||
:param as_numpy: bool. If True, return Numpy array. Otherwise, return torch tensors. | |||||
:param DataSet dataset: a DataSet object | |||||
:param int batch_size: the size of the batch | |||||
:param Sampler sampler: a Sampler object | |||||
:param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. | |||||
""" | """ | ||||
@@ -3,7 +3,9 @@ import os | |||||
class BaseLoader(object): | class BaseLoader(object): | ||||
"""Base loader for all loaders. | |||||
""" | |||||
def __init__(self): | def __init__(self): | ||||
super(BaseLoader, self).__init__() | super(BaseLoader, self).__init__() | ||||
@@ -32,7 +34,9 @@ class BaseLoader(object): | |||||
class DataLoaderRegister: | class DataLoaderRegister: | ||||
""""register for data sets""" | |||||
"""Register for all data sets. | |||||
""" | |||||
_readers = {} | _readers = {} | ||||
@classmethod | @classmethod | ||||
@@ -6,7 +6,11 @@ from fastNLP.io.base_loader import BaseLoader | |||||
class ConfigLoader(BaseLoader): | class ConfigLoader(BaseLoader): | ||||
"""loader for configuration files""" | |||||
"""Loader for configuration. | |||||
:param str data_path: path to the config | |||||
""" | |||||
def __init__(self, data_path=None): | def __init__(self, data_path=None): | ||||
super(ConfigLoader, self).__init__() | super(ConfigLoader, self).__init__() | ||||
@@ -19,13 +23,15 @@ class ConfigLoader(BaseLoader): | |||||
@staticmethod | @staticmethod | ||||
def load_config(file_path, sections): | def load_config(file_path, sections): | ||||
""" | |||||
:param file_path: the path of config file | |||||
:param sections: the dict of {section_name(string): Section instance} | |||||
Example: | |||||
"""Load section(s) of configuration into the ``sections`` provided. No returns. | |||||
:param str file_path: the path of config file | |||||
:param dict sections: the dict of ``{section_name(string): ConfigSection object}`` | |||||
Example:: | |||||
test_args = ConfigSection() | test_args = ConfigSection() | ||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | ||||
:return: return nothing, but the value of attributes are saved in sessions | |||||
""" | """ | ||||
assert isinstance(sections, dict) | assert isinstance(sections, dict) | ||||
cfg = configparser.ConfigParser() | cfg = configparser.ConfigParser() | ||||
@@ -60,9 +66,12 @@ class ConfigLoader(BaseLoader): | |||||
class ConfigSection(object): | class ConfigSection(object): | ||||
"""ConfigSection is the data structure storing all key-value pairs in one section in a config file. | |||||
""" | |||||
def __init__(self): | def __init__(self): | ||||
pass | |||||
super(ConfigSection, self).__init__() | |||||
def __getitem__(self, key): | def __getitem__(self, key): | ||||
""" | """ | ||||
@@ -132,25 +141,12 @@ class ConfigSection(object): | |||||
return self.__dict__ | return self.__dict__ | ||||
if __name__ == "__main__": | |||||
config = ConfigLoader('there is no data') | |||||
section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()} | |||||
""" | |||||
General and My can be found in config file, so the attr and | |||||
value will be updated | |||||
A cannot be found in config file, so nothing will be done | |||||
""" | |||||
config.load_config("../../test/data_for_tests/config", section) | |||||
for s in section: | |||||
print(s) | |||||
for attr in section[s].__dict__.keys(): | |||||
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr))) | |||||
class ConfigSaver(object): | class ConfigSaver(object): | ||||
"""ConfigSaver is used to save config file and solve related conflicts. | |||||
:param str file_path: path to the config file | |||||
""" | |||||
def __init__(self, file_path): | def __init__(self, file_path): | ||||
self.file_path = file_path | self.file_path = file_path | ||||
if not os.path.exists(self.file_path): | if not os.path.exists(self.file_path): | ||||
@@ -244,9 +240,8 @@ class ConfigSaver(object): | |||||
def save_config_file(self, section_name, section): | def save_config_file(self, section_name, section): | ||||
"""This is the function to be called to change the config file with a single section and its name. | """This is the function to be called to change the config file with a single section and its name. | ||||
:param section_name: The name of section what needs to be changed and saved. | |||||
:param section: The section with key and value what needs to be changed and saved. | |||||
:return: | |||||
:param str section_name: The name of section what needs to be changed and saved. | |||||
:param ConfigSection section: The section with key and value what needs to be changed and saved. | |||||
""" | """ | ||||
section_file = self._get_section(section_name) | section_file = self._get_section(section_name) | ||||
if len(section_file.__dict__.keys()) == 0: # the section not in the file before | if len(section_file.__dict__.keys()) == 0: # the section not in the file before | ||||
@@ -9,11 +9,12 @@ def convert_seq_dataset(data): | |||||
"""Create an DataSet instance that contains no labels. | """Create an DataSet instance that contains no labels. | ||||
:param data: list of list of strings, [num_examples, *]. | :param data: list of list of strings, [num_examples, *]. | ||||
:: | |||||
[ | |||||
[word_11, word_12, ...], | |||||
... | |||||
] | |||||
Example:: | |||||
[ | |||||
[word_11, word_12, ...], | |||||
... | |||||
] | |||||
:return: a DataSet. | :return: a DataSet. | ||||
""" | """ | ||||
@@ -24,15 +25,16 @@ def convert_seq_dataset(data): | |||||
def convert_seq2tag_dataset(data): | def convert_seq2tag_dataset(data): | ||||
"""Convert list of data into DataSet | |||||
"""Convert list of data into DataSet. | |||||
:param data: list of list of strings, [num_examples, *]. | :param data: list of list of strings, [num_examples, *]. | ||||
:: | |||||
[ | |||||
[ [word_11, word_12, ...], label_1 ], | |||||
[ [word_21, word_22, ...], label_2 ], | |||||
... | |||||
] | |||||
Example:: | |||||
[ | |||||
[ [word_11, word_12, ...], label_1 ], | |||||
[ [word_21, word_22, ...], label_2 ], | |||||
... | |||||
] | |||||
:return: a DataSet. | :return: a DataSet. | ||||
""" | """ | ||||
@@ -43,15 +45,16 @@ def convert_seq2tag_dataset(data): | |||||
def convert_seq2seq_dataset(data): | def convert_seq2seq_dataset(data): | ||||
"""Convert list of data into DataSet | |||||
"""Convert list of data into DataSet. | |||||
:param data: list of list of strings, [num_examples, *]. | :param data: list of list of strings, [num_examples, *]. | ||||
:: | |||||
[ | |||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
... | |||||
] | |||||
Example:: | |||||
[ | |||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
... | |||||
] | |||||
:return: a DataSet. | :return: a DataSet. | ||||
""" | """ | ||||
@@ -62,20 +65,31 @@ def convert_seq2seq_dataset(data): | |||||
class DataSetLoader: | class DataSetLoader: | ||||
""""loader for data sets""" | |||||
"""Interface for all DataSetLoaders. | |||||
""" | |||||
def load(self, path): | def load(self, path): | ||||
""" load data in `path` into a dataset | |||||
"""Load data from a given file. | |||||
:param str path: file path | |||||
:return: a DataSet object | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def convert(self, data): | def convert(self, data): | ||||
"""convert list of data into dataset | |||||
"""Optional operation to build a DataSet. | |||||
:param data: inner data structure (user-defined) to represent the data. | |||||
:return: a DataSet object | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
class NativeDataSetLoader(DataSetLoader): | class NativeDataSetLoader(DataSetLoader): | ||||
"""A simple example of DataSetLoader | |||||
""" | |||||
def __init__(self): | def __init__(self): | ||||
super(NativeDataSetLoader, self).__init__() | super(NativeDataSetLoader, self).__init__() | ||||
@@ -90,6 +104,9 @@ DataLoaderRegister.set_reader(NativeDataSetLoader, 'read_naive') | |||||
class RawDataSetLoader(DataSetLoader): | class RawDataSetLoader(DataSetLoader): | ||||
"""A simple example of raw data reader | |||||
""" | |||||
def __init__(self): | def __init__(self): | ||||
super(RawDataSetLoader, self).__init__() | super(RawDataSetLoader, self).__init__() | ||||
@@ -108,37 +125,35 @@ DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') | |||||
class POSDataSetLoader(DataSetLoader): | class POSDataSetLoader(DataSetLoader): | ||||
"""Dataset Loader for POS Tag datasets. | |||||
In these datasets, each line are divided by '\t' | |||||
while the first Col is the vocabulary and the second | |||||
Col is the label. | |||||
Different sentence are divided by an empty line. | |||||
e.g: | |||||
Tom label1 | |||||
and label2 | |||||
Jerry label1 | |||||
. label3 | |||||
(separated by an empty line) | |||||
Hello label4 | |||||
world label5 | |||||
! label3 | |||||
In this file, there are two sentence "Tom and Jerry ." | |||||
and "Hello world !". Each word has its own label from label1 | |||||
to label5. | |||||
"""Dataset Loader for a POS Tag dataset. | |||||
In these datasets, each line are divided by "\t". The first Col is the vocabulary and the second | |||||
Col is the label. Different sentence are divided by an empty line. | |||||
E.g:: | |||||
Tom label1 | |||||
and label2 | |||||
Jerry label1 | |||||
. label3 | |||||
(separated by an empty line) | |||||
Hello label4 | |||||
world label5 | |||||
! label3 | |||||
In this example, there are two sentences "Tom and Jerry ." and "Hello world !". Each word has its own label. | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(POSDataSetLoader, self).__init__() | super(POSDataSetLoader, self).__init__() | ||||
def load(self, data_path): | def load(self, data_path): | ||||
""" | """ | ||||
:return data: three-level list | :return data: three-level list | ||||
[ | |||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
... | |||||
] | |||||
Example:: | |||||
[ | |||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
... | |||||
] | |||||
""" | """ | ||||
with open(data_path, "r", encoding="utf-8") as f: | with open(data_path, "r", encoding="utf-8") as f: | ||||
lines = f.readlines() | lines = f.readlines() | ||||
@@ -188,17 +203,17 @@ class TokenizeDataSetLoader(DataSetLoader): | |||||
super(TokenizeDataSetLoader, self).__init__() | super(TokenizeDataSetLoader, self).__init__() | ||||
def load(self, data_path, max_seq_len=32): | def load(self, data_path, max_seq_len=32): | ||||
""" | |||||
load pku dataset for Chinese word segmentation | |||||
"""Load pku dataset for Chinese word segmentation. | |||||
CWS (Chinese Word Segmentation) pku training dataset format: | CWS (Chinese Word Segmentation) pku training dataset format: | ||||
1. Each line is a sentence. | |||||
2. Each word in a sentence is separated by space. | |||||
1. Each line is a sentence. | |||||
2. Each word in a sentence is separated by space. | |||||
This function convert the pku dataset into three-level lists with labels <BMES>. | This function convert the pku dataset into three-level lists with labels <BMES>. | ||||
B: beginning of a word | |||||
M: middle of a word | |||||
E: ending of a word | |||||
S: single character | |||||
B: beginning of a word | |||||
M: middle of a word | |||||
E: ending of a word | |||||
S: single character | |||||
:param str data_path: path to the data set. | |||||
:param max_seq_len: int, the maximum length of a sequence. If a sequence is longer than it, split it into | :param max_seq_len: int, the maximum length of a sequence. If a sequence is longer than it, split it into | ||||
several sequences. | several sequences. | ||||
:return: three-level lists | :return: three-level lists | ||||
@@ -254,11 +269,9 @@ class ClassDataSetLoader(DataSetLoader): | |||||
@staticmethod | @staticmethod | ||||
def parse(lines): | def parse(lines): | ||||
""" | """ | ||||
Params | |||||
lines: lines from dataset | |||||
Return | |||||
list(list(list())): the three level of lists are | |||||
words, sentence, and dataset | |||||
:param lines: lines from dataset | |||||
:return: list(list(list())): the three level of lists are words, sentence, and dataset | |||||
""" | """ | ||||
dataset = list() | dataset = list() | ||||
for line in lines: | for line in lines: | ||||
@@ -280,15 +293,9 @@ class ConllLoader(DataSetLoader): | |||||
"""loader for conll format files""" | """loader for conll format files""" | ||||
def __init__(self): | def __init__(self): | ||||
""" | |||||
:param str data_path: the path to the conll data set | |||||
""" | |||||
super(ConllLoader, self).__init__() | super(ConllLoader, self).__init__() | ||||
def load(self, data_path): | def load(self, data_path): | ||||
""" | |||||
:return: list lines: all lines in a conll file | |||||
""" | |||||
with open(data_path, "r", encoding="utf-8") as f: | with open(data_path, "r", encoding="utf-8") as f: | ||||
lines = f.readlines() | lines = f.readlines() | ||||
data = self.parse(lines) | data = self.parse(lines) | ||||
@@ -320,8 +327,8 @@ class ConllLoader(DataSetLoader): | |||||
class LMDataSetLoader(DataSetLoader): | class LMDataSetLoader(DataSetLoader): | ||||
"""Language Model Dataset Loader | """Language Model Dataset Loader | ||||
This loader produces data for language model training in a supervised way. | |||||
That means it has X and Y. | |||||
This loader produces data for language model training in a supervised way. | |||||
That means it has X and Y. | |||||
""" | """ | ||||
@@ -467,6 +474,7 @@ class Conll2003Loader(DataSetLoader): | |||||
return dataset | return dataset | ||||
class SNLIDataSetLoader(DataSetLoader): | class SNLIDataSetLoader(DataSetLoader): | ||||
"""A data set loader for SNLI data set. | """A data set loader for SNLI data set. | ||||
@@ -478,8 +486,8 @@ class SNLIDataSetLoader(DataSetLoader): | |||||
def load(self, path_list): | def load(self, path_list): | ||||
""" | """ | ||||
:param path_list: A list of file name, in the order of premise file, hypothesis file, and label file. | |||||
:return: data_set: A DataSet object. | |||||
:param list path_list: A list of file name, in the order of premise file, hypothesis file, and label file. | |||||
:return: A DataSet object. | |||||
""" | """ | ||||
assert len(path_list) == 3 | assert len(path_list) == 3 | ||||
line_set = [] | line_set = [] | ||||
@@ -507,12 +515,14 @@ class SNLIDataSetLoader(DataSetLoader): | |||||
"""Convert a 3D list to a DataSet object. | """Convert a 3D list to a DataSet object. | ||||
:param data: A 3D tensor. | :param data: A 3D tensor. | ||||
[ | |||||
[ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ], | |||||
[ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ], | |||||
... | |||||
] | |||||
:return: data_set: A DataSet object. | |||||
Example:: | |||||
[ | |||||
[ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ], | |||||
[ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ], | |||||
... | |||||
] | |||||
:return: A DataSet object. | |||||
""" | """ | ||||
data_set = DataSet() | data_set = DataSet() | ||||
@@ -38,7 +38,7 @@ class EmbedLoader(BaseLoader): | |||||
:param str emb_file: the pre-trained embedding file path | :param str emb_file: the pre-trained embedding file path | ||||
:param str emb_type: the pre-trained embedding data format | :param str emb_type: the pre-trained embedding data format | ||||
:return dict embedding: `{str: np.array}` | |||||
:return: a dict of ``{str: np.array}`` | |||||
""" | """ | ||||
if emb_type == 'glove': | if emb_type == 'glove': | ||||
return EmbedLoader._load_glove(emb_file) | return EmbedLoader._load_glove(emb_file) | ||||
@@ -53,8 +53,9 @@ class EmbedLoader(BaseLoader): | |||||
:param str emb_file: the pre-trained embedding file path. | :param str emb_file: the pre-trained embedding file path. | ||||
:param str emb_type: the pre-trained embedding format, support glove now | :param str emb_type: the pre-trained embedding format, support glove now | ||||
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding | :param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding | ||||
:return embedding_tensor: Tensor of shape (len(word_dict), emb_dim) | |||||
vocab: input vocab or vocab built by pre-train | |||||
:return (embedding_tensor, vocab): | |||||
embedding_tensor - Tensor of shape (len(word_dict), emb_dim); | |||||
vocab - input vocab or vocab built by pre-train | |||||
""" | """ | ||||
pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) | pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) | ||||
@@ -87,7 +88,7 @@ class EmbedLoader(BaseLoader): | |||||
:param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding. | :param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding. | ||||
:param str emb_file: the pre-trained embedding file path. | :param str emb_file: the pre-trained embedding file path. | ||||
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding | :param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding | ||||
:return numpy.ndarray embedding_matrix: | |||||
:return embedding_matrix: numpy.ndarray | |||||
""" | """ | ||||
if vocab is None: | if vocab is None: | ||||
@@ -3,15 +3,16 @@ import os | |||||
def create_logger(logger_name, log_path, log_format=None, log_level=logging.INFO): | def create_logger(logger_name, log_path, log_format=None, log_level=logging.INFO): | ||||
"""Return a logger. | |||||
"""Create a logger. | |||||
:param logger_name: str | |||||
:param log_path: str | |||||
:param str logger_name: | |||||
:param str log_path: | |||||
:param log_format: | :param log_format: | ||||
:param log_level: | :param log_level: | ||||
:return: logger | :return: logger | ||||
to use a logger: | |||||
To use a logger:: | |||||
logger.debug("this is a debug message") | logger.debug("this is a debug message") | ||||
logger.info("this is a info message") | logger.info("this is a info message") | ||||
logger.warning("this is a warning message") | logger.warning("this is a warning message") | ||||
@@ -13,10 +13,10 @@ class ModelLoader(BaseLoader): | |||||
@staticmethod | @staticmethod | ||||
def load_pytorch(empty_model, model_path): | def load_pytorch(empty_model, model_path): | ||||
""" | |||||
Load model parameters from .pkl files into the empty PyTorch model. | |||||
"""Load model parameters from ".pkl" files into the empty PyTorch model. | |||||
:param empty_model: a PyTorch model with initialized parameters. | :param empty_model: a PyTorch model with initialized parameters. | ||||
:param model_path: str, the path to the saved model. | |||||
:param str model_path: the path to the saved model. | |||||
""" | """ | ||||
empty_model.load_state_dict(torch.load(model_path)) | empty_model.load_state_dict(torch.load(model_path)) | ||||
@@ -24,30 +24,30 @@ class ModelLoader(BaseLoader): | |||||
def load_pytorch_model(model_path): | def load_pytorch_model(model_path): | ||||
"""Load the entire model. | """Load the entire model. | ||||
:param str model_path: the path to the saved model. | |||||
""" | """ | ||||
return torch.load(model_path) | return torch.load(model_path) | ||||
class ModelSaver(object): | class ModelSaver(object): | ||||
"""Save a model | """Save a model | ||||
:param str save_path: the path to the saving directory. | |||||
Example:: | Example:: | ||||
saver = ModelSaver("./save/model_ckpt_100.pkl") | saver = ModelSaver("./save/model_ckpt_100.pkl") | ||||
saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
""" | """ | ||||
def __init__(self, save_path): | def __init__(self, save_path): | ||||
""" | |||||
:param save_path: str, the path to the saving directory. | |||||
""" | |||||
self.save_path = save_path | self.save_path = save_path | ||||
def save_pytorch(self, model, param_only=True): | def save_pytorch(self, model, param_only=True): | ||||
"""Save a pytorch model into .pkl file. | |||||
"""Save a pytorch model into ".pkl" file. | |||||
:param model: a PyTorch model | :param model: a PyTorch model | ||||
:param param_only: bool, whether only to save the model parameters or the entire model. | |||||
:param bool param_only: whether only to save the model parameters or the entire model. | |||||
""" | """ | ||||
if param_only is True: | if param_only is True: | ||||