@@ -2,6 +2,7 @@ import _pickle | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import os | |||||
from fastNLP.action.action import Action | from fastNLP.action.action import Action | ||||
from fastNLP.action.action import RandomSampler, Batchifier | from fastNLP.action.action import RandomSampler, Batchifier | ||||
@@ -174,3 +175,153 @@ class POSTester(BaseTester): | |||||
""" | """ | ||||
loss, accuracy = self.matrices() | loss, accuracy = self.matrices() | ||||
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) | return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) | ||||
class ClassTester(BaseTester): | |||||
"""Tester for classification.""" | |||||
def __init__(self, test_args): | |||||
""" | |||||
:param test_args: a dict-like object that has __getitem__ method, \ | |||||
can be accessed by "test_args["key_str"]" | |||||
""" | |||||
# super(ClassTester, self).__init__() | |||||
self.pickle_path = test_args["pickle_path"] | |||||
self.save_dev_data = None | |||||
self.output = None | |||||
self.mean_loss = None | |||||
self.iterator = None | |||||
if "test_name" in test_args: | |||||
self.test_name = test_args["test_name"] | |||||
else: | |||||
self.test_name = "data_test.pkl" | |||||
if "validate_in_training" in test_args: | |||||
self.validate_in_training = test_args["validate_in_training"] | |||||
else: | |||||
self.validate_in_training = False | |||||
if "save_output" in test_args: | |||||
self.save_output = test_args["save_output"] | |||||
else: | |||||
self.save_output = False | |||||
if "save_loss" in test_args: | |||||
self.save_loss = test_args["save_loss"] | |||||
else: | |||||
self.save_loss = True | |||||
if "batch_size" in test_args: | |||||
self.batch_size = test_args["batch_size"] | |||||
else: | |||||
self.batch_size = 50 | |||||
if "use_cuda" in test_args: | |||||
self.use_cuda = test_args["use_cuda"] | |||||
else: | |||||
self.use_cuda = True | |||||
if "max_len" in test_args: | |||||
self.max_len = test_args["max_len"] | |||||
else: | |||||
self.max_len = None | |||||
self.model = None | |||||
self.eval_history = [] | |||||
self.batch_output = [] | |||||
def test(self, network): | |||||
# prepare model | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
self.model = network.cuda() | |||||
else: | |||||
self.model = network | |||||
# no backward setting for model | |||||
for param in self.model.parameters(): | |||||
param.requires_grad = False | |||||
# turn on the testing mode; clean up the history | |||||
self.mode(network, test=True) | |||||
# prepare test data | |||||
data_test = self.prepare_input(self.pickle_path, self.test_name) | |||||
# data generator | |||||
self.iterator = iter(Batchifier( | |||||
RandomSampler(data_test), self.batch_size, drop_last=False)) | |||||
# test | |||||
n_batches = len(data_test) // self.batch_size | |||||
n_print = n_batches // 10 | |||||
step = 0 | |||||
for batch_x, batch_y in self.batchify(data_test, max_len=self.max_len): | |||||
prediction = self.data_forward(network, batch_x) | |||||
eval_results = self.evaluate(prediction, batch_y) | |||||
if self.save_output: | |||||
self.batch_output.append(prediction) | |||||
if self.save_loss: | |||||
self.eval_history.append(eval_results) | |||||
if step % n_print == 0: | |||||
print("step: {:>5}".format(step)) | |||||
step += 1 | |||||
def prepare_input(self, data_dir, file_name): | |||||
"""Prepare data.""" | |||||
file_path = os.path.join(data_dir, file_name) | |||||
with open(file_path, 'rb') as f: | |||||
data = _pickle.load(f) | |||||
return data | |||||
def batchify(self, data, max_len=None): | |||||
"""Batch and pad data.""" | |||||
for indices in self.iterator: | |||||
# generate batch and pad | |||||
batch = [data[idx] for idx in indices] | |||||
batch_x = [sample[0] for sample in batch] | |||||
batch_y = [sample[1] for sample in batch] | |||||
batch_x = self.pad(batch_x) | |||||
# convert to tensor | |||||
batch_x = torch.tensor(batch_x, dtype=torch.long) | |||||
batch_y = torch.tensor(batch_y, dtype=torch.long) | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
batch_x = batch_x.cuda() | |||||
batch_y = batch_y.cuda() | |||||
# trim data to max_len | |||||
if max_len is not None and batch_x.size(1) > max_len: | |||||
batch_x = batch_x[:, :max_len] | |||||
yield batch_x, batch_y | |||||
def data_forward(self, network, x): | |||||
"""Forward through network.""" | |||||
logits = network(x) | |||||
return logits | |||||
def evaluate(self, y_logit, y_true): | |||||
"""Return y_pred and y_true.""" | |||||
y_prob = torch.nn.functional.softmax(y_logit, dim=-1) | |||||
return [y_prob, y_true] | |||||
def matrices(self): | |||||
"""Compute accuracy.""" | |||||
y_prob, y_true = zip(*self.eval_history) | |||||
y_prob = torch.cat(y_prob, dim=0) | |||||
y_pred = torch.argmax(y_prob, dim=-1) | |||||
y_true = torch.cat(y_true, dim=0) | |||||
acc = float(torch.sum(y_pred == y_true)) / len(y_true) | |||||
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc | |||||
def mode(self, model, test=True): | |||||
"""To do: combine this function with Trainer ?? """ | |||||
if test: | |||||
model.eval() | |||||
else: | |||||
model.train() | |||||
self.eval_history.clear() |
@@ -2,6 +2,10 @@ import _pickle | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import torch.nn as nn | |||||
import os | |||||
from time import time | |||||
from datetime import timedelta | |||||
from fastNLP.action.action import Action | from fastNLP.action.action import Action | ||||
from fastNLP.action.action import RandomSampler, Batchifier | from fastNLP.action.action import RandomSampler, Batchifier | ||||
@@ -348,6 +352,201 @@ class LanguageModelTrainer(BaseTrainer): | |||||
pass | pass | ||||
class ClassTrainer(BaseTrainer): | |||||
"""Trainer for classification.""" | |||||
def __init__(self, train_args): | |||||
# super(ClassTrainer, self).__init__(train_args) | |||||
self.n_epochs = train_args["epochs"] | |||||
self.batch_size = train_args["batch_size"] | |||||
self.pickle_path = train_args["pickle_path"] | |||||
if "validate" in train_args: | |||||
self.validate = train_args["validate"] | |||||
else: | |||||
self.validate = False | |||||
if "learn_rate" in train_args: | |||||
self.learn_rate = train_args["learn_rate"] | |||||
else: | |||||
self.learn_rate = 1e-3 | |||||
if "momentum" in train_args: | |||||
self.momentum = train_args["momentum"] | |||||
else: | |||||
self.momentum = 0.9 | |||||
if "use_cuda" in train_args: | |||||
self.use_cuda = train_args["use_cuda"] | |||||
else: | |||||
self.use_cuda = True | |||||
self.model = None | |||||
self.iterator = None | |||||
self.loss_func = None | |||||
self.optimizer = None | |||||
def train(self, network): | |||||
"""General Training Steps | |||||
:param network: a model | |||||
The method is framework independent. | |||||
Work by calling the following methods: | |||||
- prepare_input | |||||
- mode | |||||
- define_optimizer | |||||
- data_forward | |||||
- get_loss | |||||
- grad_backward | |||||
- update | |||||
Subclasses must implement these methods with a specific framework. | |||||
""" | |||||
# prepare model and data, transfer model to gpu if available | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
self.model = network.cuda() | |||||
else: | |||||
self.model = network | |||||
data_train, data_dev, data_test, embedding = self.prepare_input( | |||||
self.pickle_path) | |||||
# define tester over dev data | |||||
# valid_args = { | |||||
# "save_output": True, "validate_in_training": True, | |||||
# "save_dev_input": True, "save_loss": True, | |||||
# "batch_size": self.batch_size, "pickle_path": self.pickle_path} | |||||
# validator = POSTester(valid_args) | |||||
# urn on network training mode, define loss and optimizer | |||||
self.define_loss() | |||||
self.define_optimizer() | |||||
self.mode(test=False) | |||||
# main training epochs | |||||
start = time() | |||||
n_samples = len(data_train) | |||||
n_batches = n_samples // self.batch_size | |||||
n_print = n_batches // 10 | |||||
for epoch in range(self.n_epochs): | |||||
# prepare batch iterator | |||||
self.iterator = iter(Batchifier( | |||||
RandomSampler(data_train), self.batch_size, drop_last=False)) | |||||
# training iterations in one epoch | |||||
step = 0 | |||||
for batch_x, batch_y in self.batchify(data_train): | |||||
prediction = self.data_forward(network, batch_x) | |||||
loss = self.get_loss(prediction, batch_y) | |||||
self.grad_backward(loss) | |||||
self.update() | |||||
if step % n_print == 0: | |||||
acc = self.get_acc(prediction, batch_y) | |||||
end = time() | |||||
diff = timedelta(seconds=round(end - start)) | |||||
print("epoch: {:>3} step: {:>4} loss: {:>4.2}" | |||||
" train acc: {:>5.1%} time: {}".format( | |||||
epoch, step, loss, acc, diff)) | |||||
step += 1 | |||||
# if self.validate: | |||||
# if data_dev is None: | |||||
# raise RuntimeError("No validation data provided.") | |||||
# validator.test(network) | |||||
# print("[epoch {}]".format(epoch), end=" ") | |||||
# print(validator.show_matrices()) | |||||
# finish training | |||||
def prepare_input(self, data_path): | |||||
""" | |||||
To do: Load pkl files of train/dev/test and embedding | |||||
""" | |||||
names = [ | |||||
"data_train.pkl", "data_dev.pkl", | |||||
"data_test.pkl", "embedding.pkl"] | |||||
files = [] | |||||
for name in names: | |||||
file_path = os.path.join(data_path, name) | |||||
if os.path.exists(file_path): | |||||
with open(file_path, 'rb') as f: | |||||
data = _pickle.load(f) | |||||
else: | |||||
data = [] | |||||
files.append(data) | |||||
return tuple(files) | |||||
def mode(self, test=False): | |||||
""" | |||||
Tell the network to be trained or not. | |||||
:param test: bool | |||||
""" | |||||
if test: | |||||
self.model.eval() | |||||
else: | |||||
self.model.train() | |||||
def define_loss(self): | |||||
""" | |||||
Assign an instance of loss function to self.loss_func | |||||
E.g. self.loss_func = nn.CrossEntropyLoss() | |||||
""" | |||||
if self.loss_func is None: | |||||
if hasattr(self.model, "loss"): | |||||
self.loss_func = self.model.loss | |||||
else: | |||||
self.loss_func = nn.CrossEntropyLoss() | |||||
def define_optimizer(self): | |||||
""" | |||||
Define framework-specific optimizer specified by the models. | |||||
""" | |||||
self.optimizer = torch.optim.SGD( | |||||
self.model.parameters(), | |||||
lr=self.learn_rate, | |||||
momentum=self.momentum) | |||||
def data_forward(self, network, x): | |||||
"""Forward through network.""" | |||||
logits = network(x) | |||||
return logits | |||||
def get_loss(self, predict, truth): | |||||
"""Calculate loss.""" | |||||
return self.loss_func(predict, truth) | |||||
def grad_backward(self, loss): | |||||
"""Compute gradient backward.""" | |||||
self.model.zero_grad() | |||||
loss.backward() | |||||
def update(self): | |||||
"""Apply gradient.""" | |||||
self.optimizer.step() | |||||
def batchify(self, data): | |||||
"""Batch and pad data.""" | |||||
for indices in self.iterator: | |||||
batch = [data[idx] for idx in indices] | |||||
batch_x = [sample[0] for sample in batch] | |||||
batch_y = [sample[1] for sample in batch] | |||||
batch_x = self.pad(batch_x) | |||||
batch_x = torch.tensor(batch_x, dtype=torch.long) | |||||
batch_y = torch.tensor(batch_y, dtype=torch.long) | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
batch_x = batch_x.cuda() | |||||
batch_y = batch_y.cuda() | |||||
yield batch_x, batch_y | |||||
def get_acc(self, y_logit, y_true): | |||||
"""Compute accuracy.""" | |||||
y_pred = torch.argmax(y_logit, dim=-1) | |||||
return int(torch.sum(y_true == y_pred)) / len(y_true) | |||||
if __name__ == "__name__": | if __name__ == "__name__": | ||||
train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"} | train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"} | ||||
trainer = BaseTrainer(train_args) | trainer = BaseTrainer(train_args) | ||||
@@ -29,11 +29,11 @@ class POSDatasetLoader(DatasetLoader): | |||||
return lines | return lines | ||||
class ClassificationDatasetLoader(DatasetLoader): | |||||
"""loader for classfication data sets""" | |||||
class ClassDatasetLoader(DatasetLoader): | |||||
"""Loader for classification data sets""" | |||||
def __init__(self, data_name, data_path): | def __init__(self, data_name, data_path): | ||||
super(ClassificationDatasetLoader, data_name).__init__() | |||||
super(ClassDatasetLoader, self).__init__(data_name, data_path) | |||||
def load(self): | def load(self): | ||||
assert os.path.exists(self.data_path) | assert os.path.exists(self.data_path) | ||||
@@ -44,16 +44,21 @@ class ClassificationDatasetLoader(DatasetLoader): | |||||
@staticmethod | @staticmethod | ||||
def parse(lines): | def parse(lines): | ||||
""" | """ | ||||
:param lines: lines from dataset | |||||
:return: list(list(list())): the three level of lists are | |||||
Params | |||||
lines: lines from dataset | |||||
Return | |||||
list(list(list())): the three level of lists are | |||||
words, sentence, and dataset | words, sentence, and dataset | ||||
""" | """ | ||||
dataset = list() | dataset = list() | ||||
for line in lines: | for line in lines: | ||||
label = line.split(" ")[0] | |||||
words = line.split(" ")[1:] | |||||
word = list([w for w in words]) | |||||
sentence = list([word, label]) | |||||
line = line.strip().split() | |||||
label = line[0] | |||||
words = line[1:] | |||||
if len(words) <= 1: | |||||
continue | |||||
sentence = [words, label] | |||||
dataset.append(sentence) | dataset.append(sentence) | ||||
return dataset | return dataset | ||||
@@ -187,6 +187,191 @@ class POSPreprocess(BasePreprocess): | |||||
pass | pass | ||||
class ClassPreprocess(BasePreprocess): | |||||
""" | |||||
Pre-process the classification datasets. | |||||
Params: | |||||
pickle_path - directory to save result of pre-processing | |||||
Saves: | |||||
word2id.pkl | |||||
id2word.pkl | |||||
class2id.pkl | |||||
id2class.pkl | |||||
embedding.pkl | |||||
data_train.pkl | |||||
data_dev.pkl | |||||
data_test.pkl | |||||
""" | |||||
def __init__(self, pickle_path): | |||||
# super(ClassPreprocess, self).__init__(data, pickle_path) | |||||
self.word_dict = None | |||||
self.label_dict = None | |||||
self.pickle_path = pickle_path # save directory | |||||
def process(self, data, save_name): | |||||
""" | |||||
Process data. | |||||
Params: | |||||
data - nested list, data = [sample1, sample2, ...], | |||||
sample = [sentence, label], sentence = [word1, word2, ...] | |||||
save_name - name of processed data, such as data_train.pkl | |||||
Returns: | |||||
vocab_size - vocabulary size | |||||
n_classes - number of classes | |||||
""" | |||||
self.build_dict(data) | |||||
self.word2id() | |||||
vocab_size = self.id2word() | |||||
self.class2id() | |||||
num_classes = self.id2class() | |||||
self.embedding() | |||||
self.data_generate(data, save_name) | |||||
return vocab_size, num_classes | |||||
def build_dict(self, data): | |||||
"""Build vocabulary.""" | |||||
# just read if word2id.pkl and class2id.pkl exists | |||||
if self.pickle_exist("word2id.pkl") and \ | |||||
self.pickle_exist("class2id.pkl"): | |||||
file_name = os.path.join(self.pickle_path, "word2id.pkl") | |||||
with open(file_name, 'rb') as f: | |||||
self.word_dict = _pickle.load(f) | |||||
file_name = os.path.join(self.pickle_path, "class2id.pkl") | |||||
with open(file_name, 'rb') as f: | |||||
self.label_dict = _pickle.load(f) | |||||
return | |||||
# build vocabulary from scratch if nothing exists | |||||
self.word_dict = { | |||||
DEFAULT_PADDING_LABEL: 0, | |||||
DEFAULT_UNKNOWN_LABEL: 1, | |||||
DEFAULT_RESERVED_LABEL[0]: 2, | |||||
DEFAULT_RESERVED_LABEL[1]: 3, | |||||
DEFAULT_RESERVED_LABEL[2]: 4} | |||||
self.label_dict = {} | |||||
# collect every word and label | |||||
for sent, label in data: | |||||
if len(sent) <= 1: | |||||
continue | |||||
if label not in self.label_dict: | |||||
index = len(self.label_dict) | |||||
self.label_dict[label] = index | |||||
for word in sent: | |||||
if word not in self.word_dict: | |||||
index = len(self.word_dict) | |||||
self.word_dict[word[0]] = index | |||||
def pickle_exist(self, pickle_name): | |||||
""" | |||||
Check whether a pickle file exists. | |||||
Params | |||||
pickle_name: the filename of target pickle file | |||||
Return | |||||
True if file exists else False | |||||
""" | |||||
if not os.path.exists(self.pickle_path): | |||||
os.makedirs(self.pickle_path) | |||||
file_name = os.path.join(self.pickle_path, pickle_name) | |||||
if os.path.exists(file_name): | |||||
return True | |||||
else: | |||||
return False | |||||
def word2id(self): | |||||
"""Save vocabulary of {word:id} mapping format.""" | |||||
# nothing will be done if word2id.pkl exists | |||||
if self.pickle_exist("word2id.pkl"): | |||||
return | |||||
file_name = os.path.join(self.pickle_path, "word2id.pkl") | |||||
with open(file_name, "wb") as f: | |||||
_pickle.dump(self.word_dict, f) | |||||
def id2word(self): | |||||
"""Save vocabulary of {id:word} mapping format.""" | |||||
# nothing will be done if id2word.pkl exists | |||||
if self.pickle_exist("id2word.pkl"): | |||||
file_name = os.path.join(self.pickle_path, "id2word.pkl") | |||||
with open(file_name, 'rb') as f: | |||||
id2word_dict = _pickle.load(f) | |||||
return len(id2word_dict) | |||||
id2word_dict = {self.word_dict[w]: w for w in self.word_dict} | |||||
file_name = os.path.join(self.pickle_path, "id2word.pkl") | |||||
with open(file_name, "wb") as f: | |||||
_pickle.dump(id2word_dict, f) | |||||
return len(id2word_dict) | |||||
def class2id(self): | |||||
"""Save mapping of {class:id}.""" | |||||
# nothing will be done if class2id.pkl exists | |||||
if self.pickle_exist("class2id.pkl"): | |||||
return | |||||
file_name = os.path.join(self.pickle_path, "class2id.pkl") | |||||
with open(file_name, "wb") as f: | |||||
_pickle.dump(self.label_dict, f) | |||||
def id2class(self): | |||||
"""Save mapping of {id:class}.""" | |||||
# nothing will be done if id2class.pkl exists | |||||
if self.pickle_exist("id2class.pkl"): | |||||
file_name = os.path.join(self.pickle_path, "id2class.pkl") | |||||
with open(file_name, "rb") as f: | |||||
id2class_dict = _pickle.load(f) | |||||
return len(id2class_dict) | |||||
id2class_dict = {self.label_dict[c]: c for c in self.label_dict} | |||||
file_name = os.path.join(self.pickle_path, "id2class.pkl") | |||||
with open(file_name, "wb") as f: | |||||
_pickle.dump(id2class_dict, f) | |||||
return len(id2class_dict) | |||||
def embedding(self): | |||||
"""Save embedding lookup table corresponding to vocabulary.""" | |||||
# nothing will be done if embedding.pkl exists | |||||
if self.pickle_exist("embedding.pkl"): | |||||
return | |||||
# retrieve vocabulary from pre-trained embedding (not implemented) | |||||
def data_generate(self, data_src, save_name): | |||||
"""Convert dataset from text to digit.""" | |||||
# nothing will be done if file exists | |||||
save_path = os.path.join(self.pickle_path, save_name) | |||||
if os.path.exists(save_path): | |||||
return | |||||
data = [] | |||||
# for every sample | |||||
for sent, label in data_src: | |||||
if len(sent) <= 1: | |||||
continue | |||||
label_id = self.label_dict[label] # label id | |||||
sent_id = [] # sentence ids | |||||
for word in sent: | |||||
if word in self.word_dict: | |||||
sent_id.append(self.word_dict[word]) | |||||
else: | |||||
sent_id.append(self.word_dict[DEFAULT_UNKNOWN_LABEL]) | |||||
data.append([sent_id, label_id]) | |||||
# save data | |||||
with open(save_path, "wb") as f: | |||||
_pickle.dump(data, f) | |||||
class LMPreprocess(BasePreprocess): | class LMPreprocess(BasePreprocess): | ||||
def __init__(self, data, pickle_path): | def __init__(self, data, pickle_path): | ||||
super(LMPreprocess, self).__init__(data, pickle_path) | super(LMPreprocess, self).__init__(data, pickle_path) |
@@ -0,0 +1,37 @@ | |||||
# python: 3.6 | |||||
# encoding: utf-8 | |||||
import torch.nn as nn | |||||
# import torch.nn.functional as F | |||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.modules.encoder.conv_maxpool import ConvMaxpool | |||||
class CNNText(BaseModel): | |||||
""" | |||||
Text classification model by character CNN, the implementation of paper | |||||
'Yoon Kim. 2014. Convolution Neural Networks for Sentence | |||||
Classification.' | |||||
""" | |||||
def __init__(self, class_num=9, | |||||
kernel_nums=[100, 100, 100], kernel_sizes=[3, 4, 5], | |||||
embed_num=1000, embed_dim=300, pretrained_embed=None, | |||||
drop_prob=0.5): | |||||
super(CNNText, self).__init__() | |||||
# no support for pre-trained embedding currently | |||||
self.embed = nn.Embedding(embed_num, embed_dim, padding_idx=0) | |||||
self.conv_pool = ConvMaxpool( | |||||
in_channels=embed_dim, | |||||
out_channels=kernel_nums, | |||||
kernel_sizes=kernel_sizes) | |||||
self.dropout = nn.Dropout(drop_prob) | |||||
self.fc = nn.Linear(sum(kernel_nums), class_num) | |||||
def forward(self, x): | |||||
x = self.embed(x) # [N,L] -> [N,L,C] | |||||
x = self.conv_pool(x) # [N,L,C] -> [N,C] | |||||
x = self.dropout(x) | |||||
x = self.fc(x) # [N,C] -> [N, N_class] | |||||
return x |
@@ -0,0 +1,53 @@ | |||||
# python: 3.6 | |||||
# encoding: utf-8 | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
class ConvMaxpool(nn.Module): | |||||
""" | |||||
Convolution and max-pooling module with multiple kernel sizes. | |||||
""" | |||||
def __init__(self, in_channels, out_channels, kernel_sizes, | |||||
stride=1, padding=0, dilation=1, | |||||
groups=1, bias=True, activation='relu'): | |||||
super(ConvMaxpool, self).__init__() | |||||
# convolution | |||||
if isinstance(kernel_sizes, (list, tuple, int)): | |||||
if isinstance(kernel_sizes, int): | |||||
out_channels = [out_channels] | |||||
kernel_sizes = [kernel_sizes] | |||||
self.convs = nn.ModuleList([nn.Conv1d( | |||||
in_channels=in_channels, | |||||
out_channels=oc, | |||||
kernel_size=ks, | |||||
stride=stride, | |||||
padding=padding, | |||||
dilation=dilation, | |||||
groups=groups, | |||||
bias=bias) | |||||
for oc, ks in zip(out_channels, kernel_sizes)]) | |||||
else: | |||||
raise Exception( | |||||
'Incorrect kernel sizes: should be list, tuple or int') | |||||
# activation function | |||||
if activation == 'relu': | |||||
self.activation = F.relu | |||||
else: | |||||
raise Exception( | |||||
"Undefined activation function: choose from: relu") | |||||
def forward(self, x): | |||||
# [N,L,C] -> [N,C,L] | |||||
x = torch.transpose(x, 1, 2) | |||||
# convolution | |||||
xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L]] | |||||
# max-pooling | |||||
xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) | |||||
for i in xs] # [[N, C]] | |||||
return torch.cat(xs, dim=-1) # [N,C] |