@@ -1,87 +1,154 @@ | |||||
from collections import namedtuple | |||||
import _pickle | |||||
import numpy as np | import numpy as np | ||||
import torch | |||||
from fastNLP.action.action import Action | from fastNLP.action.action import Action | ||||
from fastNLP.action.action import RandomSampler, Batchifier | |||||
from fastNLP.modules.utils import seq_mask | |||||
class Tester(Action): | |||||
class BaseTester(Action): | |||||
"""docstring for Tester""" | """docstring for Tester""" | ||||
TestConfig = namedtuple("config", ["validate_in_training", "save_dev_input", "save_output", | |||||
"save_loss", "batch_size"]) | |||||
def __init__(self, test_args): | def __init__(self, test_args): | ||||
""" | """ | ||||
:param test_args: named tuple | :param test_args: named tuple | ||||
""" | """ | ||||
super(Tester, self).__init__() | |||||
self.validate_in_training = test_args.validate_in_training | |||||
self.save_dev_input = test_args.save_dev_input | |||||
self.valid_x = None | |||||
self.valid_y = None | |||||
self.save_output = test_args.save_output | |||||
super(BaseTester, self).__init__() | |||||
self.validate_in_training = test_args["validate_in_training"] | |||||
self.save_dev_data = None | |||||
self.save_output = test_args["save_output"] | |||||
self.output = None | self.output = None | ||||
self.save_loss = test_args.save_loss | |||||
self.save_loss = test_args["save_loss"] | |||||
self.mean_loss = None | self.mean_loss = None | ||||
self.batch_size = test_args.batch_size | |||||
def test(self, network, data): | |||||
print("testing") | |||||
network.mode(test=True) # turn on the testing mode | |||||
if self.save_dev_input: | |||||
if self.valid_x is None: | |||||
valid_x, valid_y = network.prepare_input(data) | |||||
self.valid_x = valid_x | |||||
self.valid_y = valid_y | |||||
else: | |||||
valid_x = self.valid_x | |||||
valid_y = self.valid_y | |||||
else: | |||||
valid_x, valid_y = network.prepare_input(data) | |||||
self.batch_size = test_args["batch_size"] | |||||
self.pickle_path = test_args["pickle_path"] | |||||
self.iterator = None | |||||
# split into batches by self.batch_size | |||||
iterations, test_batch_generator = self.batchify(self.batch_size, valid_x, valid_y) | |||||
self.model = None | |||||
self.eval_history = [] | |||||
batch_output = list() | |||||
loss_history = list() | |||||
# turn on the testing mode of the network | |||||
network.mode(test=True) | |||||
def test(self, network): | |||||
# print("--------------testing----------------") | |||||
self.model = network | |||||
# turn on the testing mode; clean up the history | |||||
self.mode(network, test=True) | |||||
for step in range(iterations): | |||||
batch_x, batch_y = test_batch_generator.__next__() | |||||
dev_data = self.prepare_input(self.pickle_path) | |||||
self.iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True)) | |||||
batch_output = list() | |||||
num_iter = len(dev_data) // self.batch_size | |||||
# forward pass from test input to predicted output | |||||
prediction = network.data_forward(batch_x) | |||||
for step in range(num_iter): | |||||
batch_x, batch_y = self.batchify(dev_data) | |||||
loss = network.get_loss(prediction, batch_y) | |||||
prediction = self.data_forward(network, batch_x) | |||||
eval_results = self.evaluate(prediction, batch_y) | |||||
if self.save_output: | if self.save_output: | ||||
batch_output.append(prediction.data) | |||||
batch_output.append(prediction) | |||||
if self.save_loss: | if self.save_loss: | ||||
loss_history.append(loss) | |||||
self.log(self.make_log(step, loss)) | |||||
self.eval_history.append(eval_results) | |||||
if self.save_loss: | |||||
self.mean_loss = np.mean(np.array(loss_history)) | |||||
if self.save_output: | |||||
self.output = self.make_output(batch_output) | |||||
@property | |||||
def loss(self): | |||||
return self.mean_loss | |||||
def prepare_input(self, data_path): | |||||
""" | |||||
Save the dev data once it is loaded. Can return directly next time. | |||||
:param data_path: str, the path to the pickle data for dev | |||||
:return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s). | |||||
""" | |||||
if self.save_dev_data is None: | |||||
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||||
self.save_dev_data = data_dev | |||||
return self.save_dev_data | |||||
@property | |||||
def result(self): | |||||
return self.output | |||||
def batchify(self, data): | |||||
""" | |||||
1. Perform batching from data and produce a batch of training data. | |||||
2. Add padding. | |||||
:param data: list. Each entry is a sample, which is also a list of features and label(s). | |||||
E.g. | |||||
[ | |||||
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||||
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||||
... | |||||
] | |||||
:return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||||
""" | |||||
indices = next(self.iterator) | |||||
batch = [data[idx] for idx in indices] | |||||
batch_x = [sample[0] for sample in batch] | |||||
batch_y = [sample[1] for sample in batch] | |||||
batch_x = self.pad(batch_x) | |||||
return batch_x, batch_y | |||||
@staticmethod | @staticmethod | ||||
def make_output(batch_outputs): | |||||
# construct full prediction with batch outputs | |||||
return np.concatenate(batch_outputs, axis=0) | |||||
def pad(batch, fill=0): | |||||
""" | |||||
Pad a batch of samples to maximum length. | |||||
:param batch: list of list | |||||
:param fill: word index to pad, default 0. | |||||
:return: a padded batch | |||||
""" | |||||
max_length = max([len(x) for x in batch]) | |||||
for idx, sample in enumerate(batch): | |||||
if len(sample) < max_length: | |||||
batch[idx] = sample + [fill * (max_length - len(sample))] | |||||
return batch | |||||
def load_config(self, args): | |||||
def data_forward(self, network, data): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def load_dataset(self, args): | |||||
def evaluate(self, predict, truth): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
@property | |||||
def matrices(self): | |||||
raise NotImplementedError | |||||
def mode(self, model, test=True): | |||||
"""To do: combine this function with Trainer ?? """ | |||||
if test: | |||||
model.eval() | |||||
else: | |||||
model.train() | |||||
self.eval_history.clear() | |||||
class POSTester(BaseTester): | |||||
""" | |||||
Tester for sequence labeling. | |||||
""" | |||||
def __init__(self, test_args): | |||||
super(POSTester, self).__init__(test_args) | |||||
self.max_len = None | |||||
self.mask = None | |||||
self.batch_result = None | |||||
def data_forward(self, network, x): | |||||
"""To Do: combine with Trainer | |||||
:param network: the PyTorch model | |||||
:param x: list of list, [batch_size, max_len] | |||||
:return y: [batch_size, num_classes] | |||||
""" | |||||
seq_len = [len(seq) for seq in x] | |||||
x = torch.Tensor(x).long() | |||||
self.batch_size = x.size(0) | |||||
self.max_len = x.size(1) | |||||
self.mask = seq_mask(seq_len, self.max_len) | |||||
y = network(x) | |||||
return y | |||||
def evaluate(self, predict, truth): | |||||
truth = torch.Tensor(truth) | |||||
loss, prediction = self.model.loss(predict, truth, self.mask, self.batch_size, self.max_len) | |||||
return loss.data | |||||
def matrices(self): | |||||
return np.mean(self.eval_history) |
@@ -1,12 +1,12 @@ | |||||
import _pickle | import _pickle | ||||
from collections import namedtuple | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from fastNLP.action.action import Action | from fastNLP.action.action import Action | ||||
from fastNLP.action.action import RandomSampler, Batchifier | from fastNLP.action.action import RandomSampler, Batchifier | ||||
from fastNLP.action.tester import Tester | |||||
from fastNLP.action.tester import POSTester | |||||
from fastNLP.modules.utils import seq_mask | |||||
class BaseTrainer(Action): | class BaseTrainer(Action): | ||||
@@ -21,23 +21,29 @@ class BaseTrainer(Action): | |||||
- grad_backward | - grad_backward | ||||
- get_loss | - get_loss | ||||
""" | """ | ||||
TrainConfig = namedtuple("config", ["epochs", "validate", "batch_size", "pickle_path"]) | |||||
def __init__(self, train_args): | def __init__(self, train_args): | ||||
""" | """ | ||||
training parameters | |||||
:param train_args: dict of (key, value) | |||||
The base trainer requires the following keys: | |||||
- epochs: int, the number of epochs in training | |||||
- validate: bool, whether or not to validate on dev set | |||||
- batch_size: int | |||||
- pickle_path: str, the path to pickle files for pre-processing | |||||
""" | """ | ||||
super(BaseTrainer, self).__init__() | super(BaseTrainer, self).__init__() | ||||
self.n_epochs = train_args.epochs | |||||
self.validate = train_args.validate | |||||
self.batch_size = train_args.batch_size | |||||
self.pickle_path = train_args.pickle_path | |||||
self.n_epochs = train_args["epochs"] | |||||
self.validate = train_args["validate"] | |||||
self.batch_size = train_args["batch_size"] | |||||
self.pickle_path = train_args["pickle_path"] | |||||
self.model = None | self.model = None | ||||
self.iterator = None | self.iterator = None | ||||
self.loss_func = None | self.loss_func = None | ||||
self.optimizer = None | |||||
def train(self, network): | def train(self, network): | ||||
"""General training loop. | |||||
"""General Training Steps | |||||
:param network: a model | :param network: a model | ||||
The method is framework independent. | The method is framework independent. | ||||
@@ -51,22 +57,27 @@ class BaseTrainer(Action): | |||||
- update | - update | ||||
Subclasses must implement these methods with a specific framework. | Subclasses must implement these methods with a specific framework. | ||||
""" | """ | ||||
# prepare model and data | |||||
self.model = network | self.model = network | ||||
data_train, data_dev, data_test, embedding = self.prepare_input(self.pickle_path) | data_train, data_dev, data_test, embedding = self.prepare_input(self.pickle_path) | ||||
test_args = Tester.TestConfig(save_output=True, validate_in_training=True, | |||||
save_dev_input=True, save_loss=True, batch_size=self.batch_size) | |||||
evaluator = Tester(test_args) | |||||
# define tester over dev data | |||||
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||||
"save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path} | |||||
validator = POSTester(valid_args) | |||||
best_loss = 1e10 | |||||
# main training epochs | |||||
iterations = len(data_train) // self.batch_size | iterations = len(data_train) // self.batch_size | ||||
for epoch in range(self.n_epochs): | for epoch in range(self.n_epochs): | ||||
self.mode(test=False) | |||||
# turn on network training mode; define optimizer; prepare batch iterator | |||||
self.mode(test=False) | |||||
self.define_optimizer() | self.define_optimizer() | ||||
self.iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=True)) | |||||
# training iterations in one epoch | |||||
for step in range(iterations): | for step in range(iterations): | ||||
batch_x, batch_y = self.batchify(self.batch_size, data_train) | |||||
batch_x, batch_y = self.batchify(data_train) | |||||
prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
@@ -77,9 +88,8 @@ class BaseTrainer(Action): | |||||
if self.validate: | if self.validate: | ||||
if data_dev is None: | if data_dev is None: | ||||
raise RuntimeError("No validation data provided.") | raise RuntimeError("No validation data provided.") | ||||
evaluator.test(network, data_dev) | |||||
if evaluator.loss < best_loss: | |||||
best_loss = evaluator.loss | |||||
validator.test(network) | |||||
print("[epoch {}] dev loss={:.2f}".format(epoch, validator.matrices())) | |||||
# finish training | # finish training | ||||
@@ -155,23 +165,20 @@ class BaseTrainer(Action): | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def batchify(self, batch_size, data): | |||||
def batchify(self, data): | |||||
""" | """ | ||||
1. Perform batching from data and produce a batch of training data. | 1. Perform batching from data and produce a batch of training data. | ||||
2. Add padding. | 2. Add padding. | ||||
:param batch_size: int, the size of a batch | |||||
:param data: list. Each entry is a sample, which is also a list of features and label(s). | :param data: list. Each entry is a sample, which is also a list of features and label(s). | ||||
E.g. | E.g. | ||||
[ | [ | ||||
[[feature_1, feature_2, feature_3], [label_1. label_2]], # sample 1 | |||||
[[feature_1, feature_2, feature_3], [label_1. label_2]], # sample 2 | |||||
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||||
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||||
... | ... | ||||
] | ] | ||||
:return batch_x: list. Each entry is a list of features of a sample. | |||||
batch_y: list. Each entry is a list of labels of a sample. | |||||
:return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||||
""" | """ | ||||
if self.iterator is None: | |||||
self.iterator = iter(Batchifier(RandomSampler(data), batch_size, drop_last=True)) | |||||
indices = next(self.iterator) | indices = next(self.iterator) | ||||
batch = [data[idx] for idx in indices] | batch = [data[idx] for idx in indices] | ||||
batch_x = [sample[0] for sample in batch] | batch_x = [sample[0] for sample in batch] | ||||
@@ -195,7 +202,9 @@ class BaseTrainer(Action): | |||||
class ToyTrainer(BaseTrainer): | class ToyTrainer(BaseTrainer): | ||||
"""A simple trainer for a PyTorch model.""" | |||||
""" | |||||
deprecated | |||||
""" | |||||
def __init__(self, train_args): | def __init__(self, train_args): | ||||
super(ToyTrainer, self).__init__(train_args) | super(ToyTrainer, self).__init__(train_args) | ||||
@@ -230,7 +239,7 @@ class ToyTrainer(BaseTrainer): | |||||
class WordSegTrainer(BaseTrainer): | class WordSegTrainer(BaseTrainer): | ||||
""" | """ | ||||
reserve for changes | |||||
deprecated | |||||
""" | """ | ||||
def __init__(self, train_args): | def __init__(self, train_args): | ||||
@@ -301,6 +310,7 @@ class WordSegTrainer(BaseTrainer): | |||||
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.85) | self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.85) | ||||
def get_loss(self, predict, truth): | def get_loss(self, predict, truth): | ||||
truth = torch.Tensor(truth) | |||||
self._loss = torch.nn.CrossEntropyLoss(predict, truth) | self._loss = torch.nn.CrossEntropyLoss(predict, truth) | ||||
return self._loss | return self._loss | ||||
@@ -313,8 +323,76 @@ class WordSegTrainer(BaseTrainer): | |||||
self.optimizer.step() | self.optimizer.step() | ||||
class POSTrainer(BaseTrainer): | |||||
""" | |||||
Trainer for Sequence Modeling | |||||
""" | |||||
def __init__(self, train_args): | |||||
super(POSTrainer, self).__init__(train_args) | |||||
self.vocab_size = train_args["vocab_size"] | |||||
self.num_classes = train_args["num_classes"] | |||||
self.max_len = None | |||||
self.mask = None | |||||
def prepare_input(self, data_path): | |||||
""" | |||||
To do: Load pkl files of train/dev/test and embedding | |||||
""" | |||||
data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||||
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||||
return data_train, data_dev, 0, 1 | |||||
def data_forward(self, network, x): | |||||
""" | |||||
:param network: the PyTorch model | |||||
:param x: list of list, [batch_size, max_len] | |||||
:return y: [batch_size, num_classes] | |||||
""" | |||||
seq_len = [len(seq) for seq in x] | |||||
x = torch.Tensor(x).long() | |||||
self.batch_size = x.size(0) | |||||
self.max_len = x.size(1) | |||||
self.mask = seq_mask(seq_len, self.max_len) | |||||
y = network(x) | |||||
return y | |||||
def mode(self, test=False): | |||||
if test: | |||||
self.model.eval() | |||||
else: | |||||
self.model.train() | |||||
def define_optimizer(self): | |||||
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) | |||||
def grad_backward(self, loss): | |||||
self.model.zero_grad() | |||||
loss.backward() | |||||
def update(self): | |||||
self.optimizer.step() | |||||
def get_loss(self, predict, truth): | |||||
""" | |||||
Compute loss given prediction and ground truth. | |||||
:param predict: prediction label vector, [batch_size, num_classes] | |||||
:param truth: ground truth label vector, [batch_size, max_len] | |||||
:return: a scalar | |||||
""" | |||||
truth = torch.Tensor(truth) | |||||
if self.loss_func is None: | |||||
if hasattr(self.model, "loss"): | |||||
self.loss_func = self.model.loss | |||||
else: | |||||
self.define_loss() | |||||
loss, prediction = self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len) | |||||
# print("loss={:.2f}".format(loss.data)) | |||||
return loss | |||||
if __name__ == "__name__": | if __name__ == "__name__": | ||||
train_args = BaseTrainer.TrainConfig(epochs=1, validate=False, batch_size=3, pickle_path="./") | |||||
train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"} | |||||
trainer = BaseTrainer(train_args) | trainer = BaseTrainer(train_args) | ||||
data_train = [[[1, 2, 3, 4], [0]] * 10] + [[[1, 3, 5, 2], [1]] * 10] | data_train = [[[1, 2, 3, 4], [0]] * 10] + [[[1, 3, 5, 2], [1]] * 10] | ||||
trainer.batchify(batch_size=3, data=data_train) | |||||
trainer.batchify(data=data_train) |
@@ -15,7 +15,6 @@ class POSDatasetLoader(DatasetLoader): | |||||
def __init__(self, data_name, data_path): | def __init__(self, data_name, data_path): | ||||
super(POSDatasetLoader, self).__init__(data_name, data_path) | super(POSDatasetLoader, self).__init__(data_name, data_path) | ||||
#self.data_set = self.load() | |||||
def load(self): | def load(self): | ||||
assert os.path.exists(self.data_path) | assert os.path.exists(self.data_path) | ||||
@@ -24,7 +23,7 @@ class POSDatasetLoader(DatasetLoader): | |||||
return line | return line | ||||
def load_lines(self): | def load_lines(self): | ||||
assert os.path.exists(self.data_path) | |||||
assert (os.path.exists(self.data_path)) | |||||
with open(self.data_path, "r", encoding="utf-8") as f: | with open(self.data_path, "r", encoding="utf-8") as f: | ||||
lines = f.readlines() | lines = f.readlines() | ||||
return lines | return lines | ||||
@@ -46,19 +46,17 @@ class BasePreprocess(object): | |||||
class POSPreprocess(BasePreprocess): | class POSPreprocess(BasePreprocess): | ||||
""" | """ | ||||
This class are used to preprocess the pos datasets. | This class are used to preprocess the pos datasets. | ||||
In these datasets, each line is divided by '\t' | |||||
The first Col is the vocabulary. | |||||
The second Col is the labels. | |||||
In these datasets, each line are divided by '\t' | |||||
while the first Col is the vocabulary and the second | |||||
Col is the label. | |||||
Different sentence are divided by an empty line. | Different sentence are divided by an empty line. | ||||
e.g: | e.g: | ||||
Tom label1 | Tom label1 | ||||
and label2 | and label2 | ||||
Jerry label1 | Jerry label1 | ||||
. label3 | . label3 | ||||
Hello label4 | Hello label4 | ||||
world label5 | world label5 | ||||
! label3 | ! label3 | ||||
@@ -71,11 +69,13 @@ class POSPreprocess(BasePreprocess): | |||||
super(POSPreprocess, self).__init__(data, pickle_path) | super(POSPreprocess, self).__init__(data, pickle_path) | ||||
self.word_dict = None | self.word_dict = None | ||||
self.label_dict = None | self.label_dict = None | ||||
self.data = data | |||||
self.pickle_path = pickle_path | |||||
self.build_dict() | self.build_dict() | ||||
self.word2id() | self.word2id() | ||||
self.id2word() | |||||
self.vocab_size = self.id2word() | |||||
self.class2id() | self.class2id() | ||||
self.id2class() | |||||
self.num_classes = self.id2class() | |||||
self.embedding() | self.embedding() | ||||
self.data_train() | self.data_train() | ||||
self.data_dev() | self.data_dev() | ||||
@@ -87,7 +87,8 @@ class POSPreprocess(BasePreprocess): | |||||
DEFAULT_RESERVED_LABEL[2]: 4} | DEFAULT_RESERVED_LABEL[2]: 4} | ||||
self.label_dict = {} | self.label_dict = {} | ||||
for w in self.data: | for w in self.data: | ||||
if len(w) == 0: | |||||
w = w.strip() | |||||
if len(w) <= 1: | |||||
continue | continue | ||||
word = w.split('\t') | word = w.split('\t') | ||||
@@ -95,10 +96,11 @@ class POSPreprocess(BasePreprocess): | |||||
index = len(self.word_dict) | index = len(self.word_dict) | ||||
self.word_dict[word[0]] = index | self.word_dict[word[0]] = index | ||||
for label in word[1: ]: | |||||
if label not in self.label_dict: | |||||
index = len(self.label_dict) | |||||
self.label_dict[label] = index | |||||
# for label in word[1: ]: | |||||
label = word[1] | |||||
if label not in self.label_dict: | |||||
index = len(self.label_dict) | |||||
self.label_dict[label] = index | |||||
def pickle_exist(self, pickle_name): | def pickle_exist(self, pickle_name): | ||||
""" | """ | ||||
@@ -107,7 +109,7 @@ class POSPreprocess(BasePreprocess): | |||||
""" | """ | ||||
if not os.path.exists(self.pickle_path): | if not os.path.exists(self.pickle_path): | ||||
os.makedirs(self.pickle_path) | os.makedirs(self.pickle_path) | ||||
file_name = self.pickle_path + pickle_name | |||||
file_name = os.path.join(self.pickle_path, pickle_name) | |||||
if os.path.exists(file_name): | if os.path.exists(file_name): | ||||
return True | return True | ||||
else: | else: | ||||
@@ -118,42 +120,48 @@ class POSPreprocess(BasePreprocess): | |||||
return | return | ||||
# nothing will be done if word2id.pkl exists | # nothing will be done if word2id.pkl exists | ||||
file_name = self.pickle_path + "word2id.pkl" | |||||
with open(file_name, "wb", encoding='utf-8') as f: | |||||
file_name = os.path.join(self.pickle_path, "word2id.pkl") | |||||
with open(file_name, "wb") as f: | |||||
_pickle.dump(self.word_dict, f) | _pickle.dump(self.word_dict, f) | ||||
def id2word(self): | def id2word(self): | ||||
if self.pickle_exist("id2word.pkl"): | if self.pickle_exist("id2word.pkl"): | ||||
return | |||||
file_name = os.path.join(self.pickle_path, "id2word.pkl") | |||||
id2word_dict = _pickle.load(open(file_name, "rb")) | |||||
return len(id2word_dict) | |||||
# nothing will be done if id2word.pkl exists | # nothing will be done if id2word.pkl exists | ||||
id2word_dict = {} | id2word_dict = {} | ||||
for word in self.word_dict: | for word in self.word_dict: | ||||
id2word_dict[self.word_dict[word]] = word | id2word_dict[self.word_dict[word]] = word | ||||
file_name = self.pickle_path + "id2word.pkl" | |||||
with open(file_name, "wb", encoding='utf-8') as f: | |||||
file_name = os.path.join(self.pickle_path, "id2word.pkl") | |||||
with open(file_name, "wb") as f: | |||||
_pickle.dump(id2word_dict, f) | _pickle.dump(id2word_dict, f) | ||||
return len(id2word_dict) | |||||
def class2id(self): | def class2id(self): | ||||
if self.pickle_exist("class2id.pkl"): | if self.pickle_exist("class2id.pkl"): | ||||
return | return | ||||
# nothing will be done if class2id.pkl exists | # nothing will be done if class2id.pkl exists | ||||
file_name = self.pickle_path + "class2id.pkl" | |||||
with open(file_name, "wb", encoding='utf-8') as f: | |||||
file_name = os.path.join(self.pickle_path, "class2id.pkl") | |||||
with open(file_name, "wb") as f: | |||||
_pickle.dump(self.label_dict, f) | _pickle.dump(self.label_dict, f) | ||||
def id2class(self): | def id2class(self): | ||||
if self.pickle_exist("id2class.pkl"): | if self.pickle_exist("id2class.pkl"): | ||||
return | |||||
file_name = os.path.join(self.pickle_path, "id2class.pkl") | |||||
id2class_dict = _pickle.load(open(file_name, "rb")) | |||||
return len(id2class_dict) | |||||
# nothing will be done if id2class.pkl exists | # nothing will be done if id2class.pkl exists | ||||
id2class_dict = {} | id2class_dict = {} | ||||
for label in self.label_dict: | for label in self.label_dict: | ||||
id2class_dict[self.label_dict[label]] = label | id2class_dict[self.label_dict[label]] = label | ||||
file_name = self.pickle_path + "id2class.pkl" | |||||
with open(file_name, "wb", encoding='utf-8') as f: | |||||
file_name = os.path.join(self.pickle_path, "id2class.pkl") | |||||
with open(file_name, "wb") as f: | |||||
_pickle.dump(id2class_dict, f) | _pickle.dump(id2class_dict, f) | ||||
return len(id2class_dict) | |||||
def embedding(self): | def embedding(self): | ||||
if self.pickle_exist("embedding.pkl"): | if self.pickle_exist("embedding.pkl"): | ||||
@@ -168,22 +176,26 @@ class POSPreprocess(BasePreprocess): | |||||
data_train = [] | data_train = [] | ||||
sentence = [] | sentence = [] | ||||
for w in self.data: | for w in self.data: | ||||
if len(w) == 0: | |||||
w = w.strip() | |||||
if len(w) <= 1: | |||||
wid = [] | wid = [] | ||||
lid = [] | lid = [] | ||||
for i in range(len(sentence)): | for i in range(len(sentence)): | ||||
# if sentence[i][0]=="": | |||||
# print("") | |||||
wid.append(self.word_dict[sentence[i][0]]) | wid.append(self.word_dict[sentence[i][0]]) | ||||
lid.append(self.label_dict[sentence[i][1]]) | lid.append(self.label_dict[sentence[i][1]]) | ||||
data_train.append((wid, lid)) | data_train.append((wid, lid)) | ||||
sentence = [] | sentence = [] | ||||
continue | |||||
sentence.append(w.split('\t')) | sentence.append(w.split('\t')) | ||||
file_name = self.pickle_path + "data_train.pkl" | |||||
with open(file_name, "wb", encoding='utf-8') as f: | |||||
file_name = os.path.join(self.pickle_path, "data_train.pkl") | |||||
with open(file_name, "wb") as f: | |||||
_pickle.dump(data_train, f) | _pickle.dump(data_train, f) | ||||
def data_dev(self): | def data_dev(self): | ||||
pass | pass | ||||
def data_test(self): | def data_test(self): | ||||
pass | |||||
pass |
@@ -3,32 +3,12 @@ import torch | |||||
class BaseModel(torch.nn.Module): | class BaseModel(torch.nn.Module): | ||||
"""Base PyTorch model for all models. | """Base PyTorch model for all models. | ||||
Three network modules presented: | |||||
- embedding module | |||||
- aggregation module | |||||
- output module | |||||
Subclasses must implement these three modules with "components". | |||||
To do: add some useful common features | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(BaseModel, self).__init__() | super(BaseModel, self).__init__() | ||||
def forward(self, *inputs): | |||||
x = self.encode(*inputs) | |||||
x = self.aggregation(x) | |||||
x = self.output(x) | |||||
return x | |||||
def encode(self, x): | |||||
raise NotImplementedError | |||||
def aggregation(self, x): | |||||
raise NotImplementedError | |||||
def output(self, x): | |||||
raise NotImplementedError | |||||
class Vocabulary(object): | class Vocabulary(object): | ||||
"""A look-up table that allows you to access `Lexeme` objects. The `Vocab` | """A look-up table that allows you to access `Lexeme` objects. The `Vocab` | ||||
@@ -93,3 +73,4 @@ class Token(object): | |||||
self.doc = doc | self.doc = doc | ||||
self.token = doc[offset] | self.token = doc[offset] | ||||
self.i = offset | self.i = offset | ||||
@@ -0,0 +1,97 @@ | |||||
import torch | |||||
import torch.nn as nn | |||||
from torch.nn import functional as F | |||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.modules.CRF import ContionalRandomField | |||||
class SeqLabeling(BaseModel): | |||||
""" | |||||
PyTorch Network for sequence labeling | |||||
""" | |||||
def __init__(self, hidden_dim, | |||||
rnn_num_layer, | |||||
num_classes, | |||||
vocab_size, | |||||
word_emb_dim=100, | |||||
init_emb=None, | |||||
rnn_mode="gru", | |||||
bi_direction=False, | |||||
dropout=0.5, | |||||
use_crf=True): | |||||
super(SeqLabeling, self).__init__() | |||||
self.Emb = nn.Embedding(vocab_size, word_emb_dim) | |||||
if init_emb: | |||||
self.Emb.weight = nn.Parameter(init_emb) | |||||
self.num_classes = num_classes | |||||
self.input_dim = word_emb_dim | |||||
self.layers = rnn_num_layer | |||||
self.hidden_dim = hidden_dim | |||||
self.bi_direction = bi_direction | |||||
self.dropout = dropout | |||||
self.mode = rnn_mode | |||||
if self.mode == "lstm": | |||||
self.rnn = nn.LSTM(self.input_dim, self.hidden_dim, self.layers, batch_first=True, | |||||
bidirectional=self.bi_direction, dropout=self.dropout) | |||||
elif self.mode == "gru": | |||||
self.rnn = nn.GRU(self.input_dim, self.hidden_dim, self.layers, batch_first=True, | |||||
bidirectional=self.bi_direction, dropout=self.dropout) | |||||
elif self.mode == "rnn": | |||||
self.rnn = nn.RNN(self.input_dim, self.hidden_dim, self.layers, batch_first=True, | |||||
bidirectional=self.bi_direction, dropout=self.dropout) | |||||
else: | |||||
raise Exception | |||||
if bi_direction: | |||||
self.linear = nn.Linear(self.hidden_dim * 2, self.num_classes) | |||||
else: | |||||
self.linear = nn.Linear(self.hidden_dim, self.num_classes) | |||||
self.use_crf = use_crf | |||||
if self.use_crf: | |||||
self.crf = ContionalRandomField(num_classes) | |||||
def forward(self, x): | |||||
""" | |||||
:param x: LongTensor, [batch_size, mex_len] | |||||
:return y: [batch_size, tag_size, tag_size] | |||||
""" | |||||
x = self.Emb(x) | |||||
# [batch_size, max_len, word_emb_dim] | |||||
x, hidden = self.rnn(x) | |||||
# [batch_size, max_len, hidden_size * direction] | |||||
y = self.linear(x) | |||||
# [batch_size, max_len, num_classes] | |||||
return y | |||||
def loss(self, x, y, mask, batch_size, max_len): | |||||
""" | |||||
Negative log likelihood loss. | |||||
:param x: FloatTensor, [batch_size, tag_size, tag_size] | |||||
:param y: LongTensor, [batch_size, max_len] | |||||
:param mask: ByteTensor, [batch_size, max_len] | |||||
:param batch_size: int | |||||
:param max_len: int | |||||
:return loss: | |||||
prediction: | |||||
""" | |||||
x = x.float() | |||||
y = y.long() | |||||
mask = mask.byte() | |||||
# print(x.shape, y.shape, mask.shape) | |||||
if self.use_crf: | |||||
total_loss = self.crf(x, y, mask) | |||||
tag_seq = self.crf.viterbi_decode(x, mask) | |||||
else: | |||||
# error | |||||
loss_function = nn.NLLLoss(ignore_index=0, size_average=False) | |||||
x = x.view(batch_size * max_len, -1) | |||||
score = F.log_softmax(x) | |||||
total_loss = loss_function(score, y.view(batch_size * max_len)) | |||||
_, tag_seq = torch.max(score) | |||||
tag_seq = tag_seq.view(batch_size, max_len) | |||||
return torch.mean(total_loss), tag_seq |
@@ -82,7 +82,7 @@ class ContionalRandomField(nn.Module): | |||||
def _glod_score(self, feats, tags, masks): | def _glod_score(self, feats, tags, masks): | ||||
""" | """ | ||||
Compute the score for the gold path. | Compute the score for the gold path. | ||||
:param feats: FloatTensor, batch_size x tag_size x tag_size | |||||
:param feats: FloatTensor, batch_size x max_len x tag_size | |||||
:param tags: LongTensor, batch_size x max_len | :param tags: LongTensor, batch_size x max_len | ||||
:param masks: ByteTensor, batch_size x max_len | :param masks: ByteTensor, batch_size x max_len | ||||
:return:FloatTensor, batch_size | :return:FloatTensor, batch_size | ||||
@@ -118,7 +118,7 @@ class ContionalRandomField(nn.Module): | |||||
def forward(self, feats, tags, masks): | def forward(self, feats, tags, masks): | ||||
""" | """ | ||||
Calculate the neg log likelihood | Calculate the neg log likelihood | ||||
:param feats:FloatTensor, batch_size x tag_size x tag_size | |||||
:param feats:FloatTensor, batch_size x max_len x tag_size | |||||
:param tags:LongTensor, batch_size x max_len | :param tags:LongTensor, batch_size x max_len | ||||
:param masks:ByteTensor batch_size x max_len | :param masks:ByteTensor batch_size x max_len | ||||
:return:FloatTensor, batch_size | :return:FloatTensor, batch_size | ||||
@@ -1,12 +1,13 @@ | |||||
import torch | |||||
import torch.nn as nn | |||||
import encoder | |||||
import time | |||||
import aggregation | import aggregation | ||||
import dataloader | |||||
import embedding | import embedding | ||||
import encoder | |||||
import predict | import predict | ||||
import torch | |||||
import torch.nn as nn | |||||
import torch.optim as optim | import torch.optim as optim | ||||
import time | |||||
import dataloader | |||||
WORD_NUM = 357361 | WORD_NUM = 357361 | ||||
WORD_SIZE = 100 | WORD_SIZE = 100 | ||||
@@ -16,6 +17,30 @@ R = 10 | |||||
MLP_HIDDEN = 2000 | MLP_HIDDEN = 2000 | ||||
CLASSES_NUM = 5 | CLASSES_NUM = 5 | ||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.action.trainer import BaseTrainer | |||||
class MyNet(BaseModel): | |||||
def __init__(self): | |||||
super(MyNet, self).__init__() | |||||
self.embedding = embedding.Lookuptable(WORD_NUM, WORD_SIZE) | |||||
self.encoder = encoder.Lstm(WORD_SIZE, HIDDEN_SIZE, 1, 0.5, True) | |||||
self.aggregation = aggregation.Selfattention(2 * HIDDEN_SIZE, D_A, R) | |||||
self.predict = predict.MLP(R * HIDDEN_SIZE * 2, MLP_HIDDEN, CLASSES_NUM) | |||||
self.penalty = None | |||||
def encode(self, x): | |||||
return self.encode(self.embedding(x)) | |||||
def aggregate(self, x): | |||||
x, self.penalty = self.aggregate(x) | |||||
return x | |||||
def decode(self, x): | |||||
return [self.predict(x), self.penalty] | |||||
class Net(nn.Module): | class Net(nn.Module): | ||||
""" | """ | ||||
A model for sentiment analysis using lstm and self-attention | A model for sentiment analysis using lstm and self-attention | ||||
@@ -34,6 +59,19 @@ class Net(nn.Module): | |||||
x = self.predict(x) | x = self.predict(x) | ||||
return x, penalty | return x, penalty | ||||
class MyTrainer(BaseTrainer): | |||||
def __init__(self, args): | |||||
super(MyTrainer, self).__init__(args) | |||||
self.optimizer = None | |||||
def define_optimizer(self): | |||||
self.optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) | |||||
def define_loss(self): | |||||
self.loss_func = nn.CrossEntropyLoss() | |||||
def train(model_dict=None, using_cuda=True, learning_rate=0.06,\ | def train(model_dict=None, using_cuda=True, learning_rate=0.06,\ | ||||
momentum=0.3, batch_size=32, epochs=5, coef=1.0, interval=10): | momentum=0.3, batch_size=32, epochs=5, coef=1.0, interval=10): | ||||
""" | """ | ||||
@@ -7,3 +7,9 @@ def mask_softmax(matrix, mask): | |||||
else: | else: | ||||
raise NotImplementedError | raise NotImplementedError | ||||
return result | return result | ||||
def seq_mask(seq_len, max_len): | |||||
mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)] | |||||
mask = torch.stack(mask, 1) | |||||
return mask |
@@ -1,3 +1,3 @@ | |||||
numpy==1.14.2 | |||||
numpy>=1.14.2 | |||||
torch==0.4.0 | torch==0.4.0 | ||||
torchvision==0.1.8 | |||||
torchvision>=0.1.8 |
@@ -0,0 +1,67 @@ | |||||
迈 B-v | |||||
向 E-v | |||||
充 B-v | |||||
满 E-v | |||||
希 B-n | |||||
望 E-n | |||||
的 S-u | |||||
新 S-a | |||||
世 B-n | |||||
纪 E-n | |||||
— B-w | |||||
— E-w | |||||
一 B-t | |||||
九 M-t | |||||
九 M-t | |||||
八 M-t | |||||
年 E-t | |||||
新 B-t | |||||
年 E-t | |||||
讲 B-n | |||||
话 E-n | |||||
( S-w | |||||
附 S-v | |||||
图 B-n | |||||
片 E-n | |||||
1 S-m | |||||
张 S-q | |||||
) S-w | |||||
中 B-nt | |||||
共 M-nt | |||||
中 M-nt | |||||
央 E-nt | |||||
总 B-n | |||||
书 M-n | |||||
记 E-n | |||||
、 S-w | |||||
国 B-n | |||||
家 E-n | |||||
主 B-n | |||||
席 E-n | |||||
江 B-nr | |||||
泽 M-nr | |||||
民 E-nr | |||||
( S-w | |||||
一 B-t | |||||
九 M-t | |||||
九 M-t | |||||
七 M-t | |||||
年 E-t | |||||
十 B-t | |||||
二 M-t | |||||
月 E-t | |||||
三 B-t | |||||
十 M-t | |||||
一 M-t | |||||
日 E-t | |||||
) S-w | |||||
1 B-t | |||||
2 M-t | |||||
月 E-t | |||||
3 B-t | |||||
1 M-t | |||||
日 E-t | |||||
, S-w |
@@ -0,0 +1,35 @@ | |||||
import sys | |||||
sys.path.append("..") | |||||
from fastNLP.action.trainer import POSTrainer | |||||
from fastNLP.loader.dataset_loader import POSDatasetLoader | |||||
from fastNLP.loader.preprocess import POSPreprocess | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
data_name = "people.txt" | |||||
data_path = "data_for_tests/people.txt" | |||||
pickle_path = "data_for_tests" | |||||
if __name__ == "__main__": | |||||
# Data Loader | |||||
pos = POSDatasetLoader(data_name, data_path) | |||||
train_data = pos.load_lines() | |||||
# Preprocessor | |||||
p = POSPreprocess(train_data, pickle_path) | |||||
vocab_size = p.vocab_size | |||||
num_classes = p.num_classes | |||||
# Trainer | |||||
train_args = {"epochs": 20, "batch_size": 1, "num_classes": num_classes, | |||||
"vocab_size": vocab_size, "pickle_path": pickle_path, "validate": True} | |||||
trainer = POSTrainer(train_args) | |||||
# Model | |||||
model = SeqLabeling(100, 1, num_classes, vocab_size, bi_direction=True) | |||||
# Start training | |||||
trainer.train(model) | |||||
print("Training finished!") |