@@ -170,8 +170,8 @@ class BaseTrainer(Action): | |||
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||
... | |||
] | |||
:return batch_x: list. Each entry is a list of features of a sample. | |||
batch_y: list. Each entry is a list of labels of a sample. | |||
:return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||
""" | |||
if self.iterator is None: | |||
self.iterator = iter(Batchifier(RandomSampler(data), batch_size, drop_last=True)) | |||
@@ -325,7 +325,6 @@ class POSTrainer(BaseTrainer): | |||
self.num_classes = train_args.num_classes | |||
self.max_len = None | |||
self.mask = None | |||
self.batch_x = None | |||
def prepare_input(self, data_path): | |||
""" | |||
@@ -336,14 +335,18 @@ class POSTrainer(BaseTrainer): | |||
return data_train, data_dev, 0, 1 | |||
def data_forward(self, network, x): | |||
""" | |||
:param network: the PyTorch model | |||
:param x: list of list, [batch_size, max_len] | |||
:return y: [batch_size, num_classes] | |||
""" | |||
seq_len = [len(seq) for seq in x] | |||
x = torch.LongTensor(x) | |||
x = torch.Tensor(x).long() | |||
self.batch_size = x.size(0) | |||
self.max_len = x.size(1) | |||
self.mask = seq_mask(seq_len, self.max_len) | |||
x = network(x) | |||
self.batch_x = x | |||
return x | |||
y = network(x) | |||
return y | |||
def mode(self, test=False): | |||
if test: | |||
@@ -357,8 +360,8 @@ class POSTrainer(BaseTrainer): | |||
def get_loss(self, predict, truth): | |||
""" | |||
Compute loss given prediction and ground truth. | |||
:param predict: prediction label vector | |||
:param truth: ground truth label vector | |||
:param predict: prediction label vector, [batch_size, num_classes] | |||
:param truth: ground truth label vector, [batch_size, max_len] | |||
:return: a scalar | |||
""" | |||
if self.loss_func is None: | |||
@@ -366,7 +369,7 @@ class POSTrainer(BaseTrainer): | |||
self.loss_func = self.model.loss | |||
else: | |||
self.define_loss() | |||
return self.loss_func(self.batch_x, predict, self.mask, self.batch_size, self.max_len) | |||
return self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len) | |||
if __name__ == "__name__": | |||
@@ -12,7 +12,7 @@ class SeqLabeling(BaseModel): | |||
""" | |||
def __init__(self, hidden_dim, | |||
rnn_num_layerd, | |||
rnn_num_layer, | |||
num_classes, | |||
vocab_size, | |||
word_emb_dim=100, | |||
@@ -29,7 +29,7 @@ class SeqLabeling(BaseModel): | |||
self.num_classes = num_classes | |||
self.input_dim = word_emb_dim | |||
self.layers = rnn_num_layerd | |||
self.layers = rnn_num_layer | |||
self.hidden_dim = hidden_dim | |||
self.bi_direction = bi_direction | |||
self.dropout = dropout | |||
@@ -55,32 +55,26 @@ class SeqLabeling(BaseModel): | |||
self.crf = ContionalRandomField(num_classes) | |||
def forward(self, x): | |||
x = self.embedding(x) | |||
x, hidden = self.encode(x) | |||
x = self.aggregate(x) | |||
x = self.decode(x) | |||
return x | |||
def embedding(self, x): | |||
return self.Emb(x) | |||
def encode(self, x): | |||
return self.rnn(x) | |||
def aggregate(self, x): | |||
return x | |||
def decode(self, x): | |||
x = self.linear(x) | |||
return x | |||
""" | |||
:param x: LongTensor, [batch_size, mex_len] | |||
:return y: [batch_size, tag_size, tag_size] | |||
""" | |||
x = self.Emb(x) | |||
# [batch_size, max_len, word_emb_dim] | |||
x, hidden = self.rnn(x) | |||
# [batch_size, max_len, hidden_size * direction] | |||
y = self.linear(x) | |||
# [batch_size, max_len, num_classes] | |||
return y | |||
def loss(self, x, y, mask, batch_size, max_len): | |||
""" | |||
Negative log likelihood loss. | |||
:param x: | |||
:param y: | |||
:param seq_len: | |||
:param x: FloatTensor, [batch_size, tag_size, tag_size] | |||
:param y: LongTensor, [batch_size, max_len] | |||
:param mask: ByteTensor, [batch_size, max_len] | |||
:param batch_size: int | |||
:param max_len: int | |||
:return loss: | |||
prediction: | |||
""" | |||
@@ -82,7 +82,7 @@ class ContionalRandomField(nn.Module): | |||
def _glod_score(self, feats, tags, masks): | |||
""" | |||
Compute the score for the gold path. | |||
:param feats: FloatTensor, batch_size x tag_size x tag_size | |||
:param feats: FloatTensor, batch_size x max_len x tag_size | |||
:param tags: LongTensor, batch_size x max_len | |||
:param masks: ByteTensor, batch_size x max_len | |||
:return:FloatTensor, batch_size | |||
@@ -118,7 +118,7 @@ class ContionalRandomField(nn.Module): | |||
def forward(self, feats, tags, masks): | |||
""" | |||
Calculate the neg log likelihood | |||
:param feats:FloatTensor, batch_size x tag_size x tag_size | |||
:param feats:FloatTensor, batch_size x max_len x tag_size | |||
:param tags:LongTensor, batch_size x max_len | |||
:param masks:ByteTensor batch_size x max_len | |||
:return:FloatTensor, batch_size | |||