@@ -170,8 +170,8 @@ class BaseTrainer(Action): | |||||
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | [[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | ||||
... | ... | ||||
] | ] | ||||
:return batch_x: list. Each entry is a list of features of a sample. | |||||
batch_y: list. Each entry is a list of labels of a sample. | |||||
:return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||||
""" | """ | ||||
if self.iterator is None: | if self.iterator is None: | ||||
self.iterator = iter(Batchifier(RandomSampler(data), batch_size, drop_last=True)) | self.iterator = iter(Batchifier(RandomSampler(data), batch_size, drop_last=True)) | ||||
@@ -325,7 +325,6 @@ class POSTrainer(BaseTrainer): | |||||
self.num_classes = train_args.num_classes | self.num_classes = train_args.num_classes | ||||
self.max_len = None | self.max_len = None | ||||
self.mask = None | self.mask = None | ||||
self.batch_x = None | |||||
def prepare_input(self, data_path): | def prepare_input(self, data_path): | ||||
""" | """ | ||||
@@ -336,14 +335,18 @@ class POSTrainer(BaseTrainer): | |||||
return data_train, data_dev, 0, 1 | return data_train, data_dev, 0, 1 | ||||
def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
""" | |||||
:param network: the PyTorch model | |||||
:param x: list of list, [batch_size, max_len] | |||||
:return y: [batch_size, num_classes] | |||||
""" | |||||
seq_len = [len(seq) for seq in x] | seq_len = [len(seq) for seq in x] | ||||
x = torch.LongTensor(x) | |||||
x = torch.Tensor(x).long() | |||||
self.batch_size = x.size(0) | self.batch_size = x.size(0) | ||||
self.max_len = x.size(1) | self.max_len = x.size(1) | ||||
self.mask = seq_mask(seq_len, self.max_len) | self.mask = seq_mask(seq_len, self.max_len) | ||||
x = network(x) | |||||
self.batch_x = x | |||||
return x | |||||
y = network(x) | |||||
return y | |||||
def mode(self, test=False): | def mode(self, test=False): | ||||
if test: | if test: | ||||
@@ -357,8 +360,8 @@ class POSTrainer(BaseTrainer): | |||||
def get_loss(self, predict, truth): | def get_loss(self, predict, truth): | ||||
""" | """ | ||||
Compute loss given prediction and ground truth. | Compute loss given prediction and ground truth. | ||||
:param predict: prediction label vector | |||||
:param truth: ground truth label vector | |||||
:param predict: prediction label vector, [batch_size, num_classes] | |||||
:param truth: ground truth label vector, [batch_size, max_len] | |||||
:return: a scalar | :return: a scalar | ||||
""" | """ | ||||
if self.loss_func is None: | if self.loss_func is None: | ||||
@@ -366,7 +369,7 @@ class POSTrainer(BaseTrainer): | |||||
self.loss_func = self.model.loss | self.loss_func = self.model.loss | ||||
else: | else: | ||||
self.define_loss() | self.define_loss() | ||||
return self.loss_func(self.batch_x, predict, self.mask, self.batch_size, self.max_len) | |||||
return self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len) | |||||
if __name__ == "__name__": | if __name__ == "__name__": | ||||
@@ -12,7 +12,7 @@ class SeqLabeling(BaseModel): | |||||
""" | """ | ||||
def __init__(self, hidden_dim, | def __init__(self, hidden_dim, | ||||
rnn_num_layerd, | |||||
rnn_num_layer, | |||||
num_classes, | num_classes, | ||||
vocab_size, | vocab_size, | ||||
word_emb_dim=100, | word_emb_dim=100, | ||||
@@ -29,7 +29,7 @@ class SeqLabeling(BaseModel): | |||||
self.num_classes = num_classes | self.num_classes = num_classes | ||||
self.input_dim = word_emb_dim | self.input_dim = word_emb_dim | ||||
self.layers = rnn_num_layerd | |||||
self.layers = rnn_num_layer | |||||
self.hidden_dim = hidden_dim | self.hidden_dim = hidden_dim | ||||
self.bi_direction = bi_direction | self.bi_direction = bi_direction | ||||
self.dropout = dropout | self.dropout = dropout | ||||
@@ -55,32 +55,26 @@ class SeqLabeling(BaseModel): | |||||
self.crf = ContionalRandomField(num_classes) | self.crf = ContionalRandomField(num_classes) | ||||
def forward(self, x): | def forward(self, x): | ||||
x = self.embedding(x) | |||||
x, hidden = self.encode(x) | |||||
x = self.aggregate(x) | |||||
x = self.decode(x) | |||||
return x | |||||
def embedding(self, x): | |||||
return self.Emb(x) | |||||
def encode(self, x): | |||||
return self.rnn(x) | |||||
def aggregate(self, x): | |||||
return x | |||||
def decode(self, x): | |||||
x = self.linear(x) | |||||
return x | |||||
""" | |||||
:param x: LongTensor, [batch_size, mex_len] | |||||
:return y: [batch_size, tag_size, tag_size] | |||||
""" | |||||
x = self.Emb(x) | |||||
# [batch_size, max_len, word_emb_dim] | |||||
x, hidden = self.rnn(x) | |||||
# [batch_size, max_len, hidden_size * direction] | |||||
y = self.linear(x) | |||||
# [batch_size, max_len, num_classes] | |||||
return y | |||||
def loss(self, x, y, mask, batch_size, max_len): | def loss(self, x, y, mask, batch_size, max_len): | ||||
""" | """ | ||||
Negative log likelihood loss. | Negative log likelihood loss. | ||||
:param x: | |||||
:param y: | |||||
:param seq_len: | |||||
:param x: FloatTensor, [batch_size, tag_size, tag_size] | |||||
:param y: LongTensor, [batch_size, max_len] | |||||
:param mask: ByteTensor, [batch_size, max_len] | |||||
:param batch_size: int | |||||
:param max_len: int | |||||
:return loss: | :return loss: | ||||
prediction: | prediction: | ||||
""" | """ | ||||
@@ -82,7 +82,7 @@ class ContionalRandomField(nn.Module): | |||||
def _glod_score(self, feats, tags, masks): | def _glod_score(self, feats, tags, masks): | ||||
""" | """ | ||||
Compute the score for the gold path. | Compute the score for the gold path. | ||||
:param feats: FloatTensor, batch_size x tag_size x tag_size | |||||
:param feats: FloatTensor, batch_size x max_len x tag_size | |||||
:param tags: LongTensor, batch_size x max_len | :param tags: LongTensor, batch_size x max_len | ||||
:param masks: ByteTensor, batch_size x max_len | :param masks: ByteTensor, batch_size x max_len | ||||
:return:FloatTensor, batch_size | :return:FloatTensor, batch_size | ||||
@@ -118,7 +118,7 @@ class ContionalRandomField(nn.Module): | |||||
def forward(self, feats, tags, masks): | def forward(self, feats, tags, masks): | ||||
""" | """ | ||||
Calculate the neg log likelihood | Calculate the neg log likelihood | ||||
:param feats:FloatTensor, batch_size x tag_size x tag_size | |||||
:param feats:FloatTensor, batch_size x max_len x tag_size | |||||
:param tags:LongTensor, batch_size x max_len | :param tags:LongTensor, batch_size x max_len | ||||
:param masks:ByteTensor batch_size x max_len | :param masks:ByteTensor batch_size x max_len | ||||
:return:FloatTensor, batch_size | :return:FloatTensor, batch_size | ||||