From 982503d03329b9942ef2fb143cb6f7e8e176e65a Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Wed, 4 Jul 2018 22:56:24 +0800 Subject: [PATCH] optimize code style --- fastNLP/loader/base_preprocess.py | 35 --------------- fastNLP/loader/config_loader.py | 11 +++-- fastNLP/loader/dataset_loader.py | 1 - fastNLP/loader/preprocess.py | 73 +++++++++++++++++++++---------- fastNLP/saver/base_saver.py | 14 ++++++ fastNLP/saver/logger.py | 12 +++++ fastNLP/saver/model_saver.py | 8 ++++ 7 files changed, 88 insertions(+), 66 deletions(-) delete mode 100644 fastNLP/loader/base_preprocess.py create mode 100644 fastNLP/saver/base_saver.py create mode 100644 fastNLP/saver/logger.py create mode 100644 fastNLP/saver/model_saver.py diff --git a/fastNLP/loader/base_preprocess.py b/fastNLP/loader/base_preprocess.py deleted file mode 100644 index 806fbd18..00000000 --- a/fastNLP/loader/base_preprocess.py +++ /dev/null @@ -1,35 +0,0 @@ - - -class BasePreprocess(object): - - - def __init__(self, data, pickle_path): - super(BasePreprocess, self).__init__() - self.data = data - self.pickle_path = pickle_path - if not self.pickle_path.endswith('/'): - self.pickle_path = self.pickle_path + '/' - - def word2id(self): - raise NotImplementedError - - def id2word(self): - raise NotImplementedError - - def class2id(self): - raise NotImplementedError - - def id2class(self): - raise NotImplementedError - - def embedding(self): - raise NotImplementedError - - def data_train(self): - raise NotImplementedError - - def data_dev(self): - raise NotImplementedError - - def data_test(self): - raise NotImplementedError diff --git a/fastNLP/loader/config_loader.py b/fastNLP/loader/config_loader.py index 371de4f1..e57d9891 100644 --- a/fastNLP/loader/config_loader.py +++ b/fastNLP/loader/config_loader.py @@ -1,9 +1,8 @@ -from fastNLP.loader.base_loader import BaseLoader - import configparser -import traceback import json +from fastNLP.loader.base_loader import BaseLoader + class ConfigLoader(BaseLoader): """loader for configuration files""" @@ -17,14 +16,14 @@ class ConfigLoader(BaseLoader): raise NotImplementedError @staticmethod - def loadConfig(filePath, sections): + def load_config(file_path, sections): """ - :param filePath: the path of config file + :param file_path: the path of config file :param sections: the dict of sections :return: """ cfg = configparser.ConfigParser() - cfg.read(filePath) + cfg.read(file_path) for s in sections: attr_list = [i for i in type(sections[s]).__dict__.keys() if not callable(getattr(sections[s], i)) and not i.startswith("__")] diff --git a/fastNLP/loader/dataset_loader.py b/fastNLP/loader/dataset_loader.py index f8bcb276..7132eb3b 100644 --- a/fastNLP/loader/dataset_loader.py +++ b/fastNLP/loader/dataset_loader.py @@ -30,7 +30,6 @@ class POSDatasetLoader(DatasetLoader): return lines - class ClassificationDatasetLoader(DatasetLoader): """loader for classfication data sets""" diff --git a/fastNLP/loader/preprocess.py b/fastNLP/loader/preprocess.py index 8e880107..b8d88c35 100644 --- a/fastNLP/loader/preprocess.py +++ b/fastNLP/loader/preprocess.py @@ -1,25 +1,57 @@ -import pickle import _pickle import os -from fastNLP.loader.base_preprocess import BasePreprocess - -DEFAULT_PADDING_LABEL = '' #dict index = 0 -DEFAULT_UNKNOWN_LABEL = '' #dict index = 1 +DEFAULT_PADDING_LABEL = '' # dict index = 0 +DEFAULT_UNKNOWN_LABEL = '' # dict index = 1 DEFAULT_RESERVED_LABEL = ['', '', - ''] #dict index = 2~4 -#the first vocab in dict with the index = 5 + ''] # dict index = 2~4 + + +# the first vocab in dict with the index = 5 + + +class BasePreprocess(object): + + def __init__(self, data, pickle_path): + super(BasePreprocess, self).__init__() + self.data = data + self.pickle_path = pickle_path + if not self.pickle_path.endswith('/'): + self.pickle_path = self.pickle_path + '/' + + def word2id(self): + raise NotImplementedError + + def id2word(self): + raise NotImplementedError + + def class2id(self): + raise NotImplementedError + + def id2class(self): + raise NotImplementedError + def embedding(self): + raise NotImplementedError + + def data_train(self): + raise NotImplementedError + + def data_dev(self): + raise NotImplementedError + + def data_test(self): + raise NotImplementedError class POSPreprocess(BasePreprocess): """ This class are used to preprocess the pos datasets. - In these datasets, each line are divided by '\t' - while the first Col is the vocabulary and the second - Col is the label. + In these datasets, each line is divided by '\t' + The first Col is the vocabulary. + The second Col is the labels. Different sentence are divided by an empty line. e.g: Tom label1 @@ -36,7 +68,9 @@ class POSPreprocess(BasePreprocess): """ def __init__(self, data, pickle_path): - super(POSPreprocess, self).__init(data, pickle_path) + super(POSPreprocess, self).__init__(data, pickle_path) + self.word_dict = None + self.label_dict = None self.build_dict() self.word2id() self.id2word() @@ -46,8 +80,6 @@ class POSPreprocess(BasePreprocess): self.data_train() self.data_dev() self.data_test() - #... - def build_dict(self): self.word_dict = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, @@ -68,7 +100,6 @@ class POSPreprocess(BasePreprocess): index = len(self.label_dict) self.label_dict[label] = index - def pickle_exist(self, pickle_name): """ :param pickle_name: the filename of target pickle file @@ -82,7 +113,6 @@ class POSPreprocess(BasePreprocess): else: return False - def word2id(self): if self.pickle_exist("word2id.pkl"): return @@ -92,11 +122,10 @@ class POSPreprocess(BasePreprocess): with open(file_name, "wb", encoding='utf-8') as f: _pickle.dump(self.word_dict, f) - def id2word(self): if self.pickle_exist("id2word.pkl"): return - #nothing will be done if id2word.pkl exists + # nothing will be done if id2word.pkl exists id2word_dict = {} for word in self.word_dict: @@ -105,7 +134,6 @@ class POSPreprocess(BasePreprocess): with open(file_name, "wb", encoding='utf-8') as f: _pickle.dump(id2word_dict, f) - def class2id(self): if self.pickle_exist("class2id.pkl"): return @@ -115,11 +143,10 @@ class POSPreprocess(BasePreprocess): with open(file_name, "wb", encoding='utf-8') as f: _pickle.dump(self.label_dict, f) - def id2class(self): if self.pickle_exist("id2class.pkl"): return - #nothing will be done if id2class.pkl exists + # nothing will be done if id2class.pkl exists id2class_dict = {} for label in self.label_dict: @@ -128,17 +155,15 @@ class POSPreprocess(BasePreprocess): with open(file_name, "wb", encoding='utf-8') as f: _pickle.dump(id2class_dict, f) - def embedding(self): if self.pickle_exist("embedding.pkl"): return - #nothing will be done if embedding.pkl exists - + # nothing will be done if embedding.pkl exists def data_train(self): if self.pickle_exist("data_train.pkl"): return - #nothing will be done if data_train.pkl exists + # nothing will be done if data_train.pkl exists data_train = [] sentence = [] diff --git a/fastNLP/saver/base_saver.py b/fastNLP/saver/base_saver.py new file mode 100644 index 00000000..d721da2c --- /dev/null +++ b/fastNLP/saver/base_saver.py @@ -0,0 +1,14 @@ +class BaseSaver(object): + """base class for all savers""" + + def __init__(self, save_path): + self.save_path = save_path + + def save_bytes(self): + raise NotImplementedError + + def save_str(self): + raise NotImplementedError + + def compress(self): + raise NotImplementedError diff --git a/fastNLP/saver/logger.py b/fastNLP/saver/logger.py new file mode 100644 index 00000000..be38de40 --- /dev/null +++ b/fastNLP/saver/logger.py @@ -0,0 +1,12 @@ +from saver.base_saver import BaseSaver + + +class Logger(BaseSaver): + """Logging""" + + def __init__(self, save_path): + super(Logger, self).__init__(save_path) + + def log(self, string): + with open(self.save_path, "a") as f: + f.write(string) diff --git a/fastNLP/saver/model_saver.py b/fastNLP/saver/model_saver.py new file mode 100644 index 00000000..3b3cbeca --- /dev/null +++ b/fastNLP/saver/model_saver.py @@ -0,0 +1,8 @@ +from saver.base_saver import BaseSaver + + +class ModelSaver(BaseSaver): + """Save a models""" + + def __init__(self, save_path): + super(ModelSaver, self).__init__(save_path)