Browse Source

optimize code style

tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
982503d033
7 changed files with 88 additions and 66 deletions
  1. +0
    -35
      fastNLP/loader/base_preprocess.py
  2. +5
    -6
      fastNLP/loader/config_loader.py
  3. +0
    -1
      fastNLP/loader/dataset_loader.py
  4. +49
    -24
      fastNLP/loader/preprocess.py
  5. +14
    -0
      fastNLP/saver/base_saver.py
  6. +12
    -0
      fastNLP/saver/logger.py
  7. +8
    -0
      fastNLP/saver/model_saver.py

+ 0
- 35
fastNLP/loader/base_preprocess.py View File

@@ -1,35 +0,0 @@
class BasePreprocess(object):
def __init__(self, data, pickle_path):
super(BasePreprocess, self).__init__()
self.data = data
self.pickle_path = pickle_path
if not self.pickle_path.endswith('/'):
self.pickle_path = self.pickle_path + '/'
def word2id(self):
raise NotImplementedError
def id2word(self):
raise NotImplementedError
def class2id(self):
raise NotImplementedError
def id2class(self):
raise NotImplementedError
def embedding(self):
raise NotImplementedError
def data_train(self):
raise NotImplementedError
def data_dev(self):
raise NotImplementedError
def data_test(self):
raise NotImplementedError

+ 5
- 6
fastNLP/loader/config_loader.py View File

@@ -1,9 +1,8 @@
from fastNLP.loader.base_loader import BaseLoader

import configparser
import traceback
import json

from fastNLP.loader.base_loader import BaseLoader


class ConfigLoader(BaseLoader):
"""loader for configuration files"""
@@ -17,14 +16,14 @@ class ConfigLoader(BaseLoader):
raise NotImplementedError

@staticmethod
def loadConfig(filePath, sections):
def load_config(file_path, sections):
"""
:param filePath: the path of config file
:param file_path: the path of config file
:param sections: the dict of sections
:return:
"""
cfg = configparser.ConfigParser()
cfg.read(filePath)
cfg.read(file_path)
for s in sections:
attr_list = [i for i in type(sections[s]).__dict__.keys() if
not callable(getattr(sections[s], i)) and not i.startswith("__")]


+ 0
- 1
fastNLP/loader/dataset_loader.py View File

@@ -30,7 +30,6 @@ class POSDatasetLoader(DatasetLoader):
return lines



class ClassificationDatasetLoader(DatasetLoader):
"""loader for classfication data sets"""



+ 49
- 24
fastNLP/loader/preprocess.py View File

@@ -1,25 +1,57 @@
import pickle
import _pickle
import os
from fastNLP.loader.base_preprocess import BasePreprocess
DEFAULT_PADDING_LABEL = '<pad>' #dict index = 0
DEFAULT_UNKNOWN_LABEL = '<unk>' #dict index = 1
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1
DEFAULT_RESERVED_LABEL = ['<reserved-2>',
'<reserved-3>',
'<reserved-4>'] #dict index = 2~4
#the first vocab in dict with the index = 5
'<reserved-4>'] # dict index = 2~4
# the first vocab in dict with the index = 5
class BasePreprocess(object):
def __init__(self, data, pickle_path):
super(BasePreprocess, self).__init__()
self.data = data
self.pickle_path = pickle_path
if not self.pickle_path.endswith('/'):
self.pickle_path = self.pickle_path + '/'
def word2id(self):
raise NotImplementedError
def id2word(self):
raise NotImplementedError
def class2id(self):
raise NotImplementedError
def id2class(self):
raise NotImplementedError
def embedding(self):
raise NotImplementedError
def data_train(self):
raise NotImplementedError
def data_dev(self):
raise NotImplementedError
def data_test(self):
raise NotImplementedError
class POSPreprocess(BasePreprocess):
"""
This class are used to preprocess the pos datasets.
In these datasets, each line are divided by '\t'
while the first Col is the vocabulary and the second
Col is the label.
In these datasets, each line is divided by '\t'
The first Col is the vocabulary.
The second Col is the labels.
Different sentence are divided by an empty line.
e.g:
Tom label1
@@ -36,7 +68,9 @@ class POSPreprocess(BasePreprocess):
"""
def __init__(self, data, pickle_path):
super(POSPreprocess, self).__init(data, pickle_path)
super(POSPreprocess, self).__init__(data, pickle_path)
self.word_dict = None
self.label_dict = None
self.build_dict()
self.word2id()
self.id2word()
@@ -46,8 +80,6 @@ class POSPreprocess(BasePreprocess):
self.data_train()
self.data_dev()
self.data_test()
#...
def build_dict(self):
self.word_dict = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
@@ -68,7 +100,6 @@ class POSPreprocess(BasePreprocess):
index = len(self.label_dict)
self.label_dict[label] = index
def pickle_exist(self, pickle_name):
"""
:param pickle_name: the filename of target pickle file
@@ -82,7 +113,6 @@ class POSPreprocess(BasePreprocess):
else:
return False
def word2id(self):
if self.pickle_exist("word2id.pkl"):
return
@@ -92,11 +122,10 @@ class POSPreprocess(BasePreprocess):
with open(file_name, "wb", encoding='utf-8') as f:
_pickle.dump(self.word_dict, f)
def id2word(self):
if self.pickle_exist("id2word.pkl"):
return
#nothing will be done if id2word.pkl exists
# nothing will be done if id2word.pkl exists
id2word_dict = {}
for word in self.word_dict:
@@ -105,7 +134,6 @@ class POSPreprocess(BasePreprocess):
with open(file_name, "wb", encoding='utf-8') as f:
_pickle.dump(id2word_dict, f)
def class2id(self):
if self.pickle_exist("class2id.pkl"):
return
@@ -115,11 +143,10 @@ class POSPreprocess(BasePreprocess):
with open(file_name, "wb", encoding='utf-8') as f:
_pickle.dump(self.label_dict, f)
def id2class(self):
if self.pickle_exist("id2class.pkl"):
return
#nothing will be done if id2class.pkl exists
# nothing will be done if id2class.pkl exists
id2class_dict = {}
for label in self.label_dict:
@@ -128,17 +155,15 @@ class POSPreprocess(BasePreprocess):
with open(file_name, "wb", encoding='utf-8') as f:
_pickle.dump(id2class_dict, f)
def embedding(self):
if self.pickle_exist("embedding.pkl"):
return
#nothing will be done if embedding.pkl exists
# nothing will be done if embedding.pkl exists
def data_train(self):
if self.pickle_exist("data_train.pkl"):
return
#nothing will be done if data_train.pkl exists
# nothing will be done if data_train.pkl exists
data_train = []
sentence = []


+ 14
- 0
fastNLP/saver/base_saver.py View File

@@ -0,0 +1,14 @@
class BaseSaver(object):
"""base class for all savers"""

def __init__(self, save_path):
self.save_path = save_path

def save_bytes(self):
raise NotImplementedError

def save_str(self):
raise NotImplementedError

def compress(self):
raise NotImplementedError

+ 12
- 0
fastNLP/saver/logger.py View File

@@ -0,0 +1,12 @@
from saver.base_saver import BaseSaver


class Logger(BaseSaver):
"""Logging"""

def __init__(self, save_path):
super(Logger, self).__init__(save_path)

def log(self, string):
with open(self.save_path, "a") as f:
f.write(string)

+ 8
- 0
fastNLP/saver/model_saver.py View File

@@ -0,0 +1,8 @@
from saver.base_saver import BaseSaver


class ModelSaver(BaseSaver):
"""Save a models"""

def __init__(self, save_path):
super(ModelSaver, self).__init__(save_path)

Loading…
Cancel
Save