From cca276b8c09add219bbbcaa8cbf78d786358cea3 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sat, 7 Jul 2018 16:57:57 +0800 Subject: [PATCH] - optimize package calling from test files - add people.txt in data_for_tests - To do: incorrect CRF param in POS_pipeline --- fastNLP/action/trainer.py | 35 ++++++++++++--- fastNLP/loader/dataset_loader.py | 2 +- fastNLP/models/sequencce_modeling.py | 9 +++- test/data_for_tests/people.txt | 67 ++++++++++++++++++++++++++++ test/test_POS_pipeline.py | 11 +++-- 5 files changed, 111 insertions(+), 13 deletions(-) create mode 100644 test/data_for_tests/people.txt diff --git a/fastNLP/action/trainer.py b/fastNLP/action/trainer.py index ac7138e5..94a704f9 100644 --- a/fastNLP/action/trainer.py +++ b/fastNLP/action/trainer.py @@ -31,12 +31,13 @@ class BaseTrainer(Action): super(BaseTrainer, self).__init__() self.train_args = train_args self.n_epochs = train_args.epochs - self.validate = train_args.validate + # self.validate = train_args.validate self.batch_size = train_args.batch_size self.pickle_path = train_args.pickle_path self.model = None self.iterator = None self.loss_func = None + self.optimizer = None def train(self, network): """General training loop. @@ -316,6 +317,8 @@ class WordSegTrainer(BaseTrainer): class POSTrainer(BaseTrainer): + TrainConfig = namedtuple("config", ["epochs", "batch_size", "pickle_path", "num_classes", "vocab_size"]) + def __init__(self, train_args): super(POSTrainer, self).__init__(train_args) self.vocab_size = train_args.vocab_size @@ -328,9 +331,9 @@ class POSTrainer(BaseTrainer): """ To do: Load pkl files of train/dev/test and embedding """ - data_train = _pickle.load(open(data_path + "data_train.pkl", "rb")) - data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb")) - return data_train, data_dev + data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb")) + data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) + return data_train, data_dev, 0, 1 def data_forward(self, network, x): seq_len = [len(seq) for seq in x] @@ -342,10 +345,28 @@ class POSTrainer(BaseTrainer): self.batch_x = x return x + def mode(self, test=False): + if test: + self.model.eval() + else: + self.model.train() + + def define_optimizer(self): + self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) + def get_loss(self, predict, truth): - truth = torch.LongTensor(truth) - loss, prediction = self.loss_func(self.batch_x, predict, self.mask, self.batch_size, self.max_len) - return loss + """ + Compute loss given prediction and ground truth. + :param predict: prediction label vector + :param truth: ground truth label vector + :return: a scalar + """ + if self.loss_func is None: + if hasattr(self.model, "loss"): + self.loss_func = self.model.loss + else: + self.define_loss() + return self.loss_func(self.batch_x, predict, self.mask, self.batch_size, self.max_len) if __name__ == "__name__": diff --git a/fastNLP/loader/dataset_loader.py b/fastNLP/loader/dataset_loader.py index 284be715..d57a48db 100644 --- a/fastNLP/loader/dataset_loader.py +++ b/fastNLP/loader/dataset_loader.py @@ -23,7 +23,7 @@ class POSDatasetLoader(DatasetLoader): return line def load_lines(self): - assert os.path.exists(self.data_path) + assert (os.path.exists(self.data_path)) with open(self.data_path, "r", encoding="utf-8") as f: lines = f.readlines() return lines diff --git a/fastNLP/models/sequencce_modeling.py b/fastNLP/models/sequencce_modeling.py index af6931e4..ba96d4b6 100644 --- a/fastNLP/models/sequencce_modeling.py +++ b/fastNLP/models/sequencce_modeling.py @@ -58,8 +58,8 @@ class SeqLabeling(BaseModel): x = self.embedding(x) x, hidden = self.encode(x) - x = self.aggregation(x) - x = self.output(x) + x = self.aggregate(x) + x = self.decode(x) return x def embedding(self, x): @@ -84,6 +84,11 @@ class SeqLabeling(BaseModel): :return loss: prediction: """ + x = x.float() + y = y.long() + mask = mask.byte() + print(x.shape, y.shape, mask.shape) + if self.use_crf: total_loss = self.crf(x, y, mask) tag_seq = self.crf.viterbi_decode(x, mask) diff --git a/test/data_for_tests/people.txt b/test/data_for_tests/people.txt new file mode 100644 index 00000000..f34c85cb --- /dev/null +++ b/test/data_for_tests/people.txt @@ -0,0 +1,67 @@ +迈 B-v +向 E-v +充 B-v +满 E-v +希 B-n +望 E-n +的 S-u +新 S-a +世 B-n +纪 E-n +— B-w +— E-w +一 B-t +九 M-t +九 M-t +八 M-t +年 E-t +新 B-t +年 E-t +讲 B-n +话 E-n +( S-w +附 S-v +图 B-n +片 E-n +1 S-m +张 S-q +) S-w + +中 B-nt +共 M-nt +中 M-nt +央 E-nt +总 B-n +书 M-n +记 E-n +、 S-w +国 B-n +家 E-n +主 B-n +席 E-n +江 B-nr +泽 M-nr +民 E-nr + +( S-w +一 B-t +九 M-t +九 M-t +七 M-t +年 E-t +十 B-t +二 M-t +月 E-t +三 B-t +十 M-t +一 M-t +日 E-t +) S-w + +1 B-t +2 M-t +月 E-t +3 B-t +1 M-t +日 E-t +, S-w \ No newline at end of file diff --git a/test/test_POS_pipeline.py b/test/test_POS_pipeline.py index db4232e7..66e418c6 100644 --- a/test/test_POS_pipeline.py +++ b/test/test_POS_pipeline.py @@ -1,11 +1,15 @@ +import sys + +sys.path.append("..") + from fastNLP.action.trainer import POSTrainer from fastNLP.loader.dataset_loader import POSDatasetLoader from fastNLP.loader.preprocess import POSPreprocess from fastNLP.models.sequencce_modeling import SeqLabeling -data_name = "people" -data_path = "data/people.txt" -pickle_path = "data" +data_name = "people.txt" +data_path = "data_for_tests/people.txt" +pickle_path = "data_for_tests" if __name__ == "__main__": # Data Loader @@ -27,3 +31,4 @@ if __name__ == "__main__": # Start training. trainer.train(model) +