@@ -98,7 +98,7 @@ class SeqLabelField(Field): | |||||
super(SeqLabelField, self).__init__(is_target) | super(SeqLabelField, self).__init__(is_target) | ||||
self.label_seq = label_seq | self.label_seq = label_seq | ||||
self._index = None | self._index = None | ||||
def get_length(self): | def get_length(self): | ||||
return len(self.label_seq) | return len(self.label_seq) | ||||
@@ -111,7 +111,7 @@ class SeqLabelField(Field): | |||||
pads = [0] * (padding_length - self.get_length()) | pads = [0] * (padding_length - self.get_length()) | ||||
if self._index is None: | if self._index is None: | ||||
if self.get_length() == 0: | if self.get_length() == 0: | ||||
return pads | |||||
return torch.LongTensor(pads) | |||||
elif isinstance(self.label_seq[0], int): | elif isinstance(self.label_seq[0], int): | ||||
return torch.LongTensor(self.label_seq + pads) | return torch.LongTensor(self.label_seq + pads) | ||||
elif isinstance(self.label_seq[0], str): | elif isinstance(self.label_seq[0], str): | ||||
@@ -8,7 +8,7 @@ from fastNLP.loader.base_loader import BaseLoader | |||||
class ConfigLoader(BaseLoader): | class ConfigLoader(BaseLoader): | ||||
"""loader for configuration files""" | """loader for configuration files""" | ||||
def __int__(self, data_path): | |||||
def __init__(self, data_path): | |||||
super(ConfigLoader, self).__init__() | super(ConfigLoader, self).__init__() | ||||
self.config = self.parse(super(ConfigLoader, self).load(data_path)) | self.config = self.parse(super(ConfigLoader, self).load(data_path)) | ||||
@@ -0,0 +1,364 @@ | |||||
import sys, os | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
import copy | |||||
import numpy as np | |||||
import torch | |||||
from collections import defaultdict | |||||
from torch import nn | |||||
from torch.nn import functional as F | |||||
from fastNLP.modules.utils import initial_parameter | |||||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | |||||
from fastNLP.modules.dropout import TimestepDropout | |||||
def mst(scores): | |||||
""" | |||||
with some modification to support parser output for MST decoding | |||||
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | |||||
""" | |||||
length = scores.shape[0] | |||||
min_score = -np.inf | |||||
mask = np.zeros((length, length)) | |||||
np.fill_diagonal(mask, -np.inf) | |||||
scores = scores + mask | |||||
heads = np.argmax(scores, axis=1) | |||||
heads[0] = 0 | |||||
tokens = np.arange(1, length) | |||||
roots = np.where(heads[tokens] == 0)[0] + 1 | |||||
if len(roots) < 1: | |||||
root_scores = scores[tokens, 0] | |||||
head_scores = scores[tokens, heads[tokens]] | |||||
new_root = tokens[np.argmax(root_scores / head_scores)] | |||||
heads[new_root] = 0 | |||||
elif len(roots) > 1: | |||||
root_scores = scores[roots, 0] | |||||
scores[roots, 0] = 0 | |||||
new_heads = np.argmax(scores[roots][:, tokens], axis=1) + 1 | |||||
new_root = roots[np.argmin( | |||||
scores[roots, new_heads] / root_scores)] | |||||
heads[roots] = new_heads | |||||
heads[new_root] = 0 | |||||
edges = defaultdict(set) | |||||
vertices = set((0,)) | |||||
for dep, head in enumerate(heads[tokens]): | |||||
vertices.add(dep + 1) | |||||
edges[head].add(dep + 1) | |||||
for cycle in _find_cycle(vertices, edges): | |||||
dependents = set() | |||||
to_visit = set(cycle) | |||||
while len(to_visit) > 0: | |||||
node = to_visit.pop() | |||||
if node not in dependents: | |||||
dependents.add(node) | |||||
to_visit.update(edges[node]) | |||||
cycle = np.array(list(cycle)) | |||||
old_heads = heads[cycle] | |||||
old_scores = scores[cycle, old_heads] | |||||
non_heads = np.array(list(dependents)) | |||||
scores[np.repeat(cycle, len(non_heads)), | |||||
np.repeat([non_heads], len(cycle), axis=0).flatten()] = min_score | |||||
new_heads = np.argmax(scores[cycle][:, tokens], axis=1) + 1 | |||||
new_scores = scores[cycle, new_heads] / old_scores | |||||
change = np.argmax(new_scores) | |||||
changed_cycle = cycle[change] | |||||
old_head = old_heads[change] | |||||
new_head = new_heads[change] | |||||
heads[changed_cycle] = new_head | |||||
edges[new_head].add(changed_cycle) | |||||
edges[old_head].remove(changed_cycle) | |||||
return heads | |||||
def _find_cycle(vertices, edges): | |||||
""" | |||||
https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm | |||||
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/etc/tarjan.py | |||||
""" | |||||
_index = 0 | |||||
_stack = [] | |||||
_indices = {} | |||||
_lowlinks = {} | |||||
_onstack = defaultdict(lambda: False) | |||||
_SCCs = [] | |||||
def _strongconnect(v): | |||||
nonlocal _index | |||||
_indices[v] = _index | |||||
_lowlinks[v] = _index | |||||
_index += 1 | |||||
_stack.append(v) | |||||
_onstack[v] = True | |||||
for w in edges[v]: | |||||
if w not in _indices: | |||||
_strongconnect(w) | |||||
_lowlinks[v] = min(_lowlinks[v], _lowlinks[w]) | |||||
elif _onstack[w]: | |||||
_lowlinks[v] = min(_lowlinks[v], _indices[w]) | |||||
if _lowlinks[v] == _indices[v]: | |||||
SCC = set() | |||||
while True: | |||||
w = _stack.pop() | |||||
_onstack[w] = False | |||||
SCC.add(w) | |||||
if not(w != v): | |||||
break | |||||
_SCCs.append(SCC) | |||||
for v in vertices: | |||||
if v not in _indices: | |||||
_strongconnect(v) | |||||
return [SCC for SCC in _SCCs if len(SCC) > 1] | |||||
class GraphParser(nn.Module): | |||||
"""Graph based Parser helper class, support greedy decoding and MST(Maximum Spanning Tree) decoding | |||||
""" | |||||
def __init__(self): | |||||
super(GraphParser, self).__init__() | |||||
def forward(self, x): | |||||
raise NotImplementedError | |||||
def _greedy_decoder(self, arc_matrix, seq_mask=None): | |||||
_, seq_len, _ = arc_matrix.shape | |||||
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | |||||
_, heads = torch.max(matrix, dim=2) | |||||
if seq_mask is not None: | |||||
heads *= seq_mask.long() | |||||
return heads | |||||
def _mst_decoder(self, arc_matrix, seq_mask=None): | |||||
batch_size, seq_len, _ = arc_matrix.shape | |||||
matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix) | |||||
ans = matrix.new_zeros(batch_size, seq_len).long() | |||||
for i, graph in enumerate(matrix): | |||||
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) | |||||
if seq_mask is not None: | |||||
ans *= seq_mask.long() | |||||
return ans | |||||
class ArcBiaffine(nn.Module): | |||||
"""helper module for Biaffine Dependency Parser predicting arc | |||||
""" | |||||
def __init__(self, hidden_size, bias=True): | |||||
super(ArcBiaffine, self).__init__() | |||||
self.U = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=True) | |||||
self.has_bias = bias | |||||
if self.has_bias: | |||||
self.bias = nn.Parameter(torch.Tensor(hidden_size), requires_grad=True) | |||||
else: | |||||
self.register_parameter("bias", None) | |||||
initial_parameter(self) | |||||
def forward(self, head, dep): | |||||
""" | |||||
:param head arc-head tensor = [batch, length, emb_dim] | |||||
:param dep arc-dependent tensor = [batch, length, emb_dim] | |||||
:return output tensor = [bacth, length, length] | |||||
""" | |||||
output = dep.matmul(self.U) | |||||
output = output.bmm(head.transpose(-1, -2)) | |||||
if self.has_bias: | |||||
output += head.matmul(self.bias).unsqueeze(1) | |||||
return output | |||||
class LabelBilinear(nn.Module): | |||||
"""helper module for Biaffine Dependency Parser predicting label | |||||
""" | |||||
def __init__(self, in1_features, in2_features, num_label, bias=True): | |||||
super(LabelBilinear, self).__init__() | |||||
self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) | |||||
self.lin1 = nn.Linear(in1_features, num_label, bias=False) | |||||
self.lin2 = nn.Linear(in2_features, num_label, bias=False) | |||||
def forward(self, x1, x2): | |||||
output = self.bilinear(x1, x2) | |||||
output += self.lin1(x1) + self.lin2(x2) | |||||
return output | |||||
class BiaffineParser(GraphParser): | |||||
"""Biaffine Dependency Parser implemantation. | |||||
refer to ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) | |||||
<https://arxiv.org/abs/1611.01734>`_ . | |||||
""" | |||||
def __init__(self, | |||||
word_vocab_size, | |||||
word_emb_dim, | |||||
pos_vocab_size, | |||||
pos_emb_dim, | |||||
rnn_layers, | |||||
rnn_hidden_size, | |||||
arc_mlp_size, | |||||
label_mlp_size, | |||||
num_label, | |||||
dropout, | |||||
use_var_lstm=False, | |||||
use_greedy_infer=False): | |||||
super(BiaffineParser, self).__init__() | |||||
self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) | |||||
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | |||||
if use_var_lstm: | |||||
self.lstm = VarLSTM(input_size=word_emb_dim + pos_emb_dim, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
input_dropout=dropout, | |||||
hidden_dropout=dropout, | |||||
bidirectional=True) | |||||
else: | |||||
self.lstm = nn.LSTM(input_size=word_emb_dim + pos_emb_dim, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
dropout=dropout, | |||||
bidirectional=True) | |||||
rnn_out_size = 2 * rnn_hidden_size | |||||
self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), | |||||
nn.ELU()) | |||||
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) | |||||
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), | |||||
nn.ELU()) | |||||
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | |||||
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | |||||
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | |||||
self.normal_dropout = nn.Dropout(p=dropout) | |||||
self.timestep_dropout = TimestepDropout(p=dropout) | |||||
self.use_greedy_infer = use_greedy_infer | |||||
initial_parameter(self) | |||||
def forward(self, word_seq, pos_seq, seq_mask, gold_heads=None, **_): | |||||
""" | |||||
:param word_seq: [batch_size, seq_len] sequence of word's indices | |||||
:param pos_seq: [batch_size, seq_len] sequence of word's indices | |||||
:param seq_mask: [batch_size, seq_len] sequence of length masks | |||||
:param gold_heads: [batch_size, seq_len] sequence of golden heads | |||||
:return dict: parsing results | |||||
arc_pred: [batch_size, seq_len, seq_len] | |||||
label_pred: [batch_size, seq_len, seq_len] | |||||
seq_mask: [batch_size, seq_len] | |||||
head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | |||||
""" | |||||
# prepare embeddings | |||||
batch_size, seq_len = word_seq.shape | |||||
# print('forward {} {}'.format(batch_size, seq_len)) | |||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | |||||
# get sequence mask | |||||
seq_mask = seq_mask.long() | |||||
word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] | |||||
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] | |||||
x = torch.cat([word, pos], dim=2) # -> [N,L,C] | |||||
# lstm, extract features | |||||
feat, _ = self.lstm(x) # -> [N,L,C] | |||||
# for arc biaffine | |||||
# mlp, reduce dim | |||||
arc_dep = self.timestep_dropout(self.arc_dep_mlp(feat)) | |||||
arc_head = self.timestep_dropout(self.arc_head_mlp(feat)) | |||||
label_dep = self.timestep_dropout(self.label_dep_mlp(feat)) | |||||
label_head = self.timestep_dropout(self.label_head_mlp(feat)) | |||||
# biaffine arc classifier | |||||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | |||||
flip_mask = (seq_mask == 0) | |||||
arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||||
# use gold or predicted arc to predict label | |||||
if gold_heads is None: | |||||
# use greedy decoding in training | |||||
if self.training or self.use_greedy_infer: | |||||
heads = self._greedy_decoder(arc_pred, seq_mask) | |||||
else: | |||||
heads = self._mst_decoder(arc_pred, seq_mask) | |||||
head_pred = heads | |||||
else: | |||||
head_pred = None | |||||
heads = gold_heads | |||||
label_head = label_head[batch_range, heads].contiguous() | |||||
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | |||||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask} | |||||
if head_pred is not None: | |||||
res_dict['head_pred'] = head_pred | |||||
return res_dict | |||||
def loss(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **_): | |||||
""" | |||||
Compute loss. | |||||
:param arc_pred: [batch_size, seq_len, seq_len] | |||||
:param label_pred: [batch_size, seq_len, seq_len] | |||||
:param head_indices: [batch_size, seq_len] | |||||
:param head_labels: [batch_size, seq_len] | |||||
:param seq_mask: [batch_size, seq_len] | |||||
:return: loss value | |||||
""" | |||||
batch_size, seq_len, _ = arc_pred.shape | |||||
arc_logits = F.log_softmax(arc_pred, dim=2) | |||||
label_logits = F.log_softmax(label_pred, dim=2) | |||||
batch_index = torch.arange(start=0, end=batch_size, device=arc_logits.device).long().unsqueeze(1) | |||||
child_index = torch.arange(start=0, end=seq_len, device=arc_logits.device).long().unsqueeze(0) | |||||
arc_loss = arc_logits[batch_index, child_index, head_indices] | |||||
label_loss = label_logits[batch_index, child_index, head_labels] | |||||
arc_loss = arc_loss[:, 1:] | |||||
label_loss = label_loss[:, 1:] | |||||
float_mask = seq_mask[:, 1:].float() | |||||
length = (seq_mask.sum() - batch_size).float() | |||||
arc_nll = -(arc_loss*float_mask).sum() / length | |||||
label_nll = -(label_loss*float_mask).sum() / length | |||||
return arc_nll + label_nll | |||||
def evaluate(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **kwargs): | |||||
""" | |||||
Evaluate the performance of prediction. | |||||
:return dict: performance results. | |||||
head_pred_corrct: number of correct predicted heads. | |||||
label_pred_correct: number of correct predicted labels. | |||||
total_tokens: number of predicted tokens | |||||
""" | |||||
if 'head_pred' in kwargs: | |||||
head_pred = kwargs['head_pred'] | |||||
elif self.use_greedy_infer: | |||||
head_pred = self._greedy_decoder(arc_pred, seq_mask) | |||||
else: | |||||
head_pred = self._mst_decoder(arc_pred, seq_mask) | |||||
head_pred_correct = (head_pred == head_indices).long() * seq_mask | |||||
_, label_preds = torch.max(label_pred, dim=2) | |||||
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct | |||||
return {"head_pred_correct": head_pred_correct.sum(dim=1), | |||||
"label_pred_correct": label_pred_correct.sum(dim=1), | |||||
"total_tokens": seq_mask.sum(dim=1)} | |||||
def metrics(self, head_pred_correct, label_pred_correct, total_tokens, **_): | |||||
""" | |||||
Compute the metrics of model | |||||
:param head_pred_corrct: number of correct predicted heads. | |||||
:param label_pred_correct: number of correct predicted labels. | |||||
:param total_tokens: number of predicted tokens | |||||
:return dict: the metrics results | |||||
UAS: the head predicted accuracy | |||||
LAS: the label predicted accuracy | |||||
""" | |||||
return {"UAS": head_pred_correct.sum().float() / total_tokens.sum().float() * 100, | |||||
"LAS": label_pred_correct.sum().float() / total_tokens.sum().float() * 100} | |||||
@@ -0,0 +1,15 @@ | |||||
import torch | |||||
class TimestepDropout(torch.nn.Dropout): | |||||
"""This module accepts a `[batch_size, num_timesteps, embedding_dim)]` and use a single | |||||
dropout mask of shape `(batch_size, embedding_dim)` to apply on every time step. | |||||
""" | |||||
def forward(self, x): | |||||
dropout_mask = x.new_ones(x.shape[0], x.shape[-1]) | |||||
torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True) | |||||
dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim] | |||||
if self.inplace: | |||||
x *= dropout_mask | |||||
return | |||||
else: | |||||
return x * dropout_mask |
@@ -2,391 +2,14 @@ import math | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | |||||
from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend | |||||
from torch.nn.parameter import Parameter | |||||
from torch.nn.utils.rnn import PackedSequence | from torch.nn.utils.rnn import PackedSequence | ||||
# from fastNLP.modules.utils import initial_parameter | |||||
def default_initializer(hidden_size): | |||||
stdv = 1.0 / math.sqrt(hidden_size) | |||||
def forward(tensor): | |||||
nn.init.uniform_(tensor, -stdv, stdv) | |||||
return forward | |||||
def VarMaskedRecurrent(reverse=False): | |||||
def forward(input, hidden, cell, mask): | |||||
output = [] | |||||
steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0)) | |||||
for i in steps: | |||||
if mask is None or mask[i].data.min() > 0.5: | |||||
hidden = cell(input[i], hidden) | |||||
elif mask[i].data.max() > 0.5: | |||||
hidden_next = cell(input[i], hidden) | |||||
# hack to handle LSTM | |||||
if isinstance(hidden, tuple): | |||||
hx, cx = hidden | |||||
hp1, cp1 = hidden_next | |||||
hidden = (hx + (hp1 - hx) * mask[i], cx + (cp1 - cx) * mask[i]) | |||||
else: | |||||
hidden = hidden + (hidden_next - hidden) * mask[i] | |||||
# hack to handle LSTM | |||||
output.append(hidden[0] if isinstance(hidden, tuple) else hidden) | |||||
if reverse: | |||||
output.reverse() | |||||
output = torch.cat(output, 0).view(input.size(0), *output[0].size()) | |||||
return hidden, output | |||||
return forward | |||||
def StackedRNN(inners, num_layers, lstm=False): | |||||
num_directions = len(inners) | |||||
total_layers = num_layers * num_directions | |||||
def forward(input, hidden, cells, mask): | |||||
assert (len(cells) == total_layers) | |||||
next_hidden = [] | |||||
if lstm: | |||||
hidden = list(zip(*hidden)) | |||||
for i in range(num_layers): | |||||
all_output = [] | |||||
for j, inner in enumerate(inners): | |||||
l = i * num_directions + j | |||||
hy, output = inner(input, hidden[l], cells[l], mask) | |||||
next_hidden.append(hy) | |||||
all_output.append(output) | |||||
input = torch.cat(all_output, input.dim() - 1) | |||||
if lstm: | |||||
next_h, next_c = zip(*next_hidden) | |||||
next_hidden = ( | |||||
torch.cat(next_h, 0).view(total_layers, *next_h[0].size()), | |||||
torch.cat(next_c, 0).view(total_layers, *next_c[0].size()) | |||||
) | |||||
else: | |||||
next_hidden = torch.cat(next_hidden, 0).view(total_layers, *next_hidden[0].size()) | |||||
return next_hidden, input | |||||
return forward | |||||
def AutogradVarMaskedRNN(num_layers=1, batch_first=False, bidirectional=False, lstm=False): | |||||
rec_factory = VarMaskedRecurrent | |||||
if bidirectional: | |||||
layer = (rec_factory(), rec_factory(reverse=True)) | |||||
else: | |||||
layer = (rec_factory(),) | |||||
func = StackedRNN(layer, | |||||
num_layers, | |||||
lstm=lstm) | |||||
def forward(input, cells, hidden, mask): | |||||
if batch_first: | |||||
input = input.transpose(0, 1) | |||||
if mask is not None: | |||||
mask = mask.transpose(0, 1) | |||||
nexth, output = func(input, hidden, cells, mask) | |||||
if batch_first: | |||||
output = output.transpose(0, 1) | |||||
return output, nexth | |||||
return forward | |||||
def VarMaskedStep(): | |||||
def forward(input, hidden, cell, mask): | |||||
if mask is None or mask.data.min() > 0.5: | |||||
hidden = cell(input, hidden) | |||||
elif mask.data.max() > 0.5: | |||||
hidden_next = cell(input, hidden) | |||||
# hack to handle LSTM | |||||
if isinstance(hidden, tuple): | |||||
hx, cx = hidden | |||||
hp1, cp1 = hidden_next | |||||
hidden = (hx + (hp1 - hx) * mask, cx + (cp1 - cx) * mask) | |||||
else: | |||||
hidden = hidden + (hidden_next - hidden) * mask | |||||
# hack to handle LSTM | |||||
output = hidden[0] if isinstance(hidden, tuple) else hidden | |||||
return hidden, output | |||||
return forward | |||||
def StackedStep(layer, num_layers, lstm=False): | |||||
def forward(input, hidden, cells, mask): | |||||
assert (len(cells) == num_layers) | |||||
next_hidden = [] | |||||
if lstm: | |||||
hidden = list(zip(*hidden)) | |||||
for l in range(num_layers): | |||||
hy, output = layer(input, hidden[l], cells[l], mask) | |||||
next_hidden.append(hy) | |||||
input = output | |||||
if lstm: | |||||
next_h, next_c = zip(*next_hidden) | |||||
next_hidden = ( | |||||
torch.cat(next_h, 0).view(num_layers, *next_h[0].size()), | |||||
torch.cat(next_c, 0).view(num_layers, *next_c[0].size()) | |||||
) | |||||
else: | |||||
next_hidden = torch.cat(next_hidden, 0).view(num_layers, *next_hidden[0].size()) | |||||
return next_hidden, input | |||||
return forward | |||||
def AutogradVarMaskedStep(num_layers=1, lstm=False): | |||||
layer = VarMaskedStep() | |||||
func = StackedStep(layer, | |||||
num_layers, | |||||
lstm=lstm) | |||||
def forward(input, cells, hidden, mask): | |||||
nexth, output = func(input, hidden, cells, mask) | |||||
return output, nexth | |||||
return forward | |||||
class VarMaskedRNNBase(nn.Module): | |||||
def __init__(self, Cell, input_size, hidden_size, | |||||
num_layers=1, bias=True, batch_first=False, | |||||
dropout=(0, 0), bidirectional=False, initializer=None,initial_method = None, **kwargs): | |||||
super(VarMaskedRNNBase, self).__init__() | |||||
self.Cell = Cell | |||||
self.input_size = input_size | |||||
self.hidden_size = hidden_size | |||||
self.num_layers = num_layers | |||||
self.bias = bias | |||||
self.batch_first = batch_first | |||||
self.bidirectional = bidirectional | |||||
self.lstm = False | |||||
num_directions = 2 if bidirectional else 1 | |||||
self.all_cells = [] | |||||
for layer in range(num_layers): | |||||
for direction in range(num_directions): | |||||
layer_input_size = input_size if layer == 0 else hidden_size * num_directions | |||||
cell = self.Cell(layer_input_size, hidden_size, self.bias, p=dropout, initializer=initializer, **kwargs) | |||||
self.all_cells.append(cell) | |||||
self.add_module('cell%d' % (layer * num_directions + direction), cell) | |||||
initial_parameter(self, initial_method) | |||||
def reset_parameters(self): | |||||
for cell in self.all_cells: | |||||
cell.reset_parameters() | |||||
def reset_noise(self, batch_size): | |||||
for cell in self.all_cells: | |||||
cell.reset_noise(batch_size) | |||||
def forward(self, input, mask=None, hx=None): | |||||
batch_size = input.size(0) if self.batch_first else input.size(1) | |||||
if hx is None: | |||||
num_directions = 2 if self.bidirectional else 1 | |||||
hx = torch.tensor(input.data.new(self.num_layers * num_directions, batch_size, self.hidden_size).zero_(), | |||||
requires_grad=True) | |||||
if self.lstm: | |||||
hx = (hx, hx) | |||||
func = AutogradVarMaskedRNN(num_layers=self.num_layers, | |||||
batch_first=self.batch_first, | |||||
bidirectional=self.bidirectional, | |||||
lstm=self.lstm) | |||||
self.reset_noise(batch_size) | |||||
output, hidden = func(input, self.all_cells, hx, None if mask is None else mask.view(mask.size() + (1,))) | |||||
return output, hidden | |||||
def step(self, input, hx=None, mask=None): | |||||
''' | |||||
execute one step forward (only for one-directional RNN). | |||||
Args: | |||||
input (batch, input_size): input tensor of this step. | |||||
hx (num_layers, batch, hidden_size): the hidden state of last step. | |||||
mask (batch): the mask tensor of this step. | |||||
Returns: | |||||
output (batch, hidden_size): tensor containing the output of this step from the last layer of RNN. | |||||
hn (num_layers, batch, hidden_size): tensor containing the hidden state of this step | |||||
''' | |||||
assert not self.bidirectional, "step only cannot be applied to bidirectional RNN." | |||||
batch_size = input.size(0) | |||||
if hx is None: | |||||
hx = torch.tensor(input.data.new(self.num_layers, batch_size, self.hidden_size).zero_(), requires_grad=True) | |||||
if self.lstm: | |||||
hx = (hx, hx) | |||||
func = AutogradVarMaskedStep(num_layers=self.num_layers, lstm=self.lstm) | |||||
output, hidden = func(input, self.all_cells, hx, mask) | |||||
return output, hidden | |||||
class VarMaskedFastLSTM(VarMaskedRNNBase): | |||||
def __init__(self, *args, **kwargs): | |||||
super(VarMaskedFastLSTM, self).__init__(VarFastLSTMCell, *args, **kwargs) | |||||
self.lstm = True | |||||
class VarRNNCellBase(nn.Module): | |||||
def __repr__(self): | |||||
s = '{name}({input_size}, {hidden_size}' | |||||
if 'bias' in self.__dict__ and self.bias is not True: | |||||
s += ', bias={bias}' | |||||
if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh": | |||||
s += ', nonlinearity={nonlinearity}' | |||||
s += ')' | |||||
return s.format(name=self.__class__.__name__, **self.__dict__) | |||||
def reset_noise(self, batch_size): | |||||
""" | |||||
Should be overriden by all subclasses. | |||||
Args: | |||||
batch_size: (int) batch size of input. | |||||
""" | |||||
raise NotImplementedError | |||||
class VarFastLSTMCell(VarRNNCellBase): | |||||
""" | |||||
A long short-term memory (LSTM) cell with variational dropout. | |||||
.. math:: | |||||
\begin{array}{ll} | |||||
i = \mathrm{sigmoid}(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ | |||||
f = \mathrm{sigmoid}(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ | |||||
g = \tanh(W_{ig} x + b_{ig} + W_{hc} h + b_{hg}) \\ | |||||
o = \mathrm{sigmoid}(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ | |||||
c' = f * c + i * g \\ | |||||
h' = o * \tanh(c') \\ | |||||
\end{array} | |||||
""" | |||||
def __init__(self, input_size, hidden_size, bias=True, p=(0.5, 0.5), initializer=None,initial_method =None): | |||||
super(VarFastLSTMCell, self).__init__() | |||||
self.input_size = input_size | |||||
self.hidden_size = hidden_size | |||||
self.bias = bias | |||||
self.weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size)) | |||||
self.weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size)) | |||||
if bias: | |||||
self.bias_ih = Parameter(torch.Tensor(4 * hidden_size)) | |||||
self.bias_hh = Parameter(torch.Tensor(4 * hidden_size)) | |||||
else: | |||||
self.register_parameter('bias_ih', None) | |||||
self.register_parameter('bias_hh', None) | |||||
self.initializer = default_initializer(self.hidden_size) if initializer is None else initializer | |||||
self.reset_parameters() | |||||
p_in, p_hidden = p | |||||
if p_in < 0 or p_in > 1: | |||||
raise ValueError("input dropout probability has to be between 0 and 1, " | |||||
"but got {}".format(p_in)) | |||||
if p_hidden < 0 or p_hidden > 1: | |||||
raise ValueError("hidden state dropout probability has to be between 0 and 1, " | |||||
"but got {}".format(p_hidden)) | |||||
self.p_in = p_in | |||||
self.p_hidden = p_hidden | |||||
self.noise_in = None | |||||
self.noise_hidden = None | |||||
initial_parameter(self, initial_method) | |||||
def reset_parameters(self): | |||||
for weight in self.parameters(): | |||||
if weight.dim() == 1: | |||||
weight.data.zero_() | |||||
else: | |||||
self.initializer(weight.data) | |||||
def reset_noise(self, batch_size): | |||||
if self.training: | |||||
if self.p_in: | |||||
noise = self.weight_ih.data.new(batch_size, self.input_size) | |||||
self.noise_in = torch.tensor(noise.bernoulli_(1.0 - self.p_in) / (1.0 - self.p_in)) | |||||
else: | |||||
self.noise_in = None | |||||
if self.p_hidden: | |||||
noise = self.weight_hh.data.new(batch_size, self.hidden_size) | |||||
self.noise_hidden = torch.tensor(noise.bernoulli_(1.0 - self.p_hidden) / (1.0 - self.p_hidden)) | |||||
else: | |||||
self.noise_hidden = None | |||||
else: | |||||
self.noise_in = None | |||||
self.noise_hidden = None | |||||
def forward(self, input, hx): | |||||
return self.__forward( | |||||
input, hx, | |||||
self.weight_ih, self.weight_hh, | |||||
self.bias_ih, self.bias_hh, | |||||
self.noise_in, self.noise_hidden, | |||||
) | |||||
@staticmethod | |||||
def __forward(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None): | |||||
if noise_in is not None: | |||||
if input.is_cuda: | |||||
input = input * noise_in.cuda(input.get_device()) | |||||
else: | |||||
input = input * noise_in | |||||
if input.is_cuda: | |||||
w_ih = w_ih.cuda(input.get_device()) | |||||
w_hh = w_hh.cuda(input.get_device()) | |||||
hidden = [h.cuda(input.get_device()) for h in hidden] | |||||
b_ih = b_ih.cuda(input.get_device()) | |||||
b_hh = b_hh.cuda(input.get_device()) | |||||
igates = F.linear(input, w_ih.cuda(input.get_device())) | |||||
hgates = F.linear(hidden[0], w_hh) if noise_hidden is None \ | |||||
else F.linear(hidden[0] * noise_hidden.cuda(input.get_device()), w_hh) | |||||
state = fusedBackend.LSTMFused.apply | |||||
# print("use backend") | |||||
# use some magic function | |||||
return state(igates, hgates, hidden[1]) if b_ih is None else state(igates, hgates, hidden[1], b_ih, b_hh) | |||||
hx, cx = hidden | |||||
if noise_hidden is not None: | |||||
hx = hx * noise_hidden | |||||
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh) | |||||
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) | |||||
ingate = F.sigmoid(ingate) | |||||
forgetgate = F.sigmoid(forgetgate) | |||||
cellgate = F.tanh(cellgate) | |||||
outgate = F.sigmoid(outgate) | |||||
cy = (forgetgate * cx) + (ingate * cellgate) | |||||
hy = outgate * F.tanh(cy) | |||||
return hy, cy | |||||
from fastNLP.modules.utils import initial_parameter | |||||
class VarRnnCellWrapper(nn.Module): | class VarRnnCellWrapper(nn.Module): | ||||
"""Wrapper for normal RNN Cells, make it support variational dropout | |||||
""" | |||||
def __init__(self, cell, hidden_size, input_p, hidden_p): | def __init__(self, cell, hidden_size, input_p, hidden_p): | ||||
super(VarRnnCellWrapper, self).__init__() | super(VarRnnCellWrapper, self).__init__() | ||||
self.cell = cell | self.cell = cell | ||||
@@ -394,31 +17,26 @@ class VarRnnCellWrapper(nn.Module): | |||||
self.input_p = input_p | self.input_p = input_p | ||||
self.hidden_p = hidden_p | self.hidden_p = hidden_p | ||||
def forward(self, input, hidden): | |||||
def forward(self, input, hidden, mask_x=None, mask_h=None): | |||||
""" | """ | ||||
:param input: [seq_len, batch_size, input_size] | :param input: [seq_len, batch_size, input_size] | ||||
:param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] | :param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] | ||||
for other RNN, h_0, [batch_size, hidden_size] | for other RNN, h_0, [batch_size, hidden_size] | ||||
:param mask_x: [batch_size, input_size] dropout mask for input | |||||
:param mask_h: [batch_size, hidden_size] dropout mask for hidden | |||||
:return output: [seq_len, bacth_size, hidden_size] | :return output: [seq_len, bacth_size, hidden_size] | ||||
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | ||||
for other RNN, h_n, [batch_size, hidden_size] | for other RNN, h_n, [batch_size, hidden_size] | ||||
""" | """ | ||||
is_lstm = isinstance(hidden, tuple) | is_lstm = isinstance(hidden, tuple) | ||||
_, batch_size, input_size = input.shape | |||||
mask_x = input.new_ones((batch_size, input_size)) | |||||
mask_h = input.new_ones((batch_size, self.hidden_size)) | |||||
nn.functional.dropout(mask_x, p=self.input_p, training=self.training, inplace=True) | |||||
nn.functional.dropout(mask_h, p=self.hidden_p, training=self.training, inplace=True) | |||||
input_x = input * mask_x.unsqueeze(0) | |||||
input = input * mask_x.unsqueeze(0) if mask_x is not None else input | |||||
output_list = [] | output_list = [] | ||||
for x in input_x: | |||||
for x in input: | |||||
if is_lstm: | if is_lstm: | ||||
hx, cx = hidden | hx, cx = hidden | ||||
hidden = (hx * mask_h, cx) | |||||
hidden = (hx * mask_h, cx) if mask_h is not None else (hx, cx) | |||||
else: | else: | ||||
hidden *= mask_h | |||||
hidden *= mask_h if mask_h is not None else hidden | |||||
hidden = self.cell(x, hidden) | hidden = self.cell(x, hidden) | ||||
output_list.append(hidden[0] if is_lstm else hidden) | output_list.append(hidden[0] if is_lstm else hidden) | ||||
output = torch.stack(output_list, dim=0) | output = torch.stack(output_list, dim=0) | ||||
@@ -426,6 +44,10 @@ class VarRnnCellWrapper(nn.Module): | |||||
class VarRNNBase(nn.Module): | class VarRNNBase(nn.Module): | ||||
"""Implementation of Variational Dropout RNN network. | |||||
refer to `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) | |||||
https://arxiv.org/abs/1512.05287`. | |||||
""" | |||||
def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | ||||
bias=True, batch_first=False, | bias=True, batch_first=False, | ||||
input_dropout=0, hidden_dropout=0, bidirectional=False): | input_dropout=0, hidden_dropout=0, bidirectional=False): | ||||
@@ -446,6 +68,7 @@ class VarRNNBase(nn.Module): | |||||
input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions | input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions | ||||
cell = Cell(input_size, self.hidden_size, bias) | cell = Cell(input_size, self.hidden_size, bias) | ||||
self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout)) | self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout)) | ||||
initial_parameter(self) | |||||
def forward(self, input, hx=None): | def forward(self, input, hx=None): | ||||
is_packed = isinstance(input, PackedSequence) | is_packed = isinstance(input, PackedSequence) | ||||
@@ -466,6 +89,14 @@ class VarRNNBase(nn.Module): | |||||
if self.batch_first: | if self.batch_first: | ||||
input = input.transpose(0, 1) | input = input.transpose(0, 1) | ||||
batch_size = input.shape[1] | |||||
mask_x = input.new_ones((batch_size, self.input_size)) | |||||
mask_out = input.new_ones((batch_size, self.hidden_size * self.num_directions)) | |||||
mask_h = input.new_ones((batch_size, self.hidden_size)) | |||||
nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | |||||
nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | |||||
nn.functional.dropout(mask_h, p=self.hidden_dropout, training=self.training, inplace=True) | |||||
hidden_list = [] | hidden_list = [] | ||||
for layer in range(self.num_layers): | for layer in range(self.num_layers): | ||||
@@ -474,11 +105,13 @@ class VarRNNBase(nn.Module): | |||||
input_x = input if direction == 0 else input.flip(0) | input_x = input if direction == 0 else input.flip(0) | ||||
idx = self.num_directions * layer + direction | idx = self.num_directions * layer + direction | ||||
cell = self._all_cells[idx] | cell = self._all_cells[idx] | ||||
output_x, hidden_x = cell(input_x, (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx]) | |||||
hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] | |||||
mask_xi = mask_x if layer == 0 else mask_out | |||||
output_x, hidden_x = cell(input_x, hi, mask_xi, mask_h) | |||||
output_list.append(output_x if direction == 0 else output_x.flip(0)) | output_list.append(output_x if direction == 0 else output_x.flip(0)) | ||||
hidden_list.append(hidden_x) | hidden_list.append(hidden_x) | ||||
input = torch.cat(output_list, dim=-1) | input = torch.cat(output_list, dim=-1) | ||||
output = input.transpose(0, 1) if self.batch_first else input | output = input.transpose(0, 1) if self.batch_first else input | ||||
if is_lstm: | if is_lstm: | ||||
h_list, c_list = zip(*hidden_list) | h_list, c_list = zip(*hidden_list) | ||||
@@ -487,29 +120,27 @@ class VarRNNBase(nn.Module): | |||||
hidden = (hn, cn) | hidden = (hn, cn) | ||||
else: | else: | ||||
hidden = torch.stack(hidden_list, dim=0) | hidden = torch.stack(hidden_list, dim=0) | ||||
if is_packed: | if is_packed: | ||||
output = PackedSequence(output, batch_sizes) | output = PackedSequence(output, batch_sizes) | ||||
return output, hidden | |||||
return output, hidden | |||||
class VarLSTM(VarRNNBase): | class VarLSTM(VarRNNBase): | ||||
"""Variational Dropout LSTM. | |||||
""" | |||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) | super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) | ||||
class VarRNN(VarRNNBase): | |||||
"""Variational Dropout RNN. | |||||
""" | |||||
def __init__(self, *args, **kwargs): | |||||
super(VarRNN, self).__init__(mode="RNN", Cell=nn.RNNCell, *args, **kwargs) | |||||
if __name__ == '__main__': | |||||
net = VarLSTM(input_size=10, hidden_size=20, num_layers=3, batch_first=True, bidirectional=True, input_dropout=0.33, hidden_dropout=0.33) | |||||
lstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=3, batch_first=True, bidirectional=True) | |||||
x = torch.randn(2, 8, 10) | |||||
y, hidden = net(x) | |||||
y0, h0 = lstm(x) | |||||
print(y.shape) | |||||
print(y0.shape) | |||||
print(y) | |||||
print(hidden[0]) | |||||
print(hidden[0].shape) | |||||
print(y0) | |||||
print(h0[0]) | |||||
print(h0[0].shape) | |||||
class VarGRU(VarRNNBase): | |||||
"""Variational Dropout GRU. | |||||
""" | |||||
def __init__(self, *args, **kwargs): | |||||
super(VarGRU, self).__init__(mode="GRU", Cell=nn.GRUCell, *args, **kwargs) |
@@ -0,0 +1,37 @@ | |||||
[train] | |||||
epochs = 50 | |||||
batch_size = 16 | |||||
pickle_path = "./save/" | |||||
validate = true | |||||
save_best_dev = false | |||||
use_cuda = true | |||||
model_saved_path = "./save/" | |||||
task = "parse" | |||||
[test] | |||||
save_output = true | |||||
validate_in_training = true | |||||
save_dev_input = false | |||||
save_loss = true | |||||
batch_size = 16 | |||||
pickle_path = "./save/" | |||||
use_cuda = true | |||||
task = "parse" | |||||
[model] | |||||
word_vocab_size = -1 | |||||
word_emb_dim = 100 | |||||
pos_vocab_size = -1 | |||||
pos_emb_dim = 100 | |||||
rnn_layers = 3 | |||||
rnn_hidden_size = 400 | |||||
arc_mlp_size = 500 | |||||
label_mlp_size = 100 | |||||
num_label = -1 | |||||
dropout = 0.33 | |||||
use_var_lstm=true | |||||
use_greedy_infer=false | |||||
[optim] | |||||
lr = 2e-3 |
@@ -0,0 +1,260 @@ | |||||
import os | |||||
import sys | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
from collections import defaultdict | |||||
import math | |||||
import torch | |||||
from fastNLP.core.trainer import Trainer | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.field import TextField, SeqLabelField | |||||
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||||
from fastNLP.core.tester import Tester | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.loader.model_loader import ModelLoader | |||||
from fastNLP.loader.embed_loader import EmbedLoader | |||||
from fastNLP.models.biaffine_parser import BiaffineParser | |||||
from fastNLP.saver.model_saver import ModelSaver | |||||
# not in the file's dir | |||||
if len(os.path.dirname(__file__)) != 0: | |||||
os.chdir(os.path.dirname(__file__)) | |||||
class MyDataLoader(object): | |||||
def __init__(self, pickle_path): | |||||
self.pickle_path = pickle_path | |||||
def load(self, path, word_v=None, pos_v=None, headtag_v=None): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet(name='conll') | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_one(sample) | |||||
if word_v is not None: | |||||
word_v.update(res[0]) | |||||
pos_v.update(res[1]) | |||||
headtag_v.update(res[3]) | |||||
ds.append(Instance(word_seq=TextField(res[0], is_target=False), | |||||
pos_seq=TextField(res[1], is_target=False), | |||||
head_indices=SeqLabelField(res[2], is_target=True), | |||||
head_labels=TextField(res[3], is_target=True), | |||||
seq_mask=SeqLabelField([1 for _ in range(len(res[0]))], is_target=False))) | |||||
return ds | |||||
def get_one(self, sample): | |||||
text = ['<root>'] | |||||
pos_tags = ['<root>'] | |||||
heads = [0] | |||||
head_tags = ['root'] | |||||
for w in sample: | |||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
if t3 == '_': | |||||
continue | |||||
text.append(t1) | |||||
pos_tags.append(t2) | |||||
heads.append(int(t3)) | |||||
head_tags.append(t4) | |||||
return (text, pos_tags, heads, head_tags) | |||||
def index_data(self, dataset, word_v, pos_v, tag_v): | |||||
dataset.index_field('word_seq', word_v) | |||||
dataset.index_field('pos_seq', pos_v) | |||||
dataset.index_field('head_labels', tag_v) | |||||
# datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT" | |||||
datadir = "/home/yfshao/UD_English-EWT" | |||||
cfgfile = './cfg.cfg' | |||||
train_data_name = "en_ewt-ud-train.conllu" | |||||
dev_data_name = "en_ewt-ud-dev.conllu" | |||||
emb_file_name = '/home/yfshao/glove.6B.100d.txt' | |||||
processed_datadir = './save' | |||||
# Config Loader | |||||
train_args = ConfigSection() | |||||
test_args = ConfigSection() | |||||
model_args = ConfigSection() | |||||
optim_args = ConfigSection() | |||||
ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) | |||||
# Data Loader | |||||
def save_data(dirpath, **kwargs): | |||||
import _pickle | |||||
if not os.path.exists(dirpath): | |||||
os.mkdir(dirpath) | |||||
for name, data in kwargs.items(): | |||||
with open(os.path.join(dirpath, name+'.pkl'), 'wb') as f: | |||||
_pickle.dump(data, f) | |||||
def load_data(dirpath): | |||||
import _pickle | |||||
datas = {} | |||||
for f_name in os.listdir(dirpath): | |||||
if not f_name.endswith('.pkl'): | |||||
continue | |||||
name = f_name[:-4] | |||||
with open(os.path.join(dirpath, f_name), 'rb') as f: | |||||
datas[name] = _pickle.load(f) | |||||
return datas | |||||
class MyTester(object): | |||||
def __init__(self, batch_size, use_cuda=False, **kwagrs): | |||||
self.batch_size = batch_size | |||||
self.use_cuda = use_cuda | |||||
def test(self, model, dataset): | |||||
self.model = model.cuda() if self.use_cuda else model | |||||
self.model.eval() | |||||
batchiter = Batch(dataset, self.batch_size, SequentialSampler(), self.use_cuda) | |||||
eval_res = defaultdict(list) | |||||
i = 0 | |||||
for batch_x, batch_y in batchiter: | |||||
with torch.no_grad(): | |||||
pred_y = self.model(**batch_x) | |||||
eval_one = self.model.evaluate(**pred_y, **batch_y) | |||||
i += self.batch_size | |||||
for eval_name, tensor in eval_one.items(): | |||||
eval_res[eval_name].append(tensor) | |||||
tmp = {} | |||||
for eval_name, tensorlist in eval_res.items(): | |||||
tmp[eval_name] = torch.cat(tensorlist, dim=0) | |||||
self.res = self.model.metrics(**tmp) | |||||
def show_metrics(self): | |||||
s = "" | |||||
for name, val in self.res.items(): | |||||
s += '{}: {:.2f}\t'.format(name, val) | |||||
return s | |||||
loader = MyDataLoader('') | |||||
try: | |||||
data_dict = load_data(processed_datadir) | |||||
word_v = data_dict['word_v'] | |||||
pos_v = data_dict['pos_v'] | |||||
tag_v = data_dict['tag_v'] | |||||
train_data = data_dict['train_data'] | |||||
dev_data = data_dict['dev_data'] | |||||
print('use saved pickles') | |||||
except Exception as _: | |||||
print('load raw data and preprocess') | |||||
word_v = Vocabulary(need_default=True, min_freq=2) | |||||
pos_v = Vocabulary(need_default=True) | |||||
tag_v = Vocabulary(need_default=False) | |||||
train_data = loader.load(os.path.join(datadir, train_data_name), word_v, pos_v, tag_v) | |||||
dev_data = loader.load(os.path.join(datadir, dev_data_name)) | |||||
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data) | |||||
loader.index_data(train_data, word_v, pos_v, tag_v) | |||||
loader.index_data(dev_data, word_v, pos_v, tag_v) | |||||
print(len(train_data)) | |||||
print(len(dev_data)) | |||||
ep = train_args['epochs'] | |||||
train_args['epochs'] = math.ceil(50000.0 / len(train_data) * train_args['batch_size']) if ep <= 0 else ep | |||||
model_args['word_vocab_size'] = len(word_v) | |||||
model_args['pos_vocab_size'] = len(pos_v) | |||||
model_args['num_label'] = len(tag_v) | |||||
def train(): | |||||
# Trainer | |||||
trainer = Trainer(**train_args.data) | |||||
def _define_optim(obj): | |||||
obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data) | |||||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: .75 ** (ep / 5e4)) | |||||
def _update(obj): | |||||
obj._scheduler.step() | |||||
obj._optimizer.step() | |||||
trainer.define_optimizer = lambda: _define_optim(trainer) | |||||
trainer.update = lambda: _update(trainer) | |||||
trainer.get_loss = lambda predict, truth: trainer._loss_func(**predict, **truth) | |||||
trainer._create_validator = lambda x: MyTester(**test_args.data) | |||||
# Model | |||||
model = BiaffineParser(**model_args.data) | |||||
# use pretrain embedding | |||||
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) | |||||
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) | |||||
model.word_embedding.padding_idx = word_v.padding_idx | |||||
model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | |||||
model.pos_embedding.padding_idx = pos_v.padding_idx | |||||
model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) | |||||
try: | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
print('model parameter loaded!') | |||||
except Exception as _: | |||||
print("No saved model. Continue.") | |||||
pass | |||||
# Start training | |||||
trainer.train(model, train_data, dev_data) | |||||
print("Training finished!") | |||||
# Saver | |||||
saver = ModelSaver("./save/saved_model.pkl") | |||||
saver.save_pytorch(model) | |||||
print("Model saved!") | |||||
def test(): | |||||
# Tester | |||||
tester = MyTester(**test_args.data) | |||||
# Model | |||||
model = BiaffineParser(**model_args.data) | |||||
try: | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
print('model parameter loaded!') | |||||
except Exception as _: | |||||
print("No saved model. Abort test.") | |||||
raise | |||||
# Start training | |||||
tester.test(model, dev_data) | |||||
print(tester.show_metrics()) | |||||
print("Testing finished!") | |||||
if __name__ == "__main__": | |||||
import argparse | |||||
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | |||||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | |||||
args = parser.parse_args() | |||||
if args.mode == 'train': | |||||
train() | |||||
elif args.mode == 'test': | |||||
test() | |||||
elif args.mode == 'infer': | |||||
infer() | |||||
else: | |||||
print('no mode specified for model!') | |||||
parser.print_help() |