Browse Source

Merge pull request #96 from choosewhatulike/biaffine

add biaffine dependency parser & some modules
tags/v0.2.0
Xipeng Qiu GitHub 6 years ago
parent
commit
fdc8f7b0c2
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1007 additions and 423 deletions
  1. +30
    -0
      fastNLP/core/field.py
  2. +68
    -27
      fastNLP/core/vocabulary.py
  3. +3
    -2
      fastNLP/loader/config_loader.py
  4. +58
    -24
      fastNLP/loader/embed_loader.py
  5. +364
    -0
      fastNLP/models/biaffine_parser.py
  6. +15
    -0
      fastNLP/modules/dropout.py
  7. +123
    -354
      fastNLP/modules/encoder/variational_rnn.py
  8. +37
    -0
      reproduction/Biaffine_parser/cfg.cfg
  9. +260
    -0
      reproduction/Biaffine_parser/run.py
  10. +12
    -0
      test/data_for_tests/glove.6B.50d_test.txt
  11. +33
    -0
      test/loader/test_embed_loader.py
  12. +4
    -16
      test/modules/test_variational_rnn.py

+ 30
- 0
fastNLP/core/field.py View File

@@ -93,5 +93,35 @@ class LabelField(Field):
return torch.LongTensor([self._index]) return torch.LongTensor([self._index])




class SeqLabelField(Field):
def __init__(self, label_seq, is_target=True):
super(SeqLabelField, self).__init__(is_target)
self.label_seq = label_seq
self._index = None

def get_length(self):
return len(self.label_seq)

def index(self, vocab):
if self._index is None:
self._index = [vocab[c] for c in self.label_seq]
return self._index

def to_tensor(self, padding_length):
pads = [0] * (padding_length - self.get_length())
if self._index is None:
if self.get_length() == 0:
return torch.LongTensor(pads)
elif isinstance(self.label_seq[0], int):
return torch.LongTensor(self.label_seq + pads)
elif isinstance(self.label_seq[0], str):
raise RuntimeError("Field {} not indexed. Call index method.".format(self.label))
else:
raise RuntimeError(
"Not support type for SeqLabelField. Expect str or int, got {}.".format(type(self.label)))
else:
return torch.LongTensor(self._index + pads)


if __name__ == "__main__": if __name__ == "__main__":
tf = TextField("test the code".split(), is_target=False) tf = TextField("test the code".split(), is_target=False)

+ 68
- 27
fastNLP/core/vocabulary.py View File

@@ -18,6 +18,15 @@ def isiterable(p_object):
return False return False
return True return True


def check_build_vocab(func):
def _wrapper(self, *args, **kwargs):
if self.word2idx is None:
self.build_vocab()
self.build_reverse_vocab()
elif self.idx2word is None:
self.build_reverse_vocab()
return func(self, *args, **kwargs)
return _wrapper


class Vocabulary(object): class Vocabulary(object):
"""Use for word and index one to one mapping """Use for word and index one to one mapping
@@ -30,30 +39,23 @@ class Vocabulary(object):
vocab["word"] vocab["word"]
vocab.to_word(5) vocab.to_word(5)
""" """

def __init__(self, need_default=True):
def __init__(self, need_default=True, max_size=None, min_freq=None):
""" """
:param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. :param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True.

:param int max_size: set the max number of words in Vocabulary. Default: None
:param int min_freq: set the min occur frequency of words in Vocabulary. Default: None
""" """
if need_default:
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX)
self.padding_label = DEFAULT_PADDING_LABEL
self.unknown_label = DEFAULT_UNKNOWN_LABEL
else:
self.word2idx = {}
self.padding_label = None
self.unknown_label = None

self.max_size = max_size
self.min_freq = min_freq
self.word_count = {}
self.has_default = need_default self.has_default = need_default
self.word2idx = None
self.idx2word = None self.idx2word = None


def __len__(self):
return len(self.word2idx)


def update(self, word): def update(self, word):
"""add word or list of words into Vocabulary """add word or list of words into Vocabulary
:param word: a list of string or a single string :param word: a list of string or a single string
""" """
if not isinstance(word, str) and isiterable(word): if not isinstance(word, str) and isiterable(word):
@@ -61,12 +63,48 @@ class Vocabulary(object):
for w in word: for w in word:
self.update(w) self.update(w)
else: else:
# it's a word to be added
if word not in self.word2idx:
self.word2idx[word] = len(self)
if self.idx2word is not None:
self.idx2word = None
# it's a word to be added
if word not in self.word_count:
self.word_count[word] = 1
else:
self.word_count[word] += 1
self.word2idx = None



def build_vocab(self):
"""build 'word to index' dict, and filter the word using `max_size` and `min_freq`
"""
if self.has_default:
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX)
self.padding_label = DEFAULT_PADDING_LABEL
self.unknown_label = DEFAULT_UNKNOWN_LABEL
else:
self.word2idx = {}
self.padding_label = None
self.unknown_label = None

words = sorted(self.word_count.items(), key=lambda kv: kv[1], reverse=True)
if self.min_freq is not None:
words = list(filter(lambda kv: kv[1] >= self.min_freq, words))
if self.max_size is not None and len(words) > self.max_size:
words = words[:self.max_size]
for w, _ in words:
self.word2idx[w] = len(self.word2idx)

def build_reverse_vocab(self):
"""build 'index to word' dict based on 'word to index' dict
"""
self.idx2word = {self.word2idx[w] : w for w in self.word2idx}

@check_build_vocab
def __len__(self):
return len(self.word2idx)

@check_build_vocab
def has_word(self, w):
return w in self.word2idx

@check_build_vocab
def __getitem__(self, w): def __getitem__(self, w):
"""To support usage like:: """To support usage like::


@@ -74,32 +112,35 @@ class Vocabulary(object):
""" """
if w in self.word2idx: if w in self.word2idx:
return self.word2idx[w] return self.word2idx[w]
else:
elif self.has_default:
return self.word2idx[DEFAULT_UNKNOWN_LABEL] return self.word2idx[DEFAULT_UNKNOWN_LABEL]
else:
raise ValueError("word {} not in vocabulary".format(w))


@check_build_vocab
def to_index(self, w): def to_index(self, w):
""" like to_index(w) function, turn a word to the index """ like to_index(w) function, turn a word to the index
if w is not in Vocabulary, return the unknown label if w is not in Vocabulary, return the unknown label
:param str w: :param str w:
""" """
return self[w] return self[w]


@property
@check_build_vocab
def unknown_idx(self): def unknown_idx(self):
if self.unknown_label is None: if self.unknown_label is None:
return None return None
return self.word2idx[self.unknown_label] return self.word2idx[self.unknown_label]


@property
@check_build_vocab
def padding_idx(self): def padding_idx(self):
if self.padding_label is None: if self.padding_label is None:
return None return None
return self.word2idx[self.padding_label] return self.word2idx[self.padding_label]


def build_reverse_vocab(self):
"""build 'index to word' dict based on 'word to index' dict
"""
self.idx2word = {self.word2idx[w]: w for w in self.word2idx}

@check_build_vocab
def to_word(self, idx): def to_word(self, idx):
"""given a word's index, return the word itself """given a word's index, return the word itself




+ 3
- 2
fastNLP/loader/config_loader.py View File

@@ -8,9 +8,10 @@ from fastNLP.loader.base_loader import BaseLoader
class ConfigLoader(BaseLoader): class ConfigLoader(BaseLoader):
"""loader for configuration files""" """loader for configuration files"""


def __int__(self, data_path):
def __init__(self, data_path=None):
super(ConfigLoader, self).__init__() super(ConfigLoader, self).__init__()
self.config = self.parse(super(ConfigLoader, self).load(data_path))
if data_path is not None:
self.config = self.parse(super(ConfigLoader, self).load(data_path))


@staticmethod @staticmethod
def parse(string): def parse(string):


+ 58
- 24
fastNLP/loader/embed_loader.py View File

@@ -1,10 +1,10 @@
import _pickle import _pickle
import os import os


import numpy as np
import torch


from fastNLP.loader.base_loader import BaseLoader from fastNLP.loader.base_loader import BaseLoader
from fastNLP.core.vocabulary import Vocabulary


class EmbedLoader(BaseLoader): class EmbedLoader(BaseLoader):
"""docstring for EmbedLoader""" """docstring for EmbedLoader"""
@@ -13,38 +13,72 @@ class EmbedLoader(BaseLoader):
super(EmbedLoader, self).__init__(data_path) super(EmbedLoader, self).__init__(data_path)


@staticmethod @staticmethod
def load_embedding(emb_dim, emb_file, word_dict, emb_pkl):
def _load_glove(emb_file):
"""Read file as a glove embedding

file format:
embeddings are split by line,
for one embedding, word and numbers split by space
Example::

word_1 float_1 float_2 ... float_emb_dim
word_2 float_1 float_2 ... float_emb_dim
...
"""
emb = {}
with open(emb_file, 'r', encoding='utf-8') as f:
for line in f:
line = list(filter(lambda w: len(w)>0, line.strip().split(' ')))
if len(line) > 0:
emb[line[0]] = torch.Tensor(list(map(float, line[1:])))
return emb
@staticmethod
def _load_pretrain(emb_file, emb_type):
"""Read txt data from embedding file and convert to np.array as pre-trained embedding

:param emb_file: str, the pre-trained embedding file path
:param emb_type: str, the pre-trained embedding data format
:return dict: {str: np.array}
"""
if emb_type == 'glove':
return EmbedLoader._load_glove(emb_file)
else:
raise Exception("embedding type {} not support yet".format(emb_type))

@staticmethod
def load_embedding(emb_dim, emb_file, emb_type, vocab, emb_pkl):
"""Load the pre-trained embedding and combine with the given dictionary. """Load the pre-trained embedding and combine with the given dictionary.


:param emb_file: str, the pre-trained embedding.
The embedding file should have the following format:
Each line is a word embedding, where a word string is followed by multiple floats.
Floats are separated by space. The word and the first float are separated by space.
:param word_dict: dict, a mapping from word to index.
:param emb_dim: int, the dimension of the embedding. Should be the same as pre-trained embedding. :param emb_dim: int, the dimension of the embedding. Should be the same as pre-trained embedding.
:param emb_file: str, the pre-trained embedding file path.
:param emb_type: str, the pre-trained embedding format, support glove now
:param vocab: Vocabulary, a mapping from word to index, can be provided by user or built from pre-trained embedding
:param emb_pkl: str, the embedding pickle file. :param emb_pkl: str, the embedding pickle file.
:return embedding_np: numpy array of shape (len(word_dict), emb_dim) :return embedding_np: numpy array of shape (len(word_dict), emb_dim)

vocab: input vocab or vocab built by pre-train
TODO: fragile code TODO: fragile code
""" """
# If the embedding pickle exists, load it and return. # If the embedding pickle exists, load it and return.
if os.path.exists(emb_pkl): if os.path.exists(emb_pkl):
with open(emb_pkl, "rb") as f: with open(emb_pkl, "rb") as f:
embedding_np = _pickle.load(f)
return embedding_np
embedding_np, vocab = _pickle.load(f)
return embedding_np, vocab
# Otherwise, load the pre-trained embedding. # Otherwise, load the pre-trained embedding.
with open(emb_file, "r", encoding="utf-8") as f:
# begin with a random embedding
embedding_np = np.random.uniform(-1, 1, size=(len(word_dict), emb_dim))
for line in f:
line = line.strip().split()
if len(line) != emb_dim + 1:
# skip this line if two embedding dimension not match
continue
if line[0] in word_dict:
# find the word and replace its embedding with a pre-trained one
embedding_np[word_dict[line[0]]] = [float(i) for i in line[1:]]
pretrain = EmbedLoader._load_pretrain(emb_file, emb_type)
if vocab is None:
# build vocabulary from pre-trained embedding
vocab = Vocabulary()
for w in pretrain.keys():
vocab.update(w)
embedding_np = torch.randn(len(vocab), emb_dim)
for w, v in pretrain.items():
if len(v.shape) > 1 or emb_dim != v.shape[0]:
raise ValueError('pretrian embedding dim is {}, dismatching required {}'.format(v.shape, (emb_dim,)))
if vocab.has_word(w):
embedding_np[vocab[w]] = v

# save and return the result # save and return the result
with open(emb_pkl, "wb") as f: with open(emb_pkl, "wb") as f:
_pickle.dump(embedding_np, f)
return embedding_np
_pickle.dump((embedding_np, vocab), f)
return embedding_np, vocab

+ 364
- 0
fastNLP/models/biaffine_parser.py View File

@@ -0,0 +1,364 @@
import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
import copy
import numpy as np
import torch
from collections import defaultdict
from torch import nn
from torch.nn import functional as F
from fastNLP.modules.utils import initial_parameter
from fastNLP.modules.encoder.variational_rnn import VarLSTM
from fastNLP.modules.dropout import TimestepDropout

def mst(scores):
"""
with some modification to support parser output for MST decoding
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692
"""
length = scores.shape[0]
min_score = -np.inf
mask = np.zeros((length, length))
np.fill_diagonal(mask, -np.inf)
scores = scores + mask
heads = np.argmax(scores, axis=1)
heads[0] = 0
tokens = np.arange(1, length)
roots = np.where(heads[tokens] == 0)[0] + 1
if len(roots) < 1:
root_scores = scores[tokens, 0]
head_scores = scores[tokens, heads[tokens]]
new_root = tokens[np.argmax(root_scores / head_scores)]
heads[new_root] = 0
elif len(roots) > 1:
root_scores = scores[roots, 0]
scores[roots, 0] = 0
new_heads = np.argmax(scores[roots][:, tokens], axis=1) + 1
new_root = roots[np.argmin(
scores[roots, new_heads] / root_scores)]
heads[roots] = new_heads
heads[new_root] = 0

edges = defaultdict(set)
vertices = set((0,))
for dep, head in enumerate(heads[tokens]):
vertices.add(dep + 1)
edges[head].add(dep + 1)
for cycle in _find_cycle(vertices, edges):
dependents = set()
to_visit = set(cycle)
while len(to_visit) > 0:
node = to_visit.pop()
if node not in dependents:
dependents.add(node)
to_visit.update(edges[node])
cycle = np.array(list(cycle))
old_heads = heads[cycle]
old_scores = scores[cycle, old_heads]
non_heads = np.array(list(dependents))
scores[np.repeat(cycle, len(non_heads)),
np.repeat([non_heads], len(cycle), axis=0).flatten()] = min_score
new_heads = np.argmax(scores[cycle][:, tokens], axis=1) + 1
new_scores = scores[cycle, new_heads] / old_scores
change = np.argmax(new_scores)
changed_cycle = cycle[change]
old_head = old_heads[change]
new_head = new_heads[change]
heads[changed_cycle] = new_head
edges[new_head].add(changed_cycle)
edges[old_head].remove(changed_cycle)

return heads


def _find_cycle(vertices, edges):
"""
https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/etc/tarjan.py
"""
_index = 0
_stack = []
_indices = {}
_lowlinks = {}
_onstack = defaultdict(lambda: False)
_SCCs = []

def _strongconnect(v):
nonlocal _index
_indices[v] = _index
_lowlinks[v] = _index
_index += 1
_stack.append(v)
_onstack[v] = True

for w in edges[v]:
if w not in _indices:
_strongconnect(w)
_lowlinks[v] = min(_lowlinks[v], _lowlinks[w])
elif _onstack[w]:
_lowlinks[v] = min(_lowlinks[v], _indices[w])

if _lowlinks[v] == _indices[v]:
SCC = set()
while True:
w = _stack.pop()
_onstack[w] = False
SCC.add(w)
if not(w != v):
break
_SCCs.append(SCC)

for v in vertices:
if v not in _indices:
_strongconnect(v)

return [SCC for SCC in _SCCs if len(SCC) > 1]


class GraphParser(nn.Module):
"""Graph based Parser helper class, support greedy decoding and MST(Maximum Spanning Tree) decoding
"""
def __init__(self):
super(GraphParser, self).__init__()

def forward(self, x):
raise NotImplementedError

def _greedy_decoder(self, arc_matrix, seq_mask=None):
_, seq_len, _ = arc_matrix.shape
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf))
_, heads = torch.max(matrix, dim=2)
if seq_mask is not None:
heads *= seq_mask.long()
return heads

def _mst_decoder(self, arc_matrix, seq_mask=None):
batch_size, seq_len, _ = arc_matrix.shape
matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix)
ans = matrix.new_zeros(batch_size, seq_len).long()
for i, graph in enumerate(matrix):
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device)
if seq_mask is not None:
ans *= seq_mask.long()
return ans


class ArcBiaffine(nn.Module):
"""helper module for Biaffine Dependency Parser predicting arc
"""
def __init__(self, hidden_size, bias=True):
super(ArcBiaffine, self).__init__()
self.U = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=True)
self.has_bias = bias
if self.has_bias:
self.bias = nn.Parameter(torch.Tensor(hidden_size), requires_grad=True)
else:
self.register_parameter("bias", None)
initial_parameter(self)

def forward(self, head, dep):
"""
:param head arc-head tensor = [batch, length, emb_dim]
:param dep arc-dependent tensor = [batch, length, emb_dim]

:return output tensor = [bacth, length, length]
"""
output = dep.matmul(self.U)
output = output.bmm(head.transpose(-1, -2))
if self.has_bias:
output += head.matmul(self.bias).unsqueeze(1)
return output


class LabelBilinear(nn.Module):
"""helper module for Biaffine Dependency Parser predicting label
"""
def __init__(self, in1_features, in2_features, num_label, bias=True):
super(LabelBilinear, self).__init__()
self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias)
self.lin1 = nn.Linear(in1_features, num_label, bias=False)
self.lin2 = nn.Linear(in2_features, num_label, bias=False)

def forward(self, x1, x2):
output = self.bilinear(x1, x2)
output += self.lin1(x1) + self.lin2(x2)
return output


class BiaffineParser(GraphParser):
"""Biaffine Dependency Parser implemantation.
refer to ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016)
<https://arxiv.org/abs/1611.01734>`_ .
"""
def __init__(self,
word_vocab_size,
word_emb_dim,
pos_vocab_size,
pos_emb_dim,
rnn_layers,
rnn_hidden_size,
arc_mlp_size,
label_mlp_size,
num_label,
dropout,
use_var_lstm=False,
use_greedy_infer=False):

super(BiaffineParser, self).__init__()
self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim)
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim)
if use_var_lstm:
self.lstm = VarLSTM(input_size=word_emb_dim + pos_emb_dim,
hidden_size=rnn_hidden_size,
num_layers=rnn_layers,
bias=True,
batch_first=True,
input_dropout=dropout,
hidden_dropout=dropout,
bidirectional=True)
else:
self.lstm = nn.LSTM(input_size=word_emb_dim + pos_emb_dim,
hidden_size=rnn_hidden_size,
num_layers=rnn_layers,
bias=True,
batch_first=True,
dropout=dropout,
bidirectional=True)

rnn_out_size = 2 * rnn_hidden_size
self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size),
nn.ELU())
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp)
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size),
nn.ELU())
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp)
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True)
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True)
self.normal_dropout = nn.Dropout(p=dropout)
self.timestep_dropout = TimestepDropout(p=dropout)
self.use_greedy_infer = use_greedy_infer
initial_parameter(self)

def forward(self, word_seq, pos_seq, seq_mask, gold_heads=None, **_):
"""
:param word_seq: [batch_size, seq_len] sequence of word's indices
:param pos_seq: [batch_size, seq_len] sequence of word's indices
:param seq_mask: [batch_size, seq_len] sequence of length masks
:param gold_heads: [batch_size, seq_len] sequence of golden heads
:return dict: parsing results
arc_pred: [batch_size, seq_len, seq_len]
label_pred: [batch_size, seq_len, seq_len]
seq_mask: [batch_size, seq_len]
head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads
"""
# prepare embeddings
batch_size, seq_len = word_seq.shape
# print('forward {} {}'.format(batch_size, seq_len))
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1)

# get sequence mask
seq_mask = seq_mask.long()

word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0]
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1]
x = torch.cat([word, pos], dim=2) # -> [N,L,C]

# lstm, extract features
feat, _ = self.lstm(x) # -> [N,L,C]

# for arc biaffine
# mlp, reduce dim
arc_dep = self.timestep_dropout(self.arc_dep_mlp(feat))
arc_head = self.timestep_dropout(self.arc_head_mlp(feat))
label_dep = self.timestep_dropout(self.label_dep_mlp(feat))
label_head = self.timestep_dropout(self.label_head_mlp(feat))

# biaffine arc classifier
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]
flip_mask = (seq_mask == 0)
arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf)

# use gold or predicted arc to predict label
if gold_heads is None:
# use greedy decoding in training
if self.training or self.use_greedy_infer:
heads = self._greedy_decoder(arc_pred, seq_mask)
else:
heads = self._mst_decoder(arc_pred, seq_mask)
head_pred = heads
else:
head_pred = None
heads = gold_heads

label_head = label_head[batch_range, heads].contiguous()
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label]
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask}
if head_pred is not None:
res_dict['head_pred'] = head_pred
return res_dict

def loss(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **_):
"""
Compute loss.

:param arc_pred: [batch_size, seq_len, seq_len]
:param label_pred: [batch_size, seq_len, seq_len]
:param head_indices: [batch_size, seq_len]
:param head_labels: [batch_size, seq_len]
:param seq_mask: [batch_size, seq_len]
:return: loss value
"""

batch_size, seq_len, _ = arc_pred.shape
arc_logits = F.log_softmax(arc_pred, dim=2)
label_logits = F.log_softmax(label_pred, dim=2)
batch_index = torch.arange(start=0, end=batch_size, device=arc_logits.device).long().unsqueeze(1)
child_index = torch.arange(start=0, end=seq_len, device=arc_logits.device).long().unsqueeze(0)
arc_loss = arc_logits[batch_index, child_index, head_indices]
label_loss = label_logits[batch_index, child_index, head_labels]

arc_loss = arc_loss[:, 1:]
label_loss = label_loss[:, 1:]

float_mask = seq_mask[:, 1:].float()
length = (seq_mask.sum() - batch_size).float()
arc_nll = -(arc_loss*float_mask).sum() / length
label_nll = -(label_loss*float_mask).sum() / length
return arc_nll + label_nll

def evaluate(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **kwargs):
"""
Evaluate the performance of prediction.

:return dict: performance results.
head_pred_corrct: number of correct predicted heads.
label_pred_correct: number of correct predicted labels.
total_tokens: number of predicted tokens
"""
if 'head_pred' in kwargs:
head_pred = kwargs['head_pred']
elif self.use_greedy_infer:
head_pred = self._greedy_decoder(arc_pred, seq_mask)
else:
head_pred = self._mst_decoder(arc_pred, seq_mask)

head_pred_correct = (head_pred == head_indices).long() * seq_mask
_, label_preds = torch.max(label_pred, dim=2)
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct
return {"head_pred_correct": head_pred_correct.sum(dim=1),
"label_pred_correct": label_pred_correct.sum(dim=1),
"total_tokens": seq_mask.sum(dim=1)}

def metrics(self, head_pred_correct, label_pred_correct, total_tokens, **_):
"""
Compute the metrics of model

:param head_pred_corrct: number of correct predicted heads.
:param label_pred_correct: number of correct predicted labels.
:param total_tokens: number of predicted tokens
:return dict: the metrics results
UAS: the head predicted accuracy
LAS: the label predicted accuracy
"""
return {"UAS": head_pred_correct.sum().float() / total_tokens.sum().float() * 100,
"LAS": label_pred_correct.sum().float() / total_tokens.sum().float() * 100}


+ 15
- 0
fastNLP/modules/dropout.py View File

@@ -0,0 +1,15 @@
import torch

class TimestepDropout(torch.nn.Dropout):
"""This module accepts a `[batch_size, num_timesteps, embedding_dim)]` and use a single
dropout mask of shape `(batch_size, embedding_dim)` to apply on every time step.
"""
def forward(self, x):
dropout_mask = x.new_ones(x.shape[0], x.shape[-1])
torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True)
dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim]
if self.inplace:
x *= dropout_mask
return
else:
return x * dropout_mask

+ 123
- 354
fastNLP/modules/encoder/variational_rnn.py View File

@@ -2,384 +2,153 @@ import math


import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend
from torch.nn.parameter import Parameter
from torch.nn.utils.rnn import PackedSequence


from fastNLP.modules.utils import initial_parameter from fastNLP.modules.utils import initial_parameter


def default_initializer(hidden_size):
stdv = 1.0 / math.sqrt(hidden_size)

def forward(tensor):
nn.init.uniform_(tensor, -stdv, stdv)

return forward


def VarMaskedRecurrent(reverse=False):
def forward(input, hidden, cell, mask):
output = []
steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0))
for i in steps:
if mask is None or mask[i].data.min() > 0.5:
hidden = cell(input[i], hidden)
elif mask[i].data.max() > 0.5:
hidden_next = cell(input[i], hidden)
# hack to handle LSTM
if isinstance(hidden, tuple):
hx, cx = hidden
hp1, cp1 = hidden_next
hidden = (hx + (hp1 - hx) * mask[i], cx + (cp1 - cx) * mask[i])
else:
hidden = hidden + (hidden_next - hidden) * mask[i]
# hack to handle LSTM
output.append(hidden[0] if isinstance(hidden, tuple) else hidden)

if reverse:
output.reverse()
output = torch.cat(output, 0).view(input.size(0), *output[0].size())

return hidden, output

return forward


def StackedRNN(inners, num_layers, lstm=False):
num_directions = len(inners)
total_layers = num_layers * num_directions

def forward(input, hidden, cells, mask):
assert (len(cells) == total_layers)
next_hidden = []

if lstm:
hidden = list(zip(*hidden))

for i in range(num_layers):
all_output = []
for j, inner in enumerate(inners):
l = i * num_directions + j
hy, output = inner(input, hidden[l], cells[l], mask)
next_hidden.append(hy)
all_output.append(output)

input = torch.cat(all_output, input.dim() - 1)

if lstm:
next_h, next_c = zip(*next_hidden)
next_hidden = (
torch.cat(next_h, 0).view(total_layers, *next_h[0].size()),
torch.cat(next_c, 0).view(total_layers, *next_c[0].size())
)
else:
next_hidden = torch.cat(next_hidden, 0).view(total_layers, *next_hidden[0].size())

return next_hidden, input

return forward


def AutogradVarMaskedRNN(num_layers=1, batch_first=False, bidirectional=False, lstm=False):
rec_factory = VarMaskedRecurrent

if bidirectional:
layer = (rec_factory(), rec_factory(reverse=True))
else:
layer = (rec_factory(),)

func = StackedRNN(layer,
num_layers,
lstm=lstm)

def forward(input, cells, hidden, mask):
if batch_first:
input = input.transpose(0, 1)
if mask is not None:
mask = mask.transpose(0, 1)

nexth, output = func(input, hidden, cells, mask)

if batch_first:
output = output.transpose(0, 1)

return output, nexth

return forward

try:
from torch import flip
except ImportError:
def flip(x, dims):
indices = [slice(None)] * x.dim()
for dim in dims:
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
return x[tuple(indices)]

class VarRnnCellWrapper(nn.Module):
"""Wrapper for normal RNN Cells, make it support variational dropout
"""
def __init__(self, cell, hidden_size, input_p, hidden_p):
super(VarRnnCellWrapper, self).__init__()
self.cell = cell
self.hidden_size = hidden_size
self.input_p = input_p
self.hidden_p = hidden_p


def VarMaskedStep():
def forward(input, hidden, cell, mask):
if mask is None or mask.data.min() > 0.5:
hidden = cell(input, hidden)
elif mask.data.max() > 0.5:
hidden_next = cell(input, hidden)
# hack to handle LSTM
if isinstance(hidden, tuple):
def forward(self, input, hidden, mask_x=None, mask_h=None):
"""
:param input: [seq_len, batch_size, input_size]
:param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size]
for other RNN, h_0, [batch_size, hidden_size]
:param mask_x: [batch_size, input_size] dropout mask for input
:param mask_h: [batch_size, hidden_size] dropout mask for hidden
:return output: [seq_len, bacth_size, hidden_size]
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size]
for other RNN, h_n, [batch_size, hidden_size]
"""
is_lstm = isinstance(hidden, tuple)
input = input * mask_x.unsqueeze(0) if mask_x is not None else input
output_list = []
for x in input:
if is_lstm:
hx, cx = hidden hx, cx = hidden
hp1, cp1 = hidden_next
hidden = (hx + (hp1 - hx) * mask, cx + (cp1 - cx) * mask)
hidden = (hx * mask_h, cx) if mask_h is not None else (hx, cx)
else: else:
hidden = hidden + (hidden_next - hidden) * mask
# hack to handle LSTM
output = hidden[0] if isinstance(hidden, tuple) else hidden

return hidden, output

return forward


def StackedStep(layer, num_layers, lstm=False):
def forward(input, hidden, cells, mask):
assert (len(cells) == num_layers)
next_hidden = []

if lstm:
hidden = list(zip(*hidden))

for l in range(num_layers):
hy, output = layer(input, hidden[l], cells[l], mask)
next_hidden.append(hy)
input = output

if lstm:
next_h, next_c = zip(*next_hidden)
next_hidden = (
torch.cat(next_h, 0).view(num_layers, *next_h[0].size()),
torch.cat(next_c, 0).view(num_layers, *next_c[0].size())
)
else:
next_hidden = torch.cat(next_hidden, 0).view(num_layers, *next_hidden[0].size())

return next_hidden, input

return forward


def AutogradVarMaskedStep(num_layers=1, lstm=False):
layer = VarMaskedStep()

func = StackedStep(layer,
num_layers,
lstm=lstm)

def forward(input, cells, hidden, mask):
nexth, output = func(input, hidden, cells, mask)
return output, nexth

return forward

hidden *= mask_h if mask_h is not None else hidden
hidden = self.cell(x, hidden)
output_list.append(hidden[0] if is_lstm else hidden)
output = torch.stack(output_list, dim=0)
return output, hidden


class VarMaskedRNNBase(nn.Module):
def __init__(self, Cell, input_size, hidden_size,
num_layers=1, bias=True, batch_first=False,
dropout=(0, 0), bidirectional=False, initializer=None,initial_method = None, **kwargs):


super(VarMaskedRNNBase, self).__init__()
self.Cell = Cell
class VarRNNBase(nn.Module):
"""Implementation of Variational Dropout RNN network.
refer to `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016)
https://arxiv.org/abs/1512.05287`.
"""
def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1,
bias=True, batch_first=False,
input_dropout=0, hidden_dropout=0, bidirectional=False):
super(VarRNNBase, self).__init__()
self.mode = mode
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_layers = num_layers self.num_layers = num_layers
self.bias = bias self.bias = bias
self.batch_first = batch_first self.batch_first = batch_first
self.input_dropout = input_dropout
self.hidden_dropout = hidden_dropout
self.bidirectional = bidirectional self.bidirectional = bidirectional
self.lstm = False
num_directions = 2 if bidirectional else 1

self.all_cells = []
for layer in range(num_layers):
for direction in range(num_directions):
layer_input_size = input_size if layer == 0 else hidden_size * num_directions

cell = self.Cell(layer_input_size, hidden_size, self.bias, p=dropout, initializer=initializer, **kwargs)
self.all_cells.append(cell)
self.add_module('cell%d' % (layer * num_directions + direction), cell)
initial_parameter(self, initial_method)
def reset_parameters(self):
for cell in self.all_cells:
cell.reset_parameters()

def reset_noise(self, batch_size):
for cell in self.all_cells:
cell.reset_noise(batch_size)
self.num_directions = 2 if bidirectional else 1
self._all_cells = nn.ModuleList()
for layer in range(self.num_layers):
for direction in range(self.num_directions):
input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
cell = Cell(input_size, self.hidden_size, bias)
self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout))
initial_parameter(self)

def forward(self, input, hx=None):
is_packed = isinstance(input, PackedSequence)
is_lstm = (self.mode == "LSTM")
if is_packed:
input, batch_sizes = input
max_batch_size = int(batch_sizes[0])
else:
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)


def forward(self, input, mask=None, hx=None):
batch_size = input.size(0) if self.batch_first else input.size(1)
if hx is None: if hx is None:
num_directions = 2 if self.bidirectional else 1
hx = torch.tensor(input.data.new(self.num_layers * num_directions, batch_size, self.hidden_size).zero_(),
requires_grad=True)
if self.lstm:
hx = input.new_zeros(self.num_layers * self.num_directions,
max_batch_size, self.hidden_size,
requires_grad=False)
if is_lstm:
hx = (hx, hx) hx = (hx, hx)


func = AutogradVarMaskedRNN(num_layers=self.num_layers,
batch_first=self.batch_first,
bidirectional=self.bidirectional,
lstm=self.lstm)

self.reset_noise(batch_size)

output, hidden = func(input, self.all_cells, hx, None if mask is None else mask.view(mask.size() + (1,)))
return output, hidden

def step(self, input, hx=None, mask=None):
'''
execute one step forward (only for one-directional RNN).
Args:
input (batch, input_size): input tensor of this step.
hx (num_layers, batch, hidden_size): the hidden state of last step.
mask (batch): the mask tensor of this step.
Returns:
output (batch, hidden_size): tensor containing the output of this step from the last layer of RNN.
hn (num_layers, batch, hidden_size): tensor containing the hidden state of this step
'''
assert not self.bidirectional, "step only cannot be applied to bidirectional RNN."
batch_size = input.size(0)
if hx is None:
hx = torch.tensor(input.data.new(self.num_layers, batch_size, self.hidden_size).zero_(), requires_grad=True)
if self.lstm:
hx = (hx, hx)
if self.batch_first:
input = input.transpose(0, 1)
batch_size = input.shape[1]

mask_x = input.new_ones((batch_size, self.input_size))
mask_out = input.new_ones((batch_size, self.hidden_size * self.num_directions))
mask_h = input.new_ones((batch_size, self.hidden_size))
nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True)
nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True)
nn.functional.dropout(mask_h, p=self.hidden_dropout, training=self.training, inplace=True)

hidden_list = []
for layer in range(self.num_layers):
output_list = []
for direction in range(self.num_directions):
input_x = input if direction == 0 else flip(input, [0])
idx = self.num_directions * layer + direction
cell = self._all_cells[idx]
hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx]
mask_xi = mask_x if layer == 0 else mask_out
output_x, hidden_x = cell(input_x, hi, mask_xi, mask_h)
output_list.append(output_x if direction == 0 else flip(output_x, [0]))
hidden_list.append(hidden_x)
input = torch.cat(output_list, dim=-1)

output = input.transpose(0, 1) if self.batch_first else input
if is_lstm:
h_list, c_list = zip(*hidden_list)
hn = torch.stack(h_list, dim=0)
cn = torch.stack(c_list, dim=0)
hidden = (hn, cn)
else:
hidden = torch.stack(hidden_list, dim=0)


func = AutogradVarMaskedStep(num_layers=self.num_layers, lstm=self.lstm)
if is_packed:
output = PackedSequence(output, batch_sizes)


output, hidden = func(input, self.all_cells, hx, mask)
return output, hidden return output, hidden




class VarMaskedFastLSTM(VarMaskedRNNBase):
class VarLSTM(VarRNNBase):
"""Variational Dropout LSTM.
"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(VarMaskedFastLSTM, self).__init__(VarFastLSTMCell, *args, **kwargs)
self.lstm = True


class VarRNNCellBase(nn.Module):
def __repr__(self):
s = '{name}({input_size}, {hidden_size}'
if 'bias' in self.__dict__ and self.bias is not True:
s += ', bias={bias}'
if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
s += ', nonlinearity={nonlinearity}'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)
super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs)


def reset_noise(self, batch_size):
"""
Should be overriden by all subclasses.
Args:
batch_size: (int) batch size of input.
"""
raise NotImplementedError


class VarFastLSTMCell(VarRNNCellBase):
"""
A long short-term memory (LSTM) cell with variational dropout.
.. math::
\begin{array}{ll}
i = \mathrm{sigmoid}(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
f = \mathrm{sigmoid}(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
g = \tanh(W_{ig} x + b_{ig} + W_{hc} h + b_{hg}) \\
o = \mathrm{sigmoid}(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
c' = f * c + i * g \\
h' = o * \tanh(c') \\
\end{array}
class VarRNN(VarRNNBase):
"""Variational Dropout RNN.
""" """
def __init__(self, *args, **kwargs):
super(VarRNN, self).__init__(mode="RNN", Cell=nn.RNNCell, *args, **kwargs)


def __init__(self, input_size, hidden_size, bias=True, p=(0.5, 0.5), initializer=None,initial_method =None):
super(VarFastLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size))
if bias:
self.bias_ih = Parameter(torch.Tensor(4 * hidden_size))
self.bias_hh = Parameter(torch.Tensor(4 * hidden_size))
else:
self.register_parameter('bias_ih', None)
self.register_parameter('bias_hh', None)

self.initializer = default_initializer(self.hidden_size) if initializer is None else initializer
self.reset_parameters()
p_in, p_hidden = p
if p_in < 0 or p_in > 1:
raise ValueError("input dropout probability has to be between 0 and 1, "
"but got {}".format(p_in))
if p_hidden < 0 or p_hidden > 1:
raise ValueError("hidden state dropout probability has to be between 0 and 1, "
"but got {}".format(p_hidden))
self.p_in = p_in
self.p_hidden = p_hidden
self.noise_in = None
self.noise_hidden = None
initial_parameter(self, initial_method)
def reset_parameters(self):
for weight in self.parameters():
if weight.dim() == 1:
weight.data.zero_()
else:
self.initializer(weight.data)

def reset_noise(self, batch_size):
if self.training:
if self.p_in:
noise = self.weight_ih.data.new(batch_size, self.input_size)
self.noise_in = torch.tensor(noise.bernoulli_(1.0 - self.p_in) / (1.0 - self.p_in))
else:
self.noise_in = None

if self.p_hidden:
noise = self.weight_hh.data.new(batch_size, self.hidden_size)
self.noise_hidden = torch.tensor(noise.bernoulli_(1.0 - self.p_hidden) / (1.0 - self.p_hidden))
else:
self.noise_hidden = None
else:
self.noise_in = None
self.noise_hidden = None

def forward(self, input, hx):
return self.__forward(
input, hx,
self.weight_ih, self.weight_hh,
self.bias_ih, self.bias_hh,
self.noise_in, self.noise_hidden,
)

@staticmethod
def __forward(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None):
if noise_in is not None:
if input.is_cuda:
input = input * noise_in.cuda(input.get_device())
else:
input = input * noise_in

if input.is_cuda:
w_ih = w_ih.cuda(input.get_device())
w_hh = w_hh.cuda(input.get_device())
hidden = [h.cuda(input.get_device()) for h in hidden]
b_ih = b_ih.cuda(input.get_device())
b_hh = b_hh.cuda(input.get_device())
igates = F.linear(input, w_ih.cuda(input.get_device()))
hgates = F.linear(hidden[0], w_hh) if noise_hidden is None \
else F.linear(hidden[0] * noise_hidden.cuda(input.get_device()), w_hh)
state = fusedBackend.LSTMFused.apply
# print("use backend")
# use some magic function
return state(igates, hgates, hidden[1]) if b_ih is None else state(igates, hgates, hidden[1], b_ih, b_hh)

hx, cx = hidden
if noise_hidden is not None:
hx = hx * noise_hidden
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)

ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)

cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * F.tanh(cy)

return hy, cy
class VarGRU(VarRNNBase):
"""Variational Dropout GRU.
"""
def __init__(self, *args, **kwargs):
super(VarGRU, self).__init__(mode="GRU", Cell=nn.GRUCell, *args, **kwargs)

+ 37
- 0
reproduction/Biaffine_parser/cfg.cfg View File

@@ -0,0 +1,37 @@
[train]
epochs = 50
batch_size = 16
pickle_path = "./save/"
validate = true
save_best_dev = false
use_cuda = true
model_saved_path = "./save/"
task = "parse"


[test]
save_output = true
validate_in_training = true
save_dev_input = false
save_loss = true
batch_size = 16
pickle_path = "./save/"
use_cuda = true
task = "parse"

[model]
word_vocab_size = -1
word_emb_dim = 100
pos_vocab_size = -1
pos_emb_dim = 100
rnn_layers = 3
rnn_hidden_size = 400
arc_mlp_size = 500
label_mlp_size = 100
num_label = -1
dropout = 0.33
use_var_lstm=true
use_greedy_infer=false

[optim]
lr = 2e-3

+ 260
- 0
reproduction/Biaffine_parser/run.py View File

@@ -0,0 +1,260 @@
import os
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))

from collections import defaultdict
import math
import torch

from fastNLP.core.trainer import Trainer
from fastNLP.core.instance import Instance
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet
from fastNLP.core.batch import Batch
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.field import TextField, SeqLabelField
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle
from fastNLP.core.tester import Tester
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.loader.embed_loader import EmbedLoader
from fastNLP.models.biaffine_parser import BiaffineParser
from fastNLP.saver.model_saver import ModelSaver

# not in the file's dir
if len(os.path.dirname(__file__)) != 0:
os.chdir(os.path.dirname(__file__))

class MyDataLoader(object):
def __init__(self, pickle_path):
self.pickle_path = pickle_path

def load(self, path, word_v=None, pos_v=None, headtag_v=None):
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)

ds = DataSet(name='conll')
for sample in datalist:
# print(sample)
res = self.get_one(sample)
if word_v is not None:
word_v.update(res[0])
pos_v.update(res[1])
headtag_v.update(res[3])
ds.append(Instance(word_seq=TextField(res[0], is_target=False),
pos_seq=TextField(res[1], is_target=False),
head_indices=SeqLabelField(res[2], is_target=True),
head_labels=TextField(res[3], is_target=True),
seq_mask=SeqLabelField([1 for _ in range(len(res[0]))], is_target=False)))

return ds

def get_one(self, sample):
text = ['<root>']
pos_tags = ['<root>']
heads = [0]
head_tags = ['root']
for w in sample:
t1, t2, t3, t4 = w[1], w[3], w[6], w[7]
if t3 == '_':
continue
text.append(t1)
pos_tags.append(t2)
heads.append(int(t3))
head_tags.append(t4)
return (text, pos_tags, heads, head_tags)

def index_data(self, dataset, word_v, pos_v, tag_v):
dataset.index_field('word_seq', word_v)
dataset.index_field('pos_seq', pos_v)
dataset.index_field('head_labels', tag_v)

# datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT"
datadir = "/home/yfshao/UD_English-EWT"
cfgfile = './cfg.cfg'
train_data_name = "en_ewt-ud-train.conllu"
dev_data_name = "en_ewt-ud-dev.conllu"
emb_file_name = '/home/yfshao/glove.6B.100d.txt'
processed_datadir = './save'

# Config Loader
train_args = ConfigSection()
test_args = ConfigSection()
model_args = ConfigSection()
optim_args = ConfigSection()
ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args})

# Data Loader
def save_data(dirpath, **kwargs):
import _pickle
if not os.path.exists(dirpath):
os.mkdir(dirpath)
for name, data in kwargs.items():
with open(os.path.join(dirpath, name+'.pkl'), 'wb') as f:
_pickle.dump(data, f)


def load_data(dirpath):
import _pickle
datas = {}
for f_name in os.listdir(dirpath):
if not f_name.endswith('.pkl'):
continue
name = f_name[:-4]
with open(os.path.join(dirpath, f_name), 'rb') as f:
datas[name] = _pickle.load(f)
return datas

class MyTester(object):
def __init__(self, batch_size, use_cuda=False, **kwagrs):
self.batch_size = batch_size
self.use_cuda = use_cuda

def test(self, model, dataset):
self.model = model.cuda() if self.use_cuda else model
self.model.eval()
batchiter = Batch(dataset, self.batch_size, SequentialSampler(), self.use_cuda)
eval_res = defaultdict(list)
i = 0
for batch_x, batch_y in batchiter:
with torch.no_grad():
pred_y = self.model(**batch_x)
eval_one = self.model.evaluate(**pred_y, **batch_y)
i += self.batch_size
for eval_name, tensor in eval_one.items():
eval_res[eval_name].append(tensor)
tmp = {}
for eval_name, tensorlist in eval_res.items():
tmp[eval_name] = torch.cat(tensorlist, dim=0)

self.res = self.model.metrics(**tmp)

def show_metrics(self):
s = ""
for name, val in self.res.items():
s += '{}: {:.2f}\t'.format(name, val)
return s


loader = MyDataLoader('')
try:
data_dict = load_data(processed_datadir)
word_v = data_dict['word_v']
pos_v = data_dict['pos_v']
tag_v = data_dict['tag_v']
train_data = data_dict['train_data']
dev_data = data_dict['dev_data']
print('use saved pickles')

except Exception as _:
print('load raw data and preprocess')
word_v = Vocabulary(need_default=True, min_freq=2)
pos_v = Vocabulary(need_default=True)
tag_v = Vocabulary(need_default=False)
train_data = loader.load(os.path.join(datadir, train_data_name), word_v, pos_v, tag_v)
dev_data = loader.load(os.path.join(datadir, dev_data_name))
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data)

loader.index_data(train_data, word_v, pos_v, tag_v)
loader.index_data(dev_data, word_v, pos_v, tag_v)
print(len(train_data))
print(len(dev_data))
ep = train_args['epochs']
train_args['epochs'] = math.ceil(50000.0 / len(train_data) * train_args['batch_size']) if ep <= 0 else ep
model_args['word_vocab_size'] = len(word_v)
model_args['pos_vocab_size'] = len(pos_v)
model_args['num_label'] = len(tag_v)


def train():
# Trainer
trainer = Trainer(**train_args.data)

def _define_optim(obj):
obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data)
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: .75 ** (ep / 5e4))

def _update(obj):
obj._scheduler.step()
obj._optimizer.step()

trainer.define_optimizer = lambda: _define_optim(trainer)
trainer.update = lambda: _update(trainer)
trainer.get_loss = lambda predict, truth: trainer._loss_func(**predict, **truth)
trainer._create_validator = lambda x: MyTester(**test_args.data)

# Model
model = BiaffineParser(**model_args.data)

# use pretrain embedding
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl'))
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False)
model.word_embedding.padding_idx = word_v.padding_idx
model.word_embedding.weight.data[word_v.padding_idx].fill_(0)
model.pos_embedding.padding_idx = pos_v.padding_idx
model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0)

try:
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
print('model parameter loaded!')
except Exception as _:
print("No saved model. Continue.")
pass

# Start training
trainer.train(model, train_data, dev_data)
print("Training finished!")

# Saver
saver = ModelSaver("./save/saved_model.pkl")
saver.save_pytorch(model)
print("Model saved!")


def test():
# Tester
tester = MyTester(**test_args.data)

# Model
model = BiaffineParser(**model_args.data)

try:
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
print('model parameter loaded!')
except Exception as _:
print("No saved model. Abort test.")
raise

# Start training
tester.test(model, dev_data)
print(tester.show_metrics())
print("Testing finished!")



if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model')
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer'])
args = parser.parse_args()
if args.mode == 'train':
train()
elif args.mode == 'test':
test()
elif args.mode == 'infer':
infer()
else:
print('no mode specified for model!')
parser.print_help()

+ 12
- 0
test/data_for_tests/glove.6B.50d_test.txt View File

@@ -0,0 +1,12 @@
the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.044457 -0.49688 -0.17862 -0.00066023 -0.6566 0.27843 -0.14767 -0.55677 0.14658 -0.0095095 0.011658 0.10204 -0.12792 -0.8443 -0.12181 -0.016801 -0.33279 -0.1552 -0.23131 -0.19181 -1.8823 -0.76746 0.099051 -0.42125 -0.19526 4.0071 -0.18594 -0.52287 -0.31681 0.00059213 0.0074449 0.17778 -0.15897 0.012041 -0.054223 -0.29871 -0.15749 -0.34758 -0.045637 -0.44251 0.18785 0.0027849 -0.18411 -0.11514 -0.78581
, 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -0.42852 -0.55641 -0.364 -0.23938 0.13001 -0.063734 -0.39575 -0.48162 0.23291 0.090201 -0.13324 0.078639 -0.41634 -0.15428 0.10068 0.48891 0.31226 -0.1252 -0.037512 -1.5179 0.12612 -0.02442 -0.042961 -0.28351 3.5416 -0.11956 -0.014533 -0.1499 0.21864 -0.33412 -0.13872 0.31806 0.70358 0.44858 -0.080262 0.63003 0.32111 -0.46765 0.22786 0.36034 -0.37818 -0.56657 0.044691 0.30392
. 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 -0.43478 -0.31086 -0.44999 -0.29486 0.16608 0.11963 -0.41328 -0.42353 0.59868 0.28825 -0.11547 -0.041848 -0.67989 -0.25063 0.18472 0.086876 0.46582 0.015035 0.043474 -1.4671 -0.30384 -0.023441 0.30589 -0.21785 3.746 0.0042284 -0.18436 -0.46209 0.098329 -0.11907 0.23919 0.1161 0.41705 0.056763 -6.3681e-05 0.068987 0.087939 -0.10285 -0.13931 0.22314 -0.080803 -0.35652 0.016413 0.10216
of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 0.18157 -0.52393 0.10381 -0.17566 0.078852 -0.36216 -0.11829 -0.83336 0.11917 -0.16605 0.061555 -0.012719 -0.56623 0.013616 0.22851 -0.14396 -0.067549 -0.38157 -0.23698 -1.7037 -0.86692 -0.26704 -0.2589 0.1767 3.8676 -0.1613 -0.13273 -0.68881 0.18444 0.0052464 -0.33874 -0.078956 0.24185 0.36576 -0.34727 0.28483 0.075693 -0.062178 -0.38988 0.22902 -0.21617 -0.22562 -0.093918 -0.80375
to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -0.41376 0.13228 -0.29847 -0.085253 0.17118 0.22419 -0.10046 -0.43653 0.33418 0.67846 0.057204 -0.34448 -0.42785 -0.43275 0.55963 0.10032 0.18677 -0.26854 0.037334 -2.0932 0.22171 -0.39868 0.20912 -0.55725 3.8826 0.47466 -0.95658 -0.37788 0.20869 -0.32752 0.12751 0.088359 0.16351 -0.21634 -0.094375 0.018324 0.21048 -0.03088 -0.19722 0.082279 -0.09434 -0.073297 -0.064699 -0.26044
and 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 -0.51332 -0.47368 -0.33075 -0.13834 0.2702 0.30938 -0.45012 -0.4127 -0.09932 0.038085 0.029749 0.10076 -0.25058 -0.51818 0.34558 0.44922 0.48791 -0.080866 -0.10121 -1.3777 -0.10866 -0.23201 0.012839 -0.46508 3.8463 0.31362 0.13643 -0.52244 0.3302 0.33707 -0.35601 0.32431 0.12041 0.3512 -0.069043 0.36885 0.25168 -0.24517 0.25381 0.1367 -0.31178 -0.6321 -0.25028 -0.38097
in 0.33042 0.24995 -0.60874 0.10923 0.036372 0.151 -0.55083 -0.074239 -0.092307 -0.32821 0.09598 -0.82269 -0.36717 -0.67009 0.42909 0.016496 -0.23573 0.12864 -1.0953 0.43334 0.57067 -0.1036 0.20422 0.078308 -0.42795 -1.7984 -0.27865 0.11954 -0.12689 0.031744 3.8631 -0.17786 -0.082434 -0.62698 0.26497 -0.057185 -0.073521 0.46103 0.30862 0.12498 -0.48609 -0.0080272 0.031184 -0.36576 -0.42699 0.42164 -0.11666 -0.50703 -0.027273 -0.53285
a 0.21705 0.46515 -0.46757 0.10082 1.0135 0.74845 -0.53104 -0.26256 0.16812 0.13182 -0.24909 -0.44185 -0.21739 0.51004 0.13448 -0.43141 -0.03123 0.20674 -0.78138 -0.20148 -0.097401 0.16088 -0.61836 -0.18504 -0.12461 -2.2526 -0.22321 0.5043 0.32257 0.15313 3.9636 -0.71365 -0.67012 0.28388 0.21738 0.14433 0.25926 0.23434 0.4274 -0.44451 0.13813 0.36973 -0.64289 0.024142 -0.039315 -0.26037 0.12017 -0.043782 0.41013 0.1796
" 0.25769 0.45629 -0.76974 -0.37679 0.59272 -0.063527 0.20545 -0.57385 -0.29009 -0.13662 0.32728 1.4719 -0.73681 -0.12036 0.71354 -0.46098 0.65248 0.48887 -0.51558 0.039951 -0.34307 -0.014087 0.86488 0.3546 0.7999 -1.4995 -1.8153 0.41128 0.23921 -0.43139 3.6623 -0.79834 -0.54538 0.16943 -0.82017 -0.3461 0.69495 -1.2256 -0.17992 -0.057474 0.030498 -0.39543 -0.38515 -1.0002 0.087599 -0.31009 -0.34677 -0.31438 0.75004 0.97065
's 0.23727 0.40478 -0.20547 0.58805 0.65533 0.32867 -0.81964 -0.23236 0.27428 0.24265 0.054992 0.16296 -1.2555 -0.086437 0.44536 0.096561 -0.16519 0.058378 -0.38598 0.086977 0.0033869 0.55095 -0.77697 -0.62096 0.092948 -2.5685 -0.67739 0.10151 -0.48643 -0.057805 3.1859 -0.017554 -0.16138 0.055486 -0.25885 -0.33938 -0.19928 0.26049 0.10478 -0.55934 -0.12342 0.65961 -0.51802 -0.82995 -0.082739 0.28155 -0.423 -0.27378 -0.007901 -0.030231



+ 33
- 0
test/loader/test_embed_loader.py View File

@@ -0,0 +1,33 @@
import unittest
import os

import torch

from fastNLP.loader.embed_loader import EmbedLoader
from fastNLP.core.vocabulary import Vocabulary


class TestEmbedLoader(unittest.TestCase):
glove_path = './test/data_for_tests/glove.6B.50d_test.txt'
pkl_path = './save'
raw_texts = ["i am a cat",
"this is a test of new batch",
"ha ha",
"I am a good boy .",
"This is the most beautiful girl ."
]
texts = [text.strip().split() for text in raw_texts]
vocab = Vocabulary()
vocab.update(texts)
def test1(self):
emb, _ = EmbedLoader.load_embedding(50, self.glove_path, 'glove', self.vocab, self.pkl_path)
self.assertTrue(emb.shape[0] == (len(self.vocab)))
self.assertTrue(emb.shape[1] == 50)
os.remove(self.pkl_path)
def test2(self):
try:
_ = EmbedLoader.load_embedding(100, self.glove_path, 'glove', self.vocab, self.pkl_path)
self.fail(msg="load dismatch embedding")
except ValueError:
pass

+ 4
- 16
test/modules/test_variational_rnn.py View File

@@ -3,35 +3,23 @@ import unittest
import numpy as np import numpy as np
import torch import torch


from fastNLP.modules.encoder.variational_rnn import VarMaskedFastLSTM
from fastNLP.modules.encoder.variational_rnn import VarLSTM




class TestMaskedRnn(unittest.TestCase): class TestMaskedRnn(unittest.TestCase):
def test_case_1(self): def test_case_1(self):
masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True)
masked_rnn = VarLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True)
x = torch.tensor([[[1.0], [2.0]]]) x = torch.tensor([[[1.0], [2.0]]])
print(x.size()) print(x.size())
y = masked_rnn(x) y = masked_rnn(x)
mask = torch.tensor([[[1], [1]]])
y = masked_rnn(x, mask=mask)
mask = torch.tensor([[[1], [0]]])
y = masked_rnn(x, mask=mask)



def test_case_2(self): def test_case_2(self):
input_size = 12 input_size = 12
batch = 16 batch = 16
hidden = 10 hidden = 10
masked_rnn = VarMaskedFastLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True)

x = torch.randn((batch, input_size))
output, _ = masked_rnn.step(x)
self.assertEqual(tuple(output.shape), (batch, hidden))
masked_rnn = VarLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True)


xx = torch.randn((batch, 32, input_size)) xx = torch.randn((batch, 32, input_size))
y, _ = masked_rnn(xx) y, _ = masked_rnn(xx)
self.assertEqual(tuple(y.shape), (batch, 32, hidden)) self.assertEqual(tuple(y.shape), (batch, 32, hidden))

xx = torch.randn((batch, 32, input_size))
mask = torch.from_numpy(np.random.randint(0, 2, size=(batch, 32))).to(xx)
y, _ = masked_rnn(xx, mask=mask)
self.assertEqual(tuple(y.shape), (batch, 32, hidden))

Loading…
Cancel
Save