Browse Source

- optimize package calling from test files

- add people.txt in data_for_tests
- To do: incorrect CRF param in POS_pipeline
tags/v0.1.0
FengZiYjun 7 years ago
parent
commit
cca276b8c0
5 changed files with 111 additions and 13 deletions
  1. +28
    -7
      fastNLP/action/trainer.py
  2. +1
    -1
      fastNLP/loader/dataset_loader.py
  3. +7
    -2
      fastNLP/models/sequencce_modeling.py
  4. +67
    -0
      test/data_for_tests/people.txt
  5. +8
    -3
      test/test_POS_pipeline.py

+ 28
- 7
fastNLP/action/trainer.py View File

@@ -31,12 +31,13 @@ class BaseTrainer(Action):
super(BaseTrainer, self).__init__()
self.train_args = train_args
self.n_epochs = train_args.epochs
self.validate = train_args.validate
# self.validate = train_args.validate
self.batch_size = train_args.batch_size
self.pickle_path = train_args.pickle_path
self.model = None
self.iterator = None
self.loss_func = None
self.optimizer = None

def train(self, network):
"""General training loop.
@@ -316,6 +317,8 @@ class WordSegTrainer(BaseTrainer):


class POSTrainer(BaseTrainer):
TrainConfig = namedtuple("config", ["epochs", "batch_size", "pickle_path", "num_classes", "vocab_size"])

def __init__(self, train_args):
super(POSTrainer, self).__init__(train_args)
self.vocab_size = train_args.vocab_size
@@ -328,9 +331,9 @@ class POSTrainer(BaseTrainer):
"""
To do: Load pkl files of train/dev/test and embedding
"""
data_train = _pickle.load(open(data_path + "data_train.pkl", "rb"))
data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb"))
return data_train, data_dev
data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb"))
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb"))
return data_train, data_dev, 0, 1

def data_forward(self, network, x):
seq_len = [len(seq) for seq in x]
@@ -342,10 +345,28 @@ class POSTrainer(BaseTrainer):
self.batch_x = x
return x

def mode(self, test=False):
if test:
self.model.eval()
else:
self.model.train()

def define_optimizer(self):
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9)

def get_loss(self, predict, truth):
truth = torch.LongTensor(truth)
loss, prediction = self.loss_func(self.batch_x, predict, self.mask, self.batch_size, self.max_len)
return loss
"""
Compute loss given prediction and ground truth.
:param predict: prediction label vector
:param truth: ground truth label vector
:return: a scalar
"""
if self.loss_func is None:
if hasattr(self.model, "loss"):
self.loss_func = self.model.loss
else:
self.define_loss()
return self.loss_func(self.batch_x, predict, self.mask, self.batch_size, self.max_len)


if __name__ == "__name__":


+ 1
- 1
fastNLP/loader/dataset_loader.py View File

@@ -23,7 +23,7 @@ class POSDatasetLoader(DatasetLoader):
return line

def load_lines(self):
assert os.path.exists(self.data_path)
assert (os.path.exists(self.data_path))
with open(self.data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
return lines


+ 7
- 2
fastNLP/models/sequencce_modeling.py View File

@@ -58,8 +58,8 @@ class SeqLabeling(BaseModel):

x = self.embedding(x)
x, hidden = self.encode(x)
x = self.aggregation(x)
x = self.output(x)
x = self.aggregate(x)
x = self.decode(x)
return x

def embedding(self, x):
@@ -84,6 +84,11 @@ class SeqLabeling(BaseModel):
:return loss:
prediction:
"""
x = x.float()
y = y.long()
mask = mask.byte()
print(x.shape, y.shape, mask.shape)

if self.use_crf:
total_loss = self.crf(x, y, mask)
tag_seq = self.crf.viterbi_decode(x, mask)


+ 67
- 0
test/data_for_tests/people.txt View File

@@ -0,0 +1,67 @@
迈 B-v
向 E-v
充 B-v
满 E-v
希 B-n
望 E-n
的 S-u
新 S-a
世 B-n
纪 E-n
— B-w
— E-w
一 B-t
九 M-t
九 M-t
八 M-t
年 E-t
新 B-t
年 E-t
讲 B-n
话 E-n
( S-w
附 S-v
图 B-n
片 E-n
1 S-m
张 S-q
) S-w

中 B-nt
共 M-nt
中 M-nt
央 E-nt
总 B-n
书 M-n
记 E-n
、 S-w
国 B-n
家 E-n
主 B-n
席 E-n
江 B-nr
泽 M-nr
民 E-nr

( S-w
一 B-t
九 M-t
九 M-t
七 M-t
年 E-t
十 B-t
二 M-t
月 E-t
三 B-t
十 M-t
一 M-t
日 E-t
) S-w

1 B-t
2 M-t
月 E-t
3 B-t
1 M-t
日 E-t
, S-w

+ 8
- 3
test/test_POS_pipeline.py View File

@@ -1,11 +1,15 @@
import sys

sys.path.append("..")

from fastNLP.action.trainer import POSTrainer
from fastNLP.loader.dataset_loader import POSDatasetLoader
from fastNLP.loader.preprocess import POSPreprocess
from fastNLP.models.sequencce_modeling import SeqLabeling

data_name = "people"
data_path = "data/people.txt"
pickle_path = "data"
data_name = "people.txt"
data_path = "data_for_tests/people.txt"
pickle_path = "data_for_tests"

if __name__ == "__main__":
# Data Loader
@@ -27,3 +31,4 @@ if __name__ == "__main__":

# Start training.
trainer.train(model)


Loading…
Cancel
Save