add default parameters for modules, in order to decrease required params; update POSPreprocessor to return multi-level lists; add metrics.pytags/v0.1.0
@@ -3,9 +3,6 @@ class Inference(object): | |||
This is an interface focusing on predicting output based on trained models. | |||
It does not care about evaluations of the model. | |||
Possible improvements: | |||
- use batch to make use of GPU | |||
""" | |||
def __init__(self): | |||
@@ -0,0 +1,8 @@ | |||
""" | |||
To do: | |||
设计评判结果的各种指标。如果涉及向量,使用numpy。 | |||
参考http://scikit-learn.org/stable/modules/classes.html#classification-metrics | |||
建议是每种metric写成一个函数 (由Tester的evaluate函数调用) | |||
参数表里只需考虑基本的参数即可,可以没有像它那么多的参数配置 | |||
""" |
@@ -1,15 +1,16 @@ | |||
from torch import optim | |||
def get_torch_optimizor(params, alg_name='sgd', **args): | |||
''' | |||
construct pytorch optimizor by algorithm's name | |||
optimizor's argurments can be splicified, for different optimizor's argurments, please see pytorch doc | |||
def get_torch_optimizer(params, alg_name='sgd', **args): | |||
""" | |||
construct PyTorch optimizer by algorithm's name | |||
optimizer's arguments can be specified, for different optimizer's arguments, please see PyTorch doc | |||
usage: | |||
optimizor = get_torch_optimizor(model.parameters(), 'SGD', lr=0.01) | |||
optimizer = get_torch_optimizer(model.parameters(), 'SGD', lr=0.01) | |||
""" | |||
''' | |||
name = alg_name.lower() | |||
if name == 'adadelta': | |||
return optim.Adadelta(params, **args) | |||
@@ -28,22 +29,22 @@ def get_torch_optimizor(params, alg_name='sgd', **args): | |||
elif name == 'rprop': | |||
return optim.Rprop(params, **args) | |||
elif name == 'sgd': | |||
#SGD's parameter lr is required | |||
# SGD's parameter lr is required | |||
if 'lr' not in args: | |||
args['lr'] = 0.01 | |||
return optim.SGD(params, **args) | |||
elif name == 'sparseadam': | |||
return optim.SparseAdam(params, **args) | |||
else: | |||
raise TypeError('no such optimizor named {}'.format(alg_name)) | |||
raise TypeError('no such optimizer named {}'.format(alg_name)) | |||
# example usage | |||
if __name__ == '__main__': | |||
from torch.nn.modules import Linear | |||
net = Linear(2, 5) | |||
test1 = get_torch_optimizor(net.parameters(),'adam', lr=1e-2, weight_decay=1e-3) | |||
test1 = get_torch_optimizer(net.parameters(), 'adam', lr=1e-2, weight_decay=1e-3) | |||
print(test1) | |||
test2 = get_torch_optimizor(net.parameters(), 'SGD') | |||
print(test2) | |||
test2 = get_torch_optimizer(net.parameters(), 'SGD') | |||
print(test2) |
@@ -1,8 +1,8 @@ | |||
import _pickle | |||
import os | |||
import numpy as np | |||
import torch | |||
import os | |||
from fastNLP.action.action import Action | |||
from fastNLP.action.action import RandomSampler, Batchifier | |||
@@ -108,7 +108,7 @@ class BaseTester(Action): | |||
raise NotImplementedError | |||
@property | |||
def matrices(self): | |||
def metrics(self): | |||
raise NotImplementedError | |||
def mode(self, model, test=True): | |||
@@ -163,7 +163,7 @@ class POSTester(BaseTester): | |||
accuracy = float(torch.sum(results == truth.view((-1,)))) / results.shape[0] | |||
return [loss.data, accuracy] | |||
def matrices(self): | |||
def metrics(self): | |||
batch_loss = np.mean([x[0] for x in self.eval_history]) | |||
batch_accuracy = np.mean([x[1] for x in self.eval_history]) | |||
return batch_loss, batch_accuracy | |||
@@ -173,7 +173,7 @@ class POSTester(BaseTester): | |||
This is called by Trainer to print evaluation on dev set. | |||
:return print_str: str | |||
""" | |||
loss, accuracy = self.matrices() | |||
loss, accuracy = self.metrics() | |||
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) | |||
@@ -309,7 +309,7 @@ class ClassTester(BaseTester): | |||
y_prob = torch.nn.functional.softmax(y_logit, dim=-1) | |||
return [y_prob, y_true] | |||
def matrices(self): | |||
def metrics(self): | |||
"""Compute accuracy.""" | |||
y_prob, y_true = zip(*self.eval_history) | |||
y_prob = torch.cat(y_prob, dim=0) | |||
@@ -181,7 +181,7 @@ class BaseTrainer(Action): | |||
""" | |||
raise NotImplementedError | |||
def batchify(self, data): | |||
def batchify(self, data, output_length=True): | |||
""" | |||
1. Perform batching from data and produce a batch of training data. | |||
2. Add padding. | |||
@@ -194,13 +194,18 @@ class BaseTrainer(Action): | |||
] | |||
:return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||
seq_len: list. The length of the pre-padded sequence, if output_length is True. | |||
""" | |||
indices = next(self.iterator) | |||
batch = [data[idx] for idx in indices] | |||
batch_x = [sample[0] for sample in batch] | |||
batch_y = [sample[1] for sample in batch] | |||
batch_x = self.pad(batch_x) | |||
return batch_x, batch_y | |||
batch_x_pad = self.pad(batch_x) | |||
if output_length: | |||
seq_len = [len(x) for x in batch_x] | |||
return batch_x_pad, batch_y, seq_len | |||
else: | |||
return batch_x_pad, batch_y | |||
@staticmethod | |||
def pad(batch, fill=0): | |||
@@ -245,7 +250,10 @@ class ToyTrainer(BaseTrainer): | |||
return data_train, data_dev, 0, 1 | |||
def mode(self, test=False): | |||
self.model.mode(test) | |||
if test: | |||
self.model.eval() | |||
else: | |||
self.model.train() | |||
def data_forward(self, network, x): | |||
return network(x) | |||
@@ -333,7 +341,7 @@ class POSTrainer(BaseTrainer): | |||
return loss | |||
def best_eval_result(self, validator): | |||
loss, accuracy = validator.matrices() | |||
loss, accuracy = validator.metrics() | |||
if accuracy > self.best_accuracy: | |||
self.best_accuracy = accuracy | |||
return True | |||
@@ -11,8 +11,24 @@ class DatasetLoader(BaseLoader): | |||
class POSDatasetLoader(DatasetLoader): | |||
"""loader for pos data sets""" | |||
"""Dataset Loader for POS Tag datasets. | |||
In these datasets, each line are divided by '\t' | |||
while the first Col is the vocabulary and the second | |||
Col is the label. | |||
Different sentence are divided by an empty line. | |||
e.g: | |||
Tom label1 | |||
and label2 | |||
Jerry label1 | |||
. label3 | |||
Hello label4 | |||
world label5 | |||
! label3 | |||
In this file, there are two sentence "Tom and Jerry ." | |||
and "Hello world !". Each word has its own label from label1 | |||
to label5. | |||
""" | |||
def __init__(self, data_name, data_path): | |||
super(POSDatasetLoader, self).__init__(data_name, data_path) | |||
@@ -23,10 +39,42 @@ class POSDatasetLoader(DatasetLoader): | |||
return line | |||
def load_lines(self): | |||
assert (os.path.exists(self.data_path)) | |||
""" | |||
:return data: three-level list | |||
[ | |||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||
... | |||
] | |||
""" | |||
with open(self.data_path, "r", encoding="utf-8") as f: | |||
lines = f.readlines() | |||
return lines | |||
return self.parse(lines) | |||
@staticmethod | |||
def parse(lines): | |||
data = [] | |||
sentence = [] | |||
for line in lines: | |||
line = line.strip() | |||
if len(line) > 1: | |||
sentence.append(line.split('\t')) | |||
else: | |||
words = [] | |||
labels = [] | |||
for tokens in sentence: | |||
words.append(tokens[0]) | |||
labels.append(tokens[1]) | |||
data.append([words, labels]) | |||
sentence = [] | |||
if len(sentence) != 0: | |||
words = [] | |||
labels = [] | |||
for tokens in sentence: | |||
words.append(tokens[0]) | |||
labels.append(tokens[1]) | |||
data.append([words, labels]) | |||
return data | |||
class ClassDatasetLoader(DatasetLoader): | |||
@@ -112,3 +160,10 @@ class LMDatasetLoader(DatasetLoader): | |||
with open(self.data_path, "r", encoding="utf=8") as f: | |||
text = " ".join(f.readlines()) | |||
return text.strip().split() | |||
if __name__ == "__main__": | |||
data = POSDatasetLoader("xxx", "../../test/data_for_tests/people.txt").load_lines() | |||
for example in data: | |||
for w, l in zip(example[0], example[1]): | |||
print(w, l) |
@@ -7,6 +7,10 @@ DEFAULT_RESERVED_LABEL = ['<reserved-2>', | |||
'<reserved-3>', | |||
'<reserved-4>'] # dict index = 2~4 | |||
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, | |||
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, | |||
DEFAULT_RESERVED_LABEL[2]: 4} | |||
# the first vocab in dict with the index = 5 | |||
@@ -24,69 +28,86 @@ class BasePreprocess(object): | |||
class POSPreprocess(BasePreprocess): | |||
""" | |||
This class are used to preprocess the pos datasets. | |||
In these datasets, each line are divided by '\t' | |||
while the first Col is the vocabulary and the second | |||
Col is the label. | |||
Different sentence are divided by an empty line. | |||
e.g: | |||
Tom label1 | |||
and label2 | |||
Jerry label1 | |||
. label3 | |||
Hello label4 | |||
world label5 | |||
! label3 | |||
In this file, there are two sentence "Tom and Jerry ." | |||
and "Hello world !". Each word has its own label from label1 | |||
to label5. | |||
""" | |||
""" | |||
def __init__(self, data, pickle_path): | |||
def __init__(self, data, pickle_path="./", train_dev_split=0): | |||
""" | |||
Preprocess pipeline, including building mapping from words to index, from index to words, | |||
from labels/classes to index, from index to labels/classes. | |||
:param data: three-level list | |||
[ | |||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||
... | |||
] | |||
:param pickle_path: str, the directory to the pickle files. Default: "./" | |||
:param train_dev_split: float in [0, 1]. The ratio of dev data split from training data. Default: 0. | |||
To do: | |||
1. simplify __init__ | |||
""" | |||
super(POSPreprocess, self).__init__(data, pickle_path) | |||
self.word_dict = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, | |||
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, | |||
DEFAULT_RESERVED_LABEL[2]: 4} | |||
self.label_dict = None | |||
self.data = data | |||
self.pickle_path = pickle_path | |||
self.build_dict(data) | |||
if not self.pickle_exist("word2id.pkl"): | |||
self.word_dict.update(self.word2id(data)) | |||
file_name = os.path.join(self.pickle_path, "word2id.pkl") | |||
with open(file_name, "wb") as f: | |||
_pickle.dump(self.word_dict, f) | |||
if self.pickle_exist("word2id.pkl"): | |||
# load word2index because the construction of the following objects needs it | |||
with open(os.path.join(self.pickle_path, "word2id.pkl"), "rb") as f: | |||
self.word2index = _pickle.load(f) | |||
else: | |||
self.word2index, self.label2index = self.build_dict(data) | |||
with open(os.path.join(self.pickle_path, "word2id.pkl"), "wb") as f: | |||
_pickle.dump(self.word2index, f) | |||
self.vocab_size = self.id2word() | |||
self.class2id() | |||
self.num_classes = self.id2class() | |||
self.embedding() | |||
self.data_train() | |||
self.data_dev() | |||
self.data_test() | |||
if self.pickle_exist("class2id.pkl"): | |||
with open(os.path.join(self.pickle_path, "class2id.pkl"), "rb") as f: | |||
self.label2index = _pickle.load(f) | |||
else: | |||
with open(os.path.join(self.pickle_path, "class2id.pkl"), "wb") as f: | |||
_pickle.dump(self.label2index, f) | |||
#something will be wrong if word2id.pkl is found but class2id.pkl is not found | |||
if not self.pickle_exist("id2word.pkl"): | |||
index2word = self.build_reverse_dict(self.word2index) | |||
with open(os.path.join(self.pickle_path, "id2word.pkl"), "wb") as f: | |||
_pickle.dump(index2word, f) | |||
if not self.pickle_exist("id2class.pkl"): | |||
index2label = self.build_reverse_dict(self.label2index) | |||
with open(os.path.join(self.pickle_path, "word2id.pkl"), "wb") as f: | |||
_pickle.dump(index2label, f) | |||
if not self.pickle_exist("data_train.pkl"): | |||
data_train = self.to_index(data) | |||
if train_dev_split > 0 and not self.pickle_exist("data_dev.pkl"): | |||
data_dev = data_train[: int(len(data_train) * train_dev_split)] | |||
with open(os.path.join(self.pickle_path, "data_dev.pkl"), "wb") as f: | |||
_pickle.dump(data_dev, f) | |||
with open(os.path.join(self.pickle_path, "data_train.pkl"), "wb") as f: | |||
_pickle.dump(data_train, f) | |||
def build_dict(self, data): | |||
""" | |||
Add new words with indices into self.word_dict, new labels with indices into self.label_dict. | |||
:param data: list of list [word, label] | |||
:param data: three-level list | |||
[ | |||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||
... | |||
] | |||
:return word2index: dict of {str, int} | |||
label2index: dict of {str, int} | |||
""" | |||
self.label_dict = {} | |||
for line in data: | |||
line = line.strip() | |||
if len(line) <= 1: | |||
continue | |||
tokens = line.split('\t') | |||
if tokens[0] not in self.word_dict: | |||
# add (word, index) into the dict | |||
self.word_dict[tokens[0]] = len(self.word_dict) | |||
# for label in tokens[1: ]: | |||
if tokens[1] not in self.label_dict: | |||
self.label_dict[tokens[1]] = len(self.label_dict) | |||
label2index = {} | |||
word2index = DEFAULT_WORD_TO_INDEX | |||
for example in data: | |||
for word, label in zip(example[0], example[1]): | |||
if word not in word2index: | |||
word2index[word] = len(word2index) | |||
if label not in label2index: | |||
label2index[label] = len(label2index) | |||
return word2index, label2index | |||
def pickle_exist(self, pickle_name): | |||
""" | |||
@@ -101,90 +122,38 @@ class POSPreprocess(BasePreprocess): | |||
else: | |||
return False | |||
def word2id(self): | |||
if self.pickle_exist("word2id.pkl"): | |||
return | |||
# nothing will be done if word2id.pkl exists | |||
file_name = os.path.join(self.pickle_path, "word2id.pkl") | |||
with open(file_name, "wb") as f: | |||
_pickle.dump(self.word_dict, f) | |||
def id2word(self): | |||
if self.pickle_exist("id2word.pkl"): | |||
file_name = os.path.join(self.pickle_path, "id2word.pkl") | |||
id2word_dict = _pickle.load(open(file_name, "rb")) | |||
return len(id2word_dict) | |||
# nothing will be done if id2word.pkl exists | |||
def build_reverse_dict(self, word_dict): | |||
id2word = {word_dict[w]: w for w in word_dict} | |||
return id2word | |||
id2word_dict = {} | |||
for word in self.word_dict: | |||
id2word_dict[self.word_dict[word]] = word | |||
file_name = os.path.join(self.pickle_path, "id2word.pkl") | |||
with open(file_name, "wb") as f: | |||
_pickle.dump(id2word_dict, f) | |||
return len(id2word_dict) | |||
def class2id(self): | |||
if self.pickle_exist("class2id.pkl"): | |||
return | |||
# nothing will be done if class2id.pkl exists | |||
file_name = os.path.join(self.pickle_path, "class2id.pkl") | |||
with open(file_name, "wb") as f: | |||
_pickle.dump(self.label_dict, f) | |||
def id2class(self): | |||
if self.pickle_exist("id2class.pkl"): | |||
file_name = os.path.join(self.pickle_path, "id2class.pkl") | |||
id2class_dict = _pickle.load(open(file_name, "rb")) | |||
return len(id2class_dict) | |||
# nothing will be done if id2class.pkl exists | |||
id2class_dict = {} | |||
for label in self.label_dict: | |||
id2class_dict[self.label_dict[label]] = label | |||
file_name = os.path.join(self.pickle_path, "id2class.pkl") | |||
with open(file_name, "wb") as f: | |||
_pickle.dump(id2class_dict, f) | |||
return len(id2class_dict) | |||
def embedding(self): | |||
if self.pickle_exist("embedding.pkl"): | |||
return | |||
# nothing will be done if embedding.pkl exists | |||
def data_train(self): | |||
if self.pickle_exist("data_train.pkl"): | |||
return | |||
# nothing will be done if data_train.pkl exists | |||
data_train = [] | |||
sentence = [] | |||
for w in self.data: | |||
w = w.strip() | |||
if len(w) <= 1: | |||
wid = [] | |||
lid = [] | |||
for i in range(len(sentence)): | |||
# if sentence[i][0]=="": | |||
# print("") | |||
wid.append(self.word_dict[sentence[i][0]]) | |||
lid.append(self.label_dict[sentence[i][1]]) | |||
data_train.append((wid, lid)) | |||
sentence = [] | |||
continue | |||
sentence.append(w.split('\t')) | |||
file_name = os.path.join(self.pickle_path, "data_train.pkl") | |||
with open(file_name, "wb") as f: | |||
_pickle.dump(data_train, f) | |||
def data_dev(self): | |||
pass | |||
def data_test(self): | |||
pass | |||
def to_index(self, data): | |||
""" | |||
Convert word strings and label strings into indices. | |||
:param data: three-level list | |||
[ | |||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||
... | |||
] | |||
:return data_index: the shape of data, but each string is replaced by its corresponding index | |||
""" | |||
data_index = [] | |||
for example in data: | |||
word_list = [] | |||
label_list = [] | |||
for word, label in zip(example[0], example[1]): | |||
word_list.append(self.word2index[word]) | |||
label_list.append(self.label2index[label]) | |||
data_index.append([word_list, label_list]) | |||
return data_index | |||
@property | |||
def vocab_size(self): | |||
return len(self.word2index) | |||
@property | |||
def num_classes(self): | |||
return len(self.label2index) | |||
class ClassPreprocess(BasePreprocess): | |||
@@ -9,7 +9,7 @@ import torch.nn as nn | |||
class KMaxPool(nn.Module): | |||
"""K max-pooling module.""" | |||
def __init__(self, k): | |||
def __init__(self, k=1): | |||
super(KMaxPool, self).__init__() | |||
self.k = k | |||
@@ -1,9 +0,0 @@ | |||
from fastNLP.modules.aggregation.attention import Attention | |||
class LinearAttention(Attention): | |||
def __init__(self, normalize=False): | |||
super(LinearAttention, self).__init__(normalize) | |||
def _atten_forward(self, query, memory): | |||
raise NotImplementedError |
@@ -8,14 +8,15 @@ class SelfAttention(nn.Module): | |||
Self Attention Module. | |||
Args: | |||
input_size : the size for the input vector | |||
d_a : the width of weight matrix | |||
r : the number of encoded vectors | |||
input_size: int, the size for the input vector | |||
dim: int, the width of weight matrix. | |||
num_vec: int, the number of encoded vectors | |||
""" | |||
def __init__(self, input_size, d_a, r): | |||
def __init__(self, input_size, dim=10, num_vec=10): | |||
super(SelfAttention, self).__init__() | |||
self.W_s1 = nn.Parameter(torch.randn(d_a, input_size), requires_grad=True) | |||
self.W_s2 = nn.Parameter(torch.randn(r, d_a), requires_grad=True) | |||
self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True) | |||
self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True) | |||
self.softmax = nn.Softmax(dim=2) | |||
self.tanh = nn.Tanh() | |||
@@ -5,13 +5,15 @@ from torch import nn | |||
class ConvCharEmbedding(nn.Module): | |||
def __init__(self, char_emb_size, feature_maps=(40, 30, 30), kernels=(3, 4, 5)): | |||
def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5)): | |||
""" | |||
Character Level Word Embedding | |||
:param char_emb_size: the size of character level embedding, | |||
:param char_emb_size: the size of character level embedding. Default: 50 | |||
say 26 characters, each embedded to 50 dim vector, then the input_size is 50. | |||
:param feature_maps: table of feature maps (for each kernel width) | |||
:param kernels: table of kernel widths | |||
:param feature_maps: tuple of int. The length of the tuple is the number of convolution operations | |||
over characters. The i-th integer is the number of filters (dim of out channels) for the i-th | |||
convolution. | |||
:param kernels: tuple of int. The width of each kernel. | |||
""" | |||
super(ConvCharEmbedding, self).__init__() | |||
self.convs = nn.ModuleList([ | |||
@@ -23,29 +25,35 @@ class ConvCharEmbedding(nn.Module): | |||
:param x: [batch_size * sent_length, word_length, char_emb_size] | |||
:return: [batch_size * sent_length, sum(feature_maps), 1] | |||
""" | |||
x = x.contiguous().view(x.size(0), 1, x.size(1), x.size(2)) # [batch_size*sent_length, channel, width, height] | |||
x = x.transpose(2, 3) # [batch_size*sent_length, channel, height, width] | |||
x = x.contiguous().view(x.size(0), 1, x.size(1), x.size(2)) | |||
# [batch_size*sent_length, channel, width, height] | |||
x = x.transpose(2, 3) | |||
# [batch_size*sent_length, channel, height, width] | |||
return self.convolute(x).unsqueeze(2) | |||
def convolute(self, x): | |||
feats = [] | |||
for conv in self.convs: | |||
y = conv(x) # [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1] | |||
y = torch.squeeze(y, 2) # [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1] | |||
y = conv(x) | |||
# [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1] | |||
y = torch.squeeze(y, 2) | |||
# [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1] | |||
y = F.tanh(y) | |||
y, __ = torch.max(y, 2) # [batch_size*sent_length, feature_maps[i]] | |||
y, __ = torch.max(y, 2) | |||
# [batch_size*sent_length, feature_maps[i]] | |||
feats.append(y) | |||
return torch.cat(feats, 1) # [batch_size*sent_length, sum(feature_maps)] | |||
class LSTMCharEmbedding(nn.Module): | |||
""" | |||
Character Level Word Embedding with LSTM | |||
:param char_emb_size: the size of character level embedding, | |||
Character Level Word Embedding with LSTM with a single layer. | |||
:param char_emb_size: int, the size of character level embedding. Default: 50 | |||
say 26 characters, each embedded to 50 dim vector, then the input_size is 50. | |||
:param hidden_size: int, the number of hidden units. Default: equal to char_emb_size. | |||
""" | |||
def __init__(self, char_emb_size, hidden_size=None): | |||
def __init__(self, char_emb_size=50, hidden_size=None): | |||
super(LSTMCharEmbedding, self).__init__() | |||
self.hidden_size = char_emb_size if hidden_size is None else hidden_size | |||
@@ -2,12 +2,14 @@ | |||
# encoding: utf-8 | |||
import torch.nn as nn | |||
from torch.nn.init import xavier_uniform | |||
# import torch.nn.functional as F | |||
class Conv(nn.Module): | |||
""" | |||
Basic 1-d convolution module. | |||
initialize with xavier_uniform | |||
""" | |||
def __init__(self, in_channels, out_channels, kernel_size, | |||
@@ -23,6 +25,7 @@ class Conv(nn.Module): | |||
dilation=dilation, | |||
groups=groups, | |||
bias=bias) | |||
xavier_uniform(self.conv.weight) | |||
def forward(self, x): | |||
return self.conv(x) # [N,C,L] |
@@ -7,12 +7,13 @@ class Lookuptable(nn.Module): | |||
Args: | |||
nums : the size of the lookup table | |||
dims : the size of each vector | |||
dims : the size of each vector. Default: 50. | |||
padding_idx : pads the tensor with zeros whenever it encounters this index | |||
sparse : If True, gradient matrix will be a sparse tensor. In this case, | |||
only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used | |||
""" | |||
def __init__(self, nums, dims, padding_idx=0, sparse=False): | |||
def __init__(self, nums, dims=50, padding_idx=0, sparse=False): | |||
super(Lookuptable, self).__init__() | |||
self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse) | |||
@@ -8,11 +8,12 @@ class Lstm(nn.Module): | |||
Args: | |||
input_size : input size | |||
hidden_size : hidden size | |||
num_layers : number of hidden layers | |||
dropout : dropout rate | |||
bidirectional : If True, becomes a bidirectional RNN | |||
num_layers : number of hidden layers. Default: 1 | |||
dropout : dropout rate. Default: 0.5 | |||
bidirectional : If True, becomes a bidirectional RNN. Default: False. | |||
""" | |||
def __init__(self, input_size, hidden_size, num_layers, dropout, bidirectional): | |||
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.5, bidirectional=False): | |||
super(Lstm, self).__init__() | |||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, | |||
dropout=dropout, bidirectional=bidirectional) | |||
@@ -1,9 +1,23 @@ | |||
import unittest | |||
from fastNLP.loader.dataset_loader import POSDatasetLoader | |||
class MyTestCase(unittest.TestCase): | |||
def test_something(self): | |||
self.assertEqual(True, False) | |||
class TestPreprocess(unittest.TestCase): | |||
def test_case_1(self): | |||
data = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], | |||
["Hello", "world", "!"], ["T", "F", "F"]] | |||
pickle_path = "./data_for_tests/" | |||
# POSPreprocess(data, pickle_path) | |||
class TestDatasetLoader(unittest.TestCase): | |||
def test_case_1(self): | |||
data = """Tom\tT\nand\tF\nJerry\tT\n.\tF\n\nHello\tT\nworld\tF\n!\tF""" | |||
lines = data.split("\n") | |||
answer = POSDatasetLoader.parse(lines) | |||
truth = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], [["Hello", "world", "!"], ["T", "F", "F"]]] | |||
self.assertListEqual(answer, truth, "POS Dataset Loader") | |||
if __name__ == '__main__': | |||