Browse Source

To Do (save commit)

tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
843b7c0e7e
3 changed files with 75 additions and 20 deletions
  1. +12
    -0
      fastNLP/loader/dataset_loader.py
  2. +35
    -20
      fastNLP/loader/preprocess.py
  3. +28
    -0
      test/test_keras_like.py

+ 12
- 0
fastNLP/loader/dataset_loader.py View File

@@ -95,3 +95,15 @@ class ConllLoader(DatasetLoader):
continue
tokens.append(line.split())
return sentences


class LMDatasetLoader(DatasetLoader):
def __init__(self, data_name, data_path):
super(LMDatasetLoader, self).__init__(data_name, data_path)

def load(self):
if not os.path.exists(self.data_path):
raise FileNotFoundError("file {} not found.".format(self.data_path))
with open(self.data_path, "r", encoding="utf=8") as f:
text = " ".join(f.readlines())
return text.strip().split()

+ 35
- 20
fastNLP/loader/preprocess.py View File

@@ -41,14 +41,24 @@ class POSPreprocess(BasePreprocess):
to label5.
"""
def __init__(self, data, pickle_path):
super(POSPreprocess, self).__init__(data, pickle_path)
self.word_dict = None
self.word_dict = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
DEFAULT_RESERVED_LABEL[2]: 4}
self.label_dict = None
self.data = data
self.pickle_path = pickle_path
self.build_dict()
self.word2id()
self.build_dict(data)
if not self.pickle_exist("word2id.pkl"):
self.word_dict.update(self.word2id(data))
file_name = os.path.join(self.pickle_path, "word2id.pkl")
with open(file_name, "wb") as f:
_pickle.dump(self.word_dict, f)
self.vocab_size = self.id2word()
self.class2id()
self.num_classes = self.id2class()
@@ -57,26 +67,26 @@ class POSPreprocess(BasePreprocess):
self.data_dev()
self.data_test()
def build_dict(self):
self.word_dict = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
DEFAULT_RESERVED_LABEL[2]: 4}
def build_dict(self, data):
"""
Add new words with indices into self.word_dict, new labels with indices into self.label_dict.
:param data: list of list [word, label]
"""
self.label_dict = {}
for w in self.data:
w = w.strip()
if len(w) <= 1:
for line in data:
line = line.strip()
if len(line) <= 1:
continue
word = w.split('\t')
tokens = line.split('\t')
if word[0] not in self.word_dict:
index = len(self.word_dict)
self.word_dict[word[0]] = index
if tokens[0] not in self.word_dict:
# add (word, index) into the dict
self.word_dict[tokens[0]] = len(self.word_dict)
# for label in word[1: ]:
label = word[1]
if label not in self.label_dict:
index = len(self.label_dict)
self.label_dict[label] = index
# for label in tokens[1: ]:
if tokens[1] not in self.label_dict:
self.label_dict[tokens[1]] = len(self.label_dict)
def pickle_exist(self, pickle_name):
"""
@@ -174,4 +184,9 @@ class POSPreprocess(BasePreprocess):
pass
def data_test(self):
pass
pass
class LMPreprocess(BasePreprocess):
def __init__(self, data, pickle_path):
super(LMPreprocess, self).__init__(data, pickle_path)

+ 28
- 0
test/test_keras_like.py View File

@@ -0,0 +1,28 @@
import aggregation
import decoder
import encoder


class Input(object):
def __init__(self):
pass


class Trainer(object):
def __init__(self, input, target, truth):
pass

def train(self):
pass


def test_keras_like():
data_train, label_train = dataLoader("./data_path")

x = Input()
x = encoder.LSTM(input=x)
x = aggregation.max_pool(input=x)
y = decoder.CRF(input=x)

trainer = Trainer(input=data_train, target=y, truth=label_train)
trainer.train()

Loading…
Cancel
Save