Browse Source

- add validation loss into trainer.train

- restructure: move reproduction outside
- add evaluate in tester
tags/v0.1.0
FengZiYjun 7 years ago
parent
commit
7514be6f30
30 changed files with 29 additions and 9 deletions
  1. +27
    -9
      fastNLP/action/tester.py
  2. +2
    -0
      fastNLP/action/trainer.py
  3. +0
    -0
      fastNLP/reproduction/__init__.py
  4. +0
    -0
      reproduction/CNN-sentence_classification/.gitignore
  5. +0
    -0
      reproduction/CNN-sentence_classification/README.md
  6. +0
    -0
      reproduction/CNN-sentence_classification/__init__.py
  7. +0
    -0
      reproduction/CNN-sentence_classification/dataset.py
  8. +0
    -0
      reproduction/CNN-sentence_classification/model.py
  9. +0
    -0
      reproduction/CNN-sentence_classification/rt-polaritydata/rt-polarity.neg
  10. +0
    -0
      reproduction/CNN-sentence_classification/rt-polaritydata/rt-polarity.pos
  11. +0
    -0
      reproduction/CNN-sentence_classification/train.py
  12. +0
    -0
      reproduction/Char-aware_NLM/LICENSE
  13. +0
    -0
      reproduction/Char-aware_NLM/README.md
  14. +0
    -0
      reproduction/Char-aware_NLM/__init__.py
  15. +0
    -0
      reproduction/Char-aware_NLM/model.py
  16. +0
    -0
      reproduction/Char-aware_NLM/test.py
  17. +0
    -0
      reproduction/Char-aware_NLM/test.txt
  18. +0
    -0
      reproduction/Char-aware_NLM/train.py
  19. +0
    -0
      reproduction/Char-aware_NLM/train.txt
  20. +0
    -0
      reproduction/Char-aware_NLM/utilities.py
  21. +0
    -0
      reproduction/Char-aware_NLM/valid.txt
  22. +0
    -0
      reproduction/HAN-document_classification/README.md
  23. +0
    -0
      reproduction/HAN-document_classification/__init__.py
  24. +0
    -0
      reproduction/HAN-document_classification/data/test_samples.pkl
  25. +0
    -0
      reproduction/HAN-document_classification/data/train_samples.pkl
  26. +0
    -0
      reproduction/HAN-document_classification/data/yelp.word2vec
  27. +0
    -0
      reproduction/HAN-document_classification/evaluate.py
  28. +0
    -0
      reproduction/HAN-document_classification/model.py
  29. +0
    -0
      reproduction/HAN-document_classification/preprocess.py
  30. +0
    -0
      reproduction/HAN-document_classification/train.py

+ 27
- 9
fastNLP/action/tester.py View File

@@ -1,5 +1,6 @@
import _pickle

import numpy as np
import torch

from fastNLP.action.action import Action
@@ -16,8 +17,7 @@ class BaseTester(Action):
"""
super(BaseTester, self).__init__()
self.validate_in_training = test_args["validate_in_training"]
self.valid_x = None
self.valid_y = None
self.save_dev_data = None
self.save_output = test_args["save_output"]
self.output = None
self.save_loss = test_args["save_loss"]
@@ -26,8 +26,14 @@ class BaseTester(Action):
self.pickle_path = test_args["pickle_path"]
self.iterator = None

self.model = None
self.eval_history = []

def test(self, network):
# print("--------------testing----------------")
self.model = network

# turn on the testing mode; clean up the history
self.mode(network, test=True)

dev_data = self.prepare_input(self.pickle_path)
@@ -35,7 +41,6 @@ class BaseTester(Action):
self.iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True))

batch_output = list()
eval_history = list()
num_iter = len(dev_data) // self.batch_size

for step in range(num_iter):
@@ -47,11 +52,18 @@ class BaseTester(Action):
if self.save_output:
batch_output.append(prediction)
if self.save_loss:
eval_history.append(eval_results)
self.eval_history.append(eval_results)

def prepare_input(self, data_path):
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb"))
return data_dev
"""
Save the dev data once it is loaded. Can return directly next time.
:param data_path: str, the path to the pickle data for dev
:return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s).
"""
if self.save_dev_data is None:
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb"))
self.save_dev_data = data_dev
return self.save_dev_data

def batchify(self, data):
"""
@@ -99,11 +111,12 @@ class BaseTester(Action):
raise NotImplementedError

def mode(self, model, test=True):
"""To do: combine this function with Trainer"""
"""To do: combine this function with Trainer ?? """
if test:
model.eval()
else:
model.train()
self.eval_history.clear()


class POSTester(BaseTester):
@@ -115,6 +128,7 @@ class POSTester(BaseTester):
super(POSTester, self).__init__(test_args)
self.max_len = None
self.mask = None
self.batch_result = None

def data_forward(self, network, x):
"""To Do: combine with Trainer
@@ -132,5 +146,9 @@ class POSTester(BaseTester):
return y

def evaluate(self, predict, truth):
"""To Do: """
return 0
truth = torch.Tensor(truth)
loss, prediction = self.model.loss(predict, truth, self.mask, self.batch_size, self.max_len)
return loss.data

def matrices(self):
return np.mean(self.eval_history)

+ 2
- 0
fastNLP/action/trainer.py View File

@@ -89,6 +89,7 @@ class BaseTrainer(Action):
if data_dev is None:
raise RuntimeError("No validation data provided.")
validator.test(network)
print("[epoch {}] dev loss={:.2f}".format(epoch, validator.matrices()))

# finish training

@@ -386,6 +387,7 @@ class POSTrainer(BaseTrainer):
else:
self.define_loss()
loss, prediction = self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len)
# print("loss={:.2f}".format(loss.data))
return loss




+ 0
- 0
fastNLP/reproduction/__init__.py View File


fastNLP/reproduction/CNN-sentence_classification/.gitignore → reproduction/CNN-sentence_classification/.gitignore View File


fastNLP/reproduction/CNN-sentence_classification/README.md → reproduction/CNN-sentence_classification/README.md View File


fastNLP/reproduction/CNN-sentence_classification/__init__.py → reproduction/CNN-sentence_classification/__init__.py View File


fastNLP/reproduction/CNN-sentence_classification/dataset.py → reproduction/CNN-sentence_classification/dataset.py View File


fastNLP/reproduction/CNN-sentence_classification/model.py → reproduction/CNN-sentence_classification/model.py View File


fastNLP/reproduction/CNN-sentence_classification/rt-polaritydata/rt-polarity.neg → reproduction/CNN-sentence_classification/rt-polaritydata/rt-polarity.neg View File


fastNLP/reproduction/CNN-sentence_classification/rt-polaritydata/rt-polarity.pos → reproduction/CNN-sentence_classification/rt-polaritydata/rt-polarity.pos View File


fastNLP/reproduction/CNN-sentence_classification/train.py → reproduction/CNN-sentence_classification/train.py View File


fastNLP/reproduction/Char-aware_NLM/LICENSE → reproduction/Char-aware_NLM/LICENSE View File


fastNLP/reproduction/Char-aware_NLM/README.md → reproduction/Char-aware_NLM/README.md View File


fastNLP/reproduction/Char-aware_NLM/__init__.py → reproduction/Char-aware_NLM/__init__.py View File


fastNLP/reproduction/Char-aware_NLM/model.py → reproduction/Char-aware_NLM/model.py View File


fastNLP/reproduction/Char-aware_NLM/test.py → reproduction/Char-aware_NLM/test.py View File


fastNLP/reproduction/Char-aware_NLM/test.txt → reproduction/Char-aware_NLM/test.txt View File


fastNLP/reproduction/Char-aware_NLM/train.py → reproduction/Char-aware_NLM/train.py View File


fastNLP/reproduction/Char-aware_NLM/train.txt → reproduction/Char-aware_NLM/train.txt View File


fastNLP/reproduction/Char-aware_NLM/utilities.py → reproduction/Char-aware_NLM/utilities.py View File


fastNLP/reproduction/Char-aware_NLM/valid.txt → reproduction/Char-aware_NLM/valid.txt View File


fastNLP/reproduction/HAN-document_classification/README.md → reproduction/HAN-document_classification/README.md View File


fastNLP/reproduction/HAN-document_classification/__init__.py → reproduction/HAN-document_classification/__init__.py View File


fastNLP/reproduction/HAN-document_classification/data/test_samples.pkl → reproduction/HAN-document_classification/data/test_samples.pkl View File


fastNLP/reproduction/HAN-document_classification/data/train_samples.pkl → reproduction/HAN-document_classification/data/train_samples.pkl View File


fastNLP/reproduction/HAN-document_classification/data/yelp.word2vec → reproduction/HAN-document_classification/data/yelp.word2vec View File


fastNLP/reproduction/HAN-document_classification/evaluate.py → reproduction/HAN-document_classification/evaluate.py View File


fastNLP/reproduction/HAN-document_classification/model.py → reproduction/HAN-document_classification/model.py View File


fastNLP/reproduction/HAN-document_classification/preprocess.py → reproduction/HAN-document_classification/preprocess.py View File


fastNLP/reproduction/HAN-document_classification/train.py → reproduction/HAN-document_classification/train.py View File


Loading…
Cancel
Save