Browse Source

add new model, new module, fix bugs

tags/v0.2.0
yunfan 6 years ago
parent
commit
637c37d62b
7 changed files with 720 additions and 413 deletions
  1. +2
    -2
      fastNLP/core/field.py
  2. +1
    -1
      fastNLP/loader/config_loader.py
  3. +364
    -0
      fastNLP/models/biaffine_parser.py
  4. +15
    -0
      fastNLP/modules/dropout.py
  5. +41
    -410
      fastNLP/modules/encoder/variational_rnn.py
  6. +37
    -0
      reproduction/Biaffine_parser/cfg.cfg
  7. +260
    -0
      reproduction/Biaffine_parser/run.py

+ 2
- 2
fastNLP/core/field.py View File

@@ -98,7 +98,7 @@ class SeqLabelField(Field):
super(SeqLabelField, self).__init__(is_target)
self.label_seq = label_seq
self._index = None
def get_length(self):
return len(self.label_seq)

@@ -111,7 +111,7 @@ class SeqLabelField(Field):
pads = [0] * (padding_length - self.get_length())
if self._index is None:
if self.get_length() == 0:
return pads
return torch.LongTensor(pads)
elif isinstance(self.label_seq[0], int):
return torch.LongTensor(self.label_seq + pads)
elif isinstance(self.label_seq[0], str):


+ 1
- 1
fastNLP/loader/config_loader.py View File

@@ -8,7 +8,7 @@ from fastNLP.loader.base_loader import BaseLoader
class ConfigLoader(BaseLoader):
"""loader for configuration files"""

def __int__(self, data_path):
def __init__(self, data_path):
super(ConfigLoader, self).__init__()
self.config = self.parse(super(ConfigLoader, self).load(data_path))



+ 364
- 0
fastNLP/models/biaffine_parser.py View File

@@ -0,0 +1,364 @@
import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
import copy
import numpy as np
import torch
from collections import defaultdict
from torch import nn
from torch.nn import functional as F
from fastNLP.modules.utils import initial_parameter
from fastNLP.modules.encoder.variational_rnn import VarLSTM
from fastNLP.modules.dropout import TimestepDropout

def mst(scores):
"""
with some modification to support parser output for MST decoding
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692
"""
length = scores.shape[0]
min_score = -np.inf
mask = np.zeros((length, length))
np.fill_diagonal(mask, -np.inf)
scores = scores + mask
heads = np.argmax(scores, axis=1)
heads[0] = 0
tokens = np.arange(1, length)
roots = np.where(heads[tokens] == 0)[0] + 1
if len(roots) < 1:
root_scores = scores[tokens, 0]
head_scores = scores[tokens, heads[tokens]]
new_root = tokens[np.argmax(root_scores / head_scores)]
heads[new_root] = 0
elif len(roots) > 1:
root_scores = scores[roots, 0]
scores[roots, 0] = 0
new_heads = np.argmax(scores[roots][:, tokens], axis=1) + 1
new_root = roots[np.argmin(
scores[roots, new_heads] / root_scores)]
heads[roots] = new_heads
heads[new_root] = 0

edges = defaultdict(set)
vertices = set((0,))
for dep, head in enumerate(heads[tokens]):
vertices.add(dep + 1)
edges[head].add(dep + 1)
for cycle in _find_cycle(vertices, edges):
dependents = set()
to_visit = set(cycle)
while len(to_visit) > 0:
node = to_visit.pop()
if node not in dependents:
dependents.add(node)
to_visit.update(edges[node])
cycle = np.array(list(cycle))
old_heads = heads[cycle]
old_scores = scores[cycle, old_heads]
non_heads = np.array(list(dependents))
scores[np.repeat(cycle, len(non_heads)),
np.repeat([non_heads], len(cycle), axis=0).flatten()] = min_score
new_heads = np.argmax(scores[cycle][:, tokens], axis=1) + 1
new_scores = scores[cycle, new_heads] / old_scores
change = np.argmax(new_scores)
changed_cycle = cycle[change]
old_head = old_heads[change]
new_head = new_heads[change]
heads[changed_cycle] = new_head
edges[new_head].add(changed_cycle)
edges[old_head].remove(changed_cycle)

return heads


def _find_cycle(vertices, edges):
"""
https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/etc/tarjan.py
"""
_index = 0
_stack = []
_indices = {}
_lowlinks = {}
_onstack = defaultdict(lambda: False)
_SCCs = []

def _strongconnect(v):
nonlocal _index
_indices[v] = _index
_lowlinks[v] = _index
_index += 1
_stack.append(v)
_onstack[v] = True

for w in edges[v]:
if w not in _indices:
_strongconnect(w)
_lowlinks[v] = min(_lowlinks[v], _lowlinks[w])
elif _onstack[w]:
_lowlinks[v] = min(_lowlinks[v], _indices[w])

if _lowlinks[v] == _indices[v]:
SCC = set()
while True:
w = _stack.pop()
_onstack[w] = False
SCC.add(w)
if not(w != v):
break
_SCCs.append(SCC)

for v in vertices:
if v not in _indices:
_strongconnect(v)

return [SCC for SCC in _SCCs if len(SCC) > 1]


class GraphParser(nn.Module):
"""Graph based Parser helper class, support greedy decoding and MST(Maximum Spanning Tree) decoding
"""
def __init__(self):
super(GraphParser, self).__init__()

def forward(self, x):
raise NotImplementedError

def _greedy_decoder(self, arc_matrix, seq_mask=None):
_, seq_len, _ = arc_matrix.shape
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf))
_, heads = torch.max(matrix, dim=2)
if seq_mask is not None:
heads *= seq_mask.long()
return heads

def _mst_decoder(self, arc_matrix, seq_mask=None):
batch_size, seq_len, _ = arc_matrix.shape
matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix)
ans = matrix.new_zeros(batch_size, seq_len).long()
for i, graph in enumerate(matrix):
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device)
if seq_mask is not None:
ans *= seq_mask.long()
return ans


class ArcBiaffine(nn.Module):
"""helper module for Biaffine Dependency Parser predicting arc
"""
def __init__(self, hidden_size, bias=True):
super(ArcBiaffine, self).__init__()
self.U = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=True)
self.has_bias = bias
if self.has_bias:
self.bias = nn.Parameter(torch.Tensor(hidden_size), requires_grad=True)
else:
self.register_parameter("bias", None)
initial_parameter(self)

def forward(self, head, dep):
"""
:param head arc-head tensor = [batch, length, emb_dim]
:param dep arc-dependent tensor = [batch, length, emb_dim]

:return output tensor = [bacth, length, length]
"""
output = dep.matmul(self.U)
output = output.bmm(head.transpose(-1, -2))
if self.has_bias:
output += head.matmul(self.bias).unsqueeze(1)
return output


class LabelBilinear(nn.Module):
"""helper module for Biaffine Dependency Parser predicting label
"""
def __init__(self, in1_features, in2_features, num_label, bias=True):
super(LabelBilinear, self).__init__()
self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias)
self.lin1 = nn.Linear(in1_features, num_label, bias=False)
self.lin2 = nn.Linear(in2_features, num_label, bias=False)

def forward(self, x1, x2):
output = self.bilinear(x1, x2)
output += self.lin1(x1) + self.lin2(x2)
return output


class BiaffineParser(GraphParser):
"""Biaffine Dependency Parser implemantation.
refer to ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016)
<https://arxiv.org/abs/1611.01734>`_ .
"""
def __init__(self,
word_vocab_size,
word_emb_dim,
pos_vocab_size,
pos_emb_dim,
rnn_layers,
rnn_hidden_size,
arc_mlp_size,
label_mlp_size,
num_label,
dropout,
use_var_lstm=False,
use_greedy_infer=False):

super(BiaffineParser, self).__init__()
self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim)
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim)
if use_var_lstm:
self.lstm = VarLSTM(input_size=word_emb_dim + pos_emb_dim,
hidden_size=rnn_hidden_size,
num_layers=rnn_layers,
bias=True,
batch_first=True,
input_dropout=dropout,
hidden_dropout=dropout,
bidirectional=True)
else:
self.lstm = nn.LSTM(input_size=word_emb_dim + pos_emb_dim,
hidden_size=rnn_hidden_size,
num_layers=rnn_layers,
bias=True,
batch_first=True,
dropout=dropout,
bidirectional=True)

rnn_out_size = 2 * rnn_hidden_size
self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size),
nn.ELU())
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp)
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size),
nn.ELU())
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp)
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True)
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True)
self.normal_dropout = nn.Dropout(p=dropout)
self.timestep_dropout = TimestepDropout(p=dropout)
self.use_greedy_infer = use_greedy_infer
initial_parameter(self)

def forward(self, word_seq, pos_seq, seq_mask, gold_heads=None, **_):
"""
:param word_seq: [batch_size, seq_len] sequence of word's indices
:param pos_seq: [batch_size, seq_len] sequence of word's indices
:param seq_mask: [batch_size, seq_len] sequence of length masks
:param gold_heads: [batch_size, seq_len] sequence of golden heads
:return dict: parsing results
arc_pred: [batch_size, seq_len, seq_len]
label_pred: [batch_size, seq_len, seq_len]
seq_mask: [batch_size, seq_len]
head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads
"""
# prepare embeddings
batch_size, seq_len = word_seq.shape
# print('forward {} {}'.format(batch_size, seq_len))
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1)

# get sequence mask
seq_mask = seq_mask.long()

word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0]
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1]
x = torch.cat([word, pos], dim=2) # -> [N,L,C]

# lstm, extract features
feat, _ = self.lstm(x) # -> [N,L,C]

# for arc biaffine
# mlp, reduce dim
arc_dep = self.timestep_dropout(self.arc_dep_mlp(feat))
arc_head = self.timestep_dropout(self.arc_head_mlp(feat))
label_dep = self.timestep_dropout(self.label_dep_mlp(feat))
label_head = self.timestep_dropout(self.label_head_mlp(feat))

# biaffine arc classifier
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]
flip_mask = (seq_mask == 0)
arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf)

# use gold or predicted arc to predict label
if gold_heads is None:
# use greedy decoding in training
if self.training or self.use_greedy_infer:
heads = self._greedy_decoder(arc_pred, seq_mask)
else:
heads = self._mst_decoder(arc_pred, seq_mask)
head_pred = heads
else:
head_pred = None
heads = gold_heads

label_head = label_head[batch_range, heads].contiguous()
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label]
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask}
if head_pred is not None:
res_dict['head_pred'] = head_pred
return res_dict

def loss(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **_):
"""
Compute loss.

:param arc_pred: [batch_size, seq_len, seq_len]
:param label_pred: [batch_size, seq_len, seq_len]
:param head_indices: [batch_size, seq_len]
:param head_labels: [batch_size, seq_len]
:param seq_mask: [batch_size, seq_len]
:return: loss value
"""

batch_size, seq_len, _ = arc_pred.shape
arc_logits = F.log_softmax(arc_pred, dim=2)
label_logits = F.log_softmax(label_pred, dim=2)
batch_index = torch.arange(start=0, end=batch_size, device=arc_logits.device).long().unsqueeze(1)
child_index = torch.arange(start=0, end=seq_len, device=arc_logits.device).long().unsqueeze(0)
arc_loss = arc_logits[batch_index, child_index, head_indices]
label_loss = label_logits[batch_index, child_index, head_labels]

arc_loss = arc_loss[:, 1:]
label_loss = label_loss[:, 1:]

float_mask = seq_mask[:, 1:].float()
length = (seq_mask.sum() - batch_size).float()
arc_nll = -(arc_loss*float_mask).sum() / length
label_nll = -(label_loss*float_mask).sum() / length
return arc_nll + label_nll

def evaluate(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **kwargs):
"""
Evaluate the performance of prediction.

:return dict: performance results.
head_pred_corrct: number of correct predicted heads.
label_pred_correct: number of correct predicted labels.
total_tokens: number of predicted tokens
"""
if 'head_pred' in kwargs:
head_pred = kwargs['head_pred']
elif self.use_greedy_infer:
head_pred = self._greedy_decoder(arc_pred, seq_mask)
else:
head_pred = self._mst_decoder(arc_pred, seq_mask)

head_pred_correct = (head_pred == head_indices).long() * seq_mask
_, label_preds = torch.max(label_pred, dim=2)
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct
return {"head_pred_correct": head_pred_correct.sum(dim=1),
"label_pred_correct": label_pred_correct.sum(dim=1),
"total_tokens": seq_mask.sum(dim=1)}

def metrics(self, head_pred_correct, label_pred_correct, total_tokens, **_):
"""
Compute the metrics of model

:param head_pred_corrct: number of correct predicted heads.
:param label_pred_correct: number of correct predicted labels.
:param total_tokens: number of predicted tokens
:return dict: the metrics results
UAS: the head predicted accuracy
LAS: the label predicted accuracy
"""
return {"UAS": head_pred_correct.sum().float() / total_tokens.sum().float() * 100,
"LAS": label_pred_correct.sum().float() / total_tokens.sum().float() * 100}


+ 15
- 0
fastNLP/modules/dropout.py View File

@@ -0,0 +1,15 @@
import torch

class TimestepDropout(torch.nn.Dropout):
"""This module accepts a `[batch_size, num_timesteps, embedding_dim)]` and use a single
dropout mask of shape `(batch_size, embedding_dim)` to apply on every time step.
"""
def forward(self, x):
dropout_mask = x.new_ones(x.shape[0], x.shape[-1])
torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True)
dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim]
if self.inplace:
x *= dropout_mask
return
else:
return x * dropout_mask

+ 41
- 410
fastNLP/modules/encoder/variational_rnn.py View File

@@ -2,391 +2,14 @@ import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend
from torch.nn.parameter import Parameter
from torch.nn.utils.rnn import PackedSequence

# from fastNLP.modules.utils import initial_parameter

def default_initializer(hidden_size):
stdv = 1.0 / math.sqrt(hidden_size)

def forward(tensor):
nn.init.uniform_(tensor, -stdv, stdv)

return forward


def VarMaskedRecurrent(reverse=False):
def forward(input, hidden, cell, mask):
output = []
steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0))
for i in steps:
if mask is None or mask[i].data.min() > 0.5:
hidden = cell(input[i], hidden)
elif mask[i].data.max() > 0.5:
hidden_next = cell(input[i], hidden)
# hack to handle LSTM
if isinstance(hidden, tuple):
hx, cx = hidden
hp1, cp1 = hidden_next
hidden = (hx + (hp1 - hx) * mask[i], cx + (cp1 - cx) * mask[i])
else:
hidden = hidden + (hidden_next - hidden) * mask[i]
# hack to handle LSTM
output.append(hidden[0] if isinstance(hidden, tuple) else hidden)

if reverse:
output.reverse()
output = torch.cat(output, 0).view(input.size(0), *output[0].size())

return hidden, output

return forward


def StackedRNN(inners, num_layers, lstm=False):
num_directions = len(inners)
total_layers = num_layers * num_directions

def forward(input, hidden, cells, mask):
assert (len(cells) == total_layers)
next_hidden = []

if lstm:
hidden = list(zip(*hidden))

for i in range(num_layers):
all_output = []
for j, inner in enumerate(inners):
l = i * num_directions + j
hy, output = inner(input, hidden[l], cells[l], mask)
next_hidden.append(hy)
all_output.append(output)

input = torch.cat(all_output, input.dim() - 1)

if lstm:
next_h, next_c = zip(*next_hidden)
next_hidden = (
torch.cat(next_h, 0).view(total_layers, *next_h[0].size()),
torch.cat(next_c, 0).view(total_layers, *next_c[0].size())
)
else:
next_hidden = torch.cat(next_hidden, 0).view(total_layers, *next_hidden[0].size())

return next_hidden, input

return forward


def AutogradVarMaskedRNN(num_layers=1, batch_first=False, bidirectional=False, lstm=False):
rec_factory = VarMaskedRecurrent

if bidirectional:
layer = (rec_factory(), rec_factory(reverse=True))
else:
layer = (rec_factory(),)

func = StackedRNN(layer,
num_layers,
lstm=lstm)

def forward(input, cells, hidden, mask):
if batch_first:
input = input.transpose(0, 1)
if mask is not None:
mask = mask.transpose(0, 1)

nexth, output = func(input, hidden, cells, mask)

if batch_first:
output = output.transpose(0, 1)

return output, nexth

return forward


def VarMaskedStep():
def forward(input, hidden, cell, mask):
if mask is None or mask.data.min() > 0.5:
hidden = cell(input, hidden)
elif mask.data.max() > 0.5:
hidden_next = cell(input, hidden)
# hack to handle LSTM
if isinstance(hidden, tuple):
hx, cx = hidden
hp1, cp1 = hidden_next
hidden = (hx + (hp1 - hx) * mask, cx + (cp1 - cx) * mask)
else:
hidden = hidden + (hidden_next - hidden) * mask
# hack to handle LSTM
output = hidden[0] if isinstance(hidden, tuple) else hidden

return hidden, output

return forward


def StackedStep(layer, num_layers, lstm=False):
def forward(input, hidden, cells, mask):
assert (len(cells) == num_layers)
next_hidden = []

if lstm:
hidden = list(zip(*hidden))

for l in range(num_layers):
hy, output = layer(input, hidden[l], cells[l], mask)
next_hidden.append(hy)
input = output

if lstm:
next_h, next_c = zip(*next_hidden)
next_hidden = (
torch.cat(next_h, 0).view(num_layers, *next_h[0].size()),
torch.cat(next_c, 0).view(num_layers, *next_c[0].size())
)
else:
next_hidden = torch.cat(next_hidden, 0).view(num_layers, *next_hidden[0].size())

return next_hidden, input

return forward


def AutogradVarMaskedStep(num_layers=1, lstm=False):
layer = VarMaskedStep()

func = StackedStep(layer,
num_layers,
lstm=lstm)

def forward(input, cells, hidden, mask):
nexth, output = func(input, hidden, cells, mask)
return output, nexth

return forward


class VarMaskedRNNBase(nn.Module):
def __init__(self, Cell, input_size, hidden_size,
num_layers=1, bias=True, batch_first=False,
dropout=(0, 0), bidirectional=False, initializer=None,initial_method = None, **kwargs):

super(VarMaskedRNNBase, self).__init__()
self.Cell = Cell
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.bidirectional = bidirectional
self.lstm = False
num_directions = 2 if bidirectional else 1

self.all_cells = []
for layer in range(num_layers):
for direction in range(num_directions):
layer_input_size = input_size if layer == 0 else hidden_size * num_directions

cell = self.Cell(layer_input_size, hidden_size, self.bias, p=dropout, initializer=initializer, **kwargs)
self.all_cells.append(cell)
self.add_module('cell%d' % (layer * num_directions + direction), cell)
initial_parameter(self, initial_method)
def reset_parameters(self):
for cell in self.all_cells:
cell.reset_parameters()

def reset_noise(self, batch_size):
for cell in self.all_cells:
cell.reset_noise(batch_size)

def forward(self, input, mask=None, hx=None):
batch_size = input.size(0) if self.batch_first else input.size(1)
if hx is None:
num_directions = 2 if self.bidirectional else 1
hx = torch.tensor(input.data.new(self.num_layers * num_directions, batch_size, self.hidden_size).zero_(),
requires_grad=True)
if self.lstm:
hx = (hx, hx)

func = AutogradVarMaskedRNN(num_layers=self.num_layers,
batch_first=self.batch_first,
bidirectional=self.bidirectional,
lstm=self.lstm)

self.reset_noise(batch_size)

output, hidden = func(input, self.all_cells, hx, None if mask is None else mask.view(mask.size() + (1,)))
return output, hidden

def step(self, input, hx=None, mask=None):
'''
execute one step forward (only for one-directional RNN).
Args:
input (batch, input_size): input tensor of this step.
hx (num_layers, batch, hidden_size): the hidden state of last step.
mask (batch): the mask tensor of this step.
Returns:
output (batch, hidden_size): tensor containing the output of this step from the last layer of RNN.
hn (num_layers, batch, hidden_size): tensor containing the hidden state of this step
'''
assert not self.bidirectional, "step only cannot be applied to bidirectional RNN."
batch_size = input.size(0)
if hx is None:
hx = torch.tensor(input.data.new(self.num_layers, batch_size, self.hidden_size).zero_(), requires_grad=True)
if self.lstm:
hx = (hx, hx)

func = AutogradVarMaskedStep(num_layers=self.num_layers, lstm=self.lstm)

output, hidden = func(input, self.all_cells, hx, mask)
return output, hidden


class VarMaskedFastLSTM(VarMaskedRNNBase):
def __init__(self, *args, **kwargs):
super(VarMaskedFastLSTM, self).__init__(VarFastLSTMCell, *args, **kwargs)
self.lstm = True


class VarRNNCellBase(nn.Module):
def __repr__(self):
s = '{name}({input_size}, {hidden_size}'
if 'bias' in self.__dict__ and self.bias is not True:
s += ', bias={bias}'
if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
s += ', nonlinearity={nonlinearity}'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)

def reset_noise(self, batch_size):
"""
Should be overriden by all subclasses.
Args:
batch_size: (int) batch size of input.
"""
raise NotImplementedError


class VarFastLSTMCell(VarRNNCellBase):
"""
A long short-term memory (LSTM) cell with variational dropout.
.. math::
\begin{array}{ll}
i = \mathrm{sigmoid}(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
f = \mathrm{sigmoid}(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
g = \tanh(W_{ig} x + b_{ig} + W_{hc} h + b_{hg}) \\
o = \mathrm{sigmoid}(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
c' = f * c + i * g \\
h' = o * \tanh(c') \\
\end{array}
"""

def __init__(self, input_size, hidden_size, bias=True, p=(0.5, 0.5), initializer=None,initial_method =None):
super(VarFastLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size))
if bias:
self.bias_ih = Parameter(torch.Tensor(4 * hidden_size))
self.bias_hh = Parameter(torch.Tensor(4 * hidden_size))
else:
self.register_parameter('bias_ih', None)
self.register_parameter('bias_hh', None)

self.initializer = default_initializer(self.hidden_size) if initializer is None else initializer
self.reset_parameters()
p_in, p_hidden = p
if p_in < 0 or p_in > 1:
raise ValueError("input dropout probability has to be between 0 and 1, "
"but got {}".format(p_in))
if p_hidden < 0 or p_hidden > 1:
raise ValueError("hidden state dropout probability has to be between 0 and 1, "
"but got {}".format(p_hidden))
self.p_in = p_in
self.p_hidden = p_hidden
self.noise_in = None
self.noise_hidden = None
initial_parameter(self, initial_method)
def reset_parameters(self):
for weight in self.parameters():
if weight.dim() == 1:
weight.data.zero_()
else:
self.initializer(weight.data)

def reset_noise(self, batch_size):
if self.training:
if self.p_in:
noise = self.weight_ih.data.new(batch_size, self.input_size)
self.noise_in = torch.tensor(noise.bernoulli_(1.0 - self.p_in) / (1.0 - self.p_in))
else:
self.noise_in = None

if self.p_hidden:
noise = self.weight_hh.data.new(batch_size, self.hidden_size)
self.noise_hidden = torch.tensor(noise.bernoulli_(1.0 - self.p_hidden) / (1.0 - self.p_hidden))
else:
self.noise_hidden = None
else:
self.noise_in = None
self.noise_hidden = None

def forward(self, input, hx):
return self.__forward(
input, hx,
self.weight_ih, self.weight_hh,
self.bias_ih, self.bias_hh,
self.noise_in, self.noise_hidden,
)

@staticmethod
def __forward(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
if noise_in is not None:
if input.is_cuda:
input = input * noise_in.cuda(input.get_device())
else:
input = input * noise_in

if input.is_cuda:
w_ih = w_ih.cuda(input.get_device())
w_hh = w_hh.cuda(input.get_device())
hidden = [h.cuda(input.get_device()) for h in hidden]
b_ih = b_ih.cuda(input.get_device())
b_hh = b_hh.cuda(input.get_device())
igates = F.linear(input, w_ih.cuda(input.get_device()))
hgates = F.linear(hidden[0], w_hh) if noise_hidden is None \
else F.linear(hidden[0] * noise_hidden.cuda(input.get_device()), w_hh)
state = fusedBackend.LSTMFused.apply
# print("use backend")
# use some magic function
return state(igates, hgates, hidden[1]) if b_ih is None else state(igates, hgates, hidden[1], b_ih, b_hh)

hx, cx = hidden
if noise_hidden is not None:
hx = hx * noise_hidden
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)

ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)

cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * F.tanh(cy)

return hy, cy
from fastNLP.modules.utils import initial_parameter


class VarRnnCellWrapper(nn.Module):
"""Wrapper for normal RNN Cells, make it support variational dropout
"""
def __init__(self, cell, hidden_size, input_p, hidden_p):
super(VarRnnCellWrapper, self).__init__()
self.cell = cell
@@ -394,31 +17,26 @@ class VarRnnCellWrapper(nn.Module):
self.input_p = input_p
self.hidden_p = hidden_p

def forward(self, input, hidden):
def forward(self, input, hidden, mask_x=None, mask_h=None):
"""
:param input: [seq_len, batch_size, input_size]
:param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size]
for other RNN, h_0, [batch_size, hidden_size]

:param mask_x: [batch_size, input_size] dropout mask for input
:param mask_h: [batch_size, hidden_size] dropout mask for hidden
:return output: [seq_len, bacth_size, hidden_size]
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size]
for other RNN, h_n, [batch_size, hidden_size]
"""
is_lstm = isinstance(hidden, tuple)
_, batch_size, input_size = input.shape
mask_x = input.new_ones((batch_size, input_size))
mask_h = input.new_ones((batch_size, self.hidden_size))
nn.functional.dropout(mask_x, p=self.input_p, training=self.training, inplace=True)
nn.functional.dropout(mask_h, p=self.hidden_p, training=self.training, inplace=True)

input_x = input * mask_x.unsqueeze(0)
input = input * mask_x.unsqueeze(0) if mask_x is not None else input
output_list = []
for x in input_x:
for x in input:
if is_lstm:
hx, cx = hidden
hidden = (hx * mask_h, cx)
hidden = (hx * mask_h, cx) if mask_h is not None else (hx, cx)
else:
hidden *= mask_h
hidden *= mask_h if mask_h is not None else hidden
hidden = self.cell(x, hidden)
output_list.append(hidden[0] if is_lstm else hidden)
output = torch.stack(output_list, dim=0)
@@ -426,6 +44,10 @@ class VarRnnCellWrapper(nn.Module):


class VarRNNBase(nn.Module):
"""Implementation of Variational Dropout RNN network.
refer to `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016)
https://arxiv.org/abs/1512.05287`.
"""
def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1,
bias=True, batch_first=False,
input_dropout=0, hidden_dropout=0, bidirectional=False):
@@ -446,6 +68,7 @@ class VarRNNBase(nn.Module):
input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
cell = Cell(input_size, self.hidden_size, bias)
self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout))
initial_parameter(self)

def forward(self, input, hx=None):
is_packed = isinstance(input, PackedSequence)
@@ -466,6 +89,14 @@ class VarRNNBase(nn.Module):

if self.batch_first:
input = input.transpose(0, 1)
batch_size = input.shape[1]

mask_x = input.new_ones((batch_size, self.input_size))
mask_out = input.new_ones((batch_size, self.hidden_size * self.num_directions))
mask_h = input.new_ones((batch_size, self.hidden_size))
nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True)
nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True)
nn.functional.dropout(mask_h, p=self.hidden_dropout, training=self.training, inplace=True)

hidden_list = []
for layer in range(self.num_layers):
@@ -474,11 +105,13 @@ class VarRNNBase(nn.Module):
input_x = input if direction == 0 else input.flip(0)
idx = self.num_directions * layer + direction
cell = self._all_cells[idx]
output_x, hidden_x = cell(input_x, (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx])
hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx]
mask_xi = mask_x if layer == 0 else mask_out
output_x, hidden_x = cell(input_x, hi, mask_xi, mask_h)
output_list.append(output_x if direction == 0 else output_x.flip(0))
hidden_list.append(hidden_x)
input = torch.cat(output_list, dim=-1)
output = input.transpose(0, 1) if self.batch_first else input
if is_lstm:
h_list, c_list = zip(*hidden_list)
@@ -487,29 +120,27 @@ class VarRNNBase(nn.Module):
hidden = (hn, cn)
else:
hidden = torch.stack(hidden_list, dim=0)
if is_packed:
output = PackedSequence(output, batch_sizes)

return output, hidden
return output, hidden


class VarLSTM(VarRNNBase):
"""Variational Dropout LSTM.
"""
def __init__(self, *args, **kwargs):
super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs)

class VarRNN(VarRNNBase):
"""Variational Dropout RNN.
"""
def __init__(self, *args, **kwargs):
super(VarRNN, self).__init__(mode="RNN", Cell=nn.RNNCell, *args, **kwargs)

if __name__ == '__main__':
net = VarLSTM(input_size=10, hidden_size=20, num_layers=3, batch_first=True, bidirectional=True, input_dropout=0.33, hidden_dropout=0.33)
lstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=3, batch_first=True, bidirectional=True)
x = torch.randn(2, 8, 10)
y, hidden = net(x)
y0, h0 = lstm(x)
print(y.shape)
print(y0.shape)
print(y)
print(hidden[0])
print(hidden[0].shape)
print(y0)
print(h0[0])
print(h0[0].shape)
class VarGRU(VarRNNBase):
"""Variational Dropout GRU.
"""
def __init__(self, *args, **kwargs):
super(VarGRU, self).__init__(mode="GRU", Cell=nn.GRUCell, *args, **kwargs)

+ 37
- 0
reproduction/Biaffine_parser/cfg.cfg View File

@@ -0,0 +1,37 @@
[train]
epochs = 50
batch_size = 16
pickle_path = "./save/"
validate = true
save_best_dev = false
use_cuda = true
model_saved_path = "./save/"
task = "parse"


[test]
save_output = true
validate_in_training = true
save_dev_input = false
save_loss = true
batch_size = 16
pickle_path = "./save/"
use_cuda = true
task = "parse"

[model]
word_vocab_size = -1
word_emb_dim = 100
pos_vocab_size = -1
pos_emb_dim = 100
rnn_layers = 3
rnn_hidden_size = 400
arc_mlp_size = 500
label_mlp_size = 100
num_label = -1
dropout = 0.33
use_var_lstm=true
use_greedy_infer=false

[optim]
lr = 2e-3

+ 260
- 0
reproduction/Biaffine_parser/run.py View File

@@ -0,0 +1,260 @@
import os
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))

from collections import defaultdict
import math
import torch

from fastNLP.core.trainer import Trainer
from fastNLP.core.instance import Instance
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet
from fastNLP.core.batch import Batch
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.field import TextField, SeqLabelField
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle
from fastNLP.core.tester import Tester
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.loader.embed_loader import EmbedLoader
from fastNLP.models.biaffine_parser import BiaffineParser
from fastNLP.saver.model_saver import ModelSaver

# not in the file's dir
if len(os.path.dirname(__file__)) != 0:
os.chdir(os.path.dirname(__file__))

class MyDataLoader(object):
def __init__(self, pickle_path):
self.pickle_path = pickle_path

def load(self, path, word_v=None, pos_v=None, headtag_v=None):
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)

ds = DataSet(name='conll')
for sample in datalist:
# print(sample)
res = self.get_one(sample)
if word_v is not None:
word_v.update(res[0])
pos_v.update(res[1])
headtag_v.update(res[3])
ds.append(Instance(word_seq=TextField(res[0], is_target=False),
pos_seq=TextField(res[1], is_target=False),
head_indices=SeqLabelField(res[2], is_target=True),
head_labels=TextField(res[3], is_target=True),
seq_mask=SeqLabelField([1 for _ in range(len(res[0]))], is_target=False)))

return ds

def get_one(self, sample):
text = ['<root>']
pos_tags = ['<root>']
heads = [0]
head_tags = ['root']
for w in sample:
t1, t2, t3, t4 = w[1], w[3], w[6], w[7]
if t3 == '_':
continue
text.append(t1)
pos_tags.append(t2)
heads.append(int(t3))
head_tags.append(t4)
return (text, pos_tags, heads, head_tags)

def index_data(self, dataset, word_v, pos_v, tag_v):
dataset.index_field('word_seq', word_v)
dataset.index_field('pos_seq', pos_v)
dataset.index_field('head_labels', tag_v)

# datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT"
datadir = "/home/yfshao/UD_English-EWT"
cfgfile = './cfg.cfg'
train_data_name = "en_ewt-ud-train.conllu"
dev_data_name = "en_ewt-ud-dev.conllu"
emb_file_name = '/home/yfshao/glove.6B.100d.txt'
processed_datadir = './save'

# Config Loader
train_args = ConfigSection()
test_args = ConfigSection()
model_args = ConfigSection()
optim_args = ConfigSection()
ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args})

# Data Loader
def save_data(dirpath, **kwargs):
import _pickle
if not os.path.exists(dirpath):
os.mkdir(dirpath)
for name, data in kwargs.items():
with open(os.path.join(dirpath, name+'.pkl'), 'wb') as f:
_pickle.dump(data, f)


def load_data(dirpath):
import _pickle
datas = {}
for f_name in os.listdir(dirpath):
if not f_name.endswith('.pkl'):
continue
name = f_name[:-4]
with open(os.path.join(dirpath, f_name), 'rb') as f:
datas[name] = _pickle.load(f)
return datas

class MyTester(object):
def __init__(self, batch_size, use_cuda=False, **kwagrs):
self.batch_size = batch_size
self.use_cuda = use_cuda

def test(self, model, dataset):
self.model = model.cuda() if self.use_cuda else model
self.model.eval()
batchiter = Batch(dataset, self.batch_size, SequentialSampler(), self.use_cuda)
eval_res = defaultdict(list)
i = 0
for batch_x, batch_y in batchiter:
with torch.no_grad():
pred_y = self.model(**batch_x)
eval_one = self.model.evaluate(**pred_y, **batch_y)
i += self.batch_size
for eval_name, tensor in eval_one.items():
eval_res[eval_name].append(tensor)
tmp = {}
for eval_name, tensorlist in eval_res.items():
tmp[eval_name] = torch.cat(tensorlist, dim=0)

self.res = self.model.metrics(**tmp)

def show_metrics(self):
s = ""
for name, val in self.res.items():
s += '{}: {:.2f}\t'.format(name, val)
return s


loader = MyDataLoader('')
try:
data_dict = load_data(processed_datadir)
word_v = data_dict['word_v']
pos_v = data_dict['pos_v']
tag_v = data_dict['tag_v']
train_data = data_dict['train_data']
dev_data = data_dict['dev_data']
print('use saved pickles')

except Exception as _:
print('load raw data and preprocess')
word_v = Vocabulary(need_default=True, min_freq=2)
pos_v = Vocabulary(need_default=True)
tag_v = Vocabulary(need_default=False)
train_data = loader.load(os.path.join(datadir, train_data_name), word_v, pos_v, tag_v)
dev_data = loader.load(os.path.join(datadir, dev_data_name))
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data)

loader.index_data(train_data, word_v, pos_v, tag_v)
loader.index_data(dev_data, word_v, pos_v, tag_v)
print(len(train_data))
print(len(dev_data))
ep = train_args['epochs']
train_args['epochs'] = math.ceil(50000.0 / len(train_data) * train_args['batch_size']) if ep <= 0 else ep
model_args['word_vocab_size'] = len(word_v)
model_args['pos_vocab_size'] = len(pos_v)
model_args['num_label'] = len(tag_v)


def train():
# Trainer
trainer = Trainer(**train_args.data)

def _define_optim(obj):
obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data)
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: .75 ** (ep / 5e4))

def _update(obj):
obj._scheduler.step()
obj._optimizer.step()

trainer.define_optimizer = lambda: _define_optim(trainer)
trainer.update = lambda: _update(trainer)
trainer.get_loss = lambda predict, truth: trainer._loss_func(**predict, **truth)
trainer._create_validator = lambda x: MyTester(**test_args.data)

# Model
model = BiaffineParser(**model_args.data)

# use pretrain embedding
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl'))
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False)
model.word_embedding.padding_idx = word_v.padding_idx
model.word_embedding.weight.data[word_v.padding_idx].fill_(0)
model.pos_embedding.padding_idx = pos_v.padding_idx
model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0)

try:
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
print('model parameter loaded!')
except Exception as _:
print("No saved model. Continue.")
pass

# Start training
trainer.train(model, train_data, dev_data)
print("Training finished!")

# Saver
saver = ModelSaver("./save/saved_model.pkl")
saver.save_pytorch(model)
print("Model saved!")


def test():
# Tester
tester = MyTester(**test_args.data)

# Model
model = BiaffineParser(**model_args.data)

try:
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
print('model parameter loaded!')
except Exception as _:
print("No saved model. Abort test.")
raise

# Start training
tester.test(model, dev_data)
print(tester.show_metrics())
print("Testing finished!")



if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model')
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer'])
args = parser.parse_args()
if args.mode == 'train':
train()
elif args.mode == 'test':
test()
elif args.mode == 'infer':
infer()
else:
print('no mode specified for model!')
parser.print_help()

Loading…
Cancel
Save