Browse Source

Merge branch 'dev' of github.com:choosewhatulike/fastNLP-private into dev

tags/v0.4.10
yh_cc 5 years ago
parent
commit
af6a9da78d
8 changed files with 49 additions and 30 deletions
  1. +1
    -1
      fastNLP/core/dataset.py
  2. +21
    -19
      fastNLP/modules/aggregator/attention.py
  3. +1
    -2
      test/core/test_predictor.py
  4. +1
    -1
      test/io/test_config_saver.py
  5. +5
    -1
      test/models/model_runner.py
  6. +16
    -2
      test/models/test_biaffine_parser.py
  7. +3
    -3
      test/models/test_star_trans.py
  8. +1
    -1
      test/modules/test_other_modules.py

+ 1
- 1
fastNLP/core/dataset.py View File

@@ -97,7 +97,7 @@
# 将句子分成单词形式, 详见DataSet.apply()方法 # 将句子分成单词形式, 详见DataSet.apply()方法
dataset.apply(lambda ins: ins['sentence'].split(), new_field_name='words') dataset.apply(lambda ins: ins['sentence'].split(), new_field_name='words')
# 或使用DataSet.apply_field() # 或使用DataSet.apply_field()
dataset.apply(lambda sent:sent.split(), field_name='sentence', new_field_name='words')
dataset.apply_field(lambda sent:sent.split(), field_name='sentence', new_field_name='words')
# 除了匿名函数,也可以定义函数传递进去 # 除了匿名函数,也可以定义函数传递进去
def get_words(instance): def get_words(instance):
sentence = instance['sentence'] sentence = instance['sentence']


+ 21
- 19
fastNLP/modules/aggregator/attention.py View File

@@ -14,7 +14,7 @@ class DotAttention(nn.Module):
""" """
TODO TODO
""" """
def __init__(self, key_size, value_size, dropout=0.1):
def __init__(self, key_size, value_size, dropout=0):
super(DotAttention, self).__init__() super(DotAttention, self).__init__()
self.key_size = key_size self.key_size = key_size
self.value_size = value_size self.value_size = value_size
@@ -25,14 +25,14 @@ class DotAttention(nn.Module):
def forward(self, Q, K, V, mask_out=None): def forward(self, Q, K, V, mask_out=None):
""" """


:param Q: [batch, seq_len, key_size]
:param K: [batch, seq_len, key_size]
:param V: [batch, seq_len, value_size]
:param mask_out: [batch, seq_len]
:param Q: [batch, seq_len_q, key_size]
:param K: [batch, seq_len_k, key_size]
:param V: [batch, seq_len_k, value_size]
:param mask_out: [batch, 1, seq_len] or [batch, seq_len_q, seq_len_k]
""" """
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale output = torch.matmul(Q, K.transpose(1, 2)) / self.scale
if mask_out is not None: if mask_out is not None:
output.masked_fill_(mask_out, -float('inf'))
output.masked_fill_(mask_out, -1e8)
output = self.softmax(output) output = self.softmax(output)
output = self.drop(output) output = self.drop(output)
return torch.matmul(output, V) return torch.matmul(output, V)
@@ -58,7 +58,8 @@ class MultiHeadAttention(nn.Module):
self.q_in = nn.Linear(input_size, in_size) self.q_in = nn.Linear(input_size, in_size)
self.k_in = nn.Linear(input_size, in_size) self.k_in = nn.Linear(input_size, in_size)
self.v_in = nn.Linear(input_size, in_size) self.v_in = nn.Linear(input_size, in_size)
self.attention = DotAttention(key_size=key_size, value_size=value_size)
# follow the paper, do not apply dropout within dot-product
self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=0)
self.out = nn.Linear(value_size * num_head, input_size) self.out = nn.Linear(value_size * num_head, input_size)
self.drop = TimestepDropout(dropout) self.drop = TimestepDropout(dropout)
self.reset_parameters() self.reset_parameters()
@@ -73,28 +74,29 @@ class MultiHeadAttention(nn.Module):
def forward(self, Q, K, V, atte_mask_out=None): def forward(self, Q, K, V, atte_mask_out=None):
""" """


:param Q: [batch, seq_len, model_size]
:param K: [batch, seq_len, model_size]
:param V: [batch, seq_len, model_size]
:param Q: [batch, seq_len_q, model_size]
:param K: [batch, seq_len_k, model_size]
:param V: [batch, seq_len_k, model_size]
:param seq_mask: [batch, seq_len] :param seq_mask: [batch, seq_len]
""" """
batch, seq_len, _ = Q.size()
batch, sq, _ = Q.size()
sk = K.size(1)
d_k, d_v, n_head = self.key_size, self.value_size, self.num_head d_k, d_v, n_head = self.key_size, self.value_size, self.num_head
# input linear # input linear
q = self.q_in(Q).view(batch, seq_len, n_head, d_k)
k = self.k_in(K).view(batch, seq_len, n_head, d_k)
v = self.v_in(V).view(batch, seq_len, n_head, d_k)
q = self.q_in(Q).view(batch, sq, n_head, d_k)
k = self.k_in(K).view(batch, sk, n_head, d_k)
v = self.v_in(V).view(batch, sk, n_head, d_v)


# transpose q, k and v to do batch attention # transpose q, k and v to do batch attention
q = q.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k)
k = k.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k)
v = v.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_v)
q = q.permute(2, 0, 1, 3).contiguous().view(-1, sq, d_k)
k = k.permute(2, 0, 1, 3).contiguous().view(-1, sk, d_k)
v = v.permute(2, 0, 1, 3).contiguous().view(-1, sk, d_v)
if atte_mask_out is not None: if atte_mask_out is not None:
atte_mask_out = atte_mask_out.repeat(n_head, 1, 1) atte_mask_out = atte_mask_out.repeat(n_head, 1, 1)
atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, seq_len, d_v)
atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, sq, d_v)


# concat all heads, do output linear # concat all heads, do output linear
atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, seq_len, -1)
atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1)
output = self.drop(self.out(atte)) output = self.drop(self.out(atte))
return output return output




+ 1
- 2
test/core/test_predictor.py View File

@@ -7,7 +7,6 @@ import torch
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.core.predictor import Predictor from fastNLP.core.predictor import Predictor
from fastNLP.modules.encoder.linear import Linear




def prepare_fake_dataset(): def prepare_fake_dataset():
@@ -27,7 +26,7 @@ def prepare_fake_dataset():
class LinearModel(torch.nn.Module): class LinearModel(torch.nn.Module):
def __init__(self): def __init__(self):
super(LinearModel, self).__init__() super(LinearModel, self).__init__()
self.linear = Linear(2, 1)
self.linear = torch.nn.Linear(2, 1)


def forward(self, x): def forward(self, x):
return {"predict": self.linear(x)} return {"predict": self.linear(x)}


+ 1
- 1
test/io/test_config_saver.py View File

@@ -1,7 +1,7 @@
import os import os
import unittest import unittest


from fastNLP.io import ConfigSection, ConfigLoader, ConfigSaver
# from fastNLP.io import ConfigSection, ConfigLoader, ConfigSaver




class TestConfigSaver(unittest.TestCase): class TestConfigSaver(unittest.TestCase):


+ 5
- 1
test/models/model_runner.py View File

@@ -24,7 +24,7 @@ Example::
RUNNER.run_model(model, data=get_mydata(), RUNNER.run_model(model, data=get_mydata(),
loss=Myloss(), metrics=Mymetric()) loss=Myloss(), metrics=Mymetric())
""" """
from fastNLP import Trainer, Tester, DataSet
from fastNLP import Trainer, Tester, DataSet, Callback
from fastNLP import AccuracyMetric from fastNLP import AccuracyMetric
from fastNLP import CrossEntropyLoss from fastNLP import CrossEntropyLoss
from fastNLP.core.const import Const as C from fastNLP.core.const import Const as C
@@ -42,6 +42,10 @@ POS_TAGGING = 'pos_tagging'
NLI = 'nli' NLI = 'nli'


class ModelRunner(): class ModelRunner():
class Checker(Callback):
def on_backward_begin(self, loss):
assert loss.to('cpu').numpy().isfinate()

def gen_seq(self, length, vocab_size): def gen_seq(self, length, vocab_size):
"""generate fake sequence indexes with given length""" """generate fake sequence indexes with given length"""
# reserve 0 for padding # reserve 0 for padding


+ 16
- 2
test/models/test_biaffine_parser.py View File

@@ -25,10 +25,24 @@ def prepare_parser_data():
is_input=True, is_target=True) is_input=True, is_target=True)
return ds return ds



class TestBiaffineParser(unittest.TestCase): class TestBiaffineParser(unittest.TestCase):
def test_train(self): def test_train(self):
model = BiaffineParser(init_embed=(VOCAB_SIZE, 30),
pos_vocab_size=VOCAB_SIZE, pos_emb_dim=30,
model = BiaffineParser(init_embed=(VOCAB_SIZE, 10),
pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10,
rnn_hidden_size=10,
arc_mlp_size=10,
label_mlp_size=10,
num_label=NUM_CLS, encoder='var-lstm') num_label=NUM_CLS, encoder='var-lstm')
ds = prepare_parser_data() ds = prepare_parser_data()
RUNNER.run_model(model, ds, loss=ParserLoss(), metrics=ParserMetric()) RUNNER.run_model(model, ds, loss=ParserLoss(), metrics=ParserMetric())

def test_train2(self):
model = BiaffineParser(init_embed=(VOCAB_SIZE, 10),
pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10,
rnn_hidden_size=16,
arc_mlp_size=10,
label_mlp_size=10,
num_label=NUM_CLS, encoder='transformer')
ds = prepare_parser_data()
RUNNER.run_model(model, ds, loss=ParserLoss(), metrics=ParserMetric())

+ 3
- 3
test/models/test_star_trans.py View File

@@ -4,13 +4,13 @@ from fastNLP.models.star_transformer import STNLICls, STSeqCls, STSeqLabel


# add star-transformer tests, for 3 kinds of tasks. # add star-transformer tests, for 3 kinds of tasks.
def test_cls(): def test_cls():
model = STSeqCls((VOCAB_SIZE, 100), NUM_CLS, dropout=0)
model = STSeqCls((VOCAB_SIZE, 10), NUM_CLS, dropout=0)
RUNNER.run_model_with_task(TEXT_CLS, model) RUNNER.run_model_with_task(TEXT_CLS, model)


def test_nli(): def test_nli():
model = STNLICls((VOCAB_SIZE, 100), NUM_CLS, dropout=0)
model = STNLICls((VOCAB_SIZE, 10), NUM_CLS, dropout=0)
RUNNER.run_model_with_task(NLI, model) RUNNER.run_model_with_task(NLI, model)


def test_seq_label(): def test_seq_label():
model = STSeqLabel((VOCAB_SIZE, 100), NUM_CLS, dropout=0)
model = STSeqLabel((VOCAB_SIZE, 10), NUM_CLS, dropout=0)
RUNNER.run_model_with_task(POS_TAGGING, model) RUNNER.run_model_with_task(POS_TAGGING, model)

+ 1
- 1
test/modules/test_other_modules.py View File

@@ -2,7 +2,7 @@ import unittest


import torch import torch


from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear, BiAffine
# from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear, BiAffine
from fastNLP.modules.encoder.star_transformer import StarTransformer from fastNLP.modules.encoder.star_transformer import StarTransformer






Loading…
Cancel
Save