Browse Source

update const files

tags/v0.4.10
xuyige 6 years ago
parent
commit
ae8f74b31d
2 changed files with 23 additions and 7 deletions
  1. +8
    -0
      fastNLP/core/const.py
  2. +15
    -7
      fastNLP/models/snli.py

+ 8
- 0
fastNLP/core/const.py View File

@@ -7,6 +7,7 @@ class Const:
INPUT_LEN 序列长度 seq_len(复数seq_len1,seq_len2) INPUT_LEN 序列长度 seq_len(复数seq_len1,seq_len2)
OUTPUT 模型输出 pred(复数pred1, pred2) OUTPUT 模型输出 pred(复数pred1, pred2)
TARGET 真实目标 target(复数target1,target2) TARGET 真实目标 target(复数target1,target2)
LOSS 损失函数 loss (复数loss1,loss2)


""" """
INPUT = 'words' INPUT = 'words'
@@ -14,6 +15,7 @@ class Const:
INPUT_LEN = 'seq_len' INPUT_LEN = 'seq_len'
OUTPUT = 'pred' OUTPUT = 'pred'
TARGET = 'target' TARGET = 'target'
LOSS = 'loss'


@staticmethod @staticmethod
def INPUTS(i): def INPUTS(i):
@@ -44,3 +46,9 @@ class Const:
"""得到第 i 个 ``TARGET`` 的命名""" """得到第 i 个 ``TARGET`` 的命名"""
i = int(i) + 1 i = int(i) + 1
return Const.TARGET + str(i) return Const.TARGET + str(i)

@staticmethod
def LOSSES(i):
"""得到第 i 个 ``LOSS`` 的命名"""
i = int(i) + 1
return Const.LOSS + str(i)

+ 15
- 7
fastNLP/models/snli.py View File

@@ -2,6 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn


from .base_model import BaseModel from .base_model import BaseModel
from ..core.const import Const
from ..modules import decoder as Decoder from ..modules import decoder as Decoder
from ..modules import encoder as Encoder from ..modules import encoder as Encoder
from ..modules import aggregator as Aggregator from ..modules import aggregator as Aggregator
@@ -40,7 +41,7 @@ class ESIM(BaseModel):
(self.vocab_size, self.embed_dim), dropout=self.dropout, (self.vocab_size, self.embed_dim), dropout=self.dropout,
) )


self.embedding_layer = Encoder.Linear(self.embed_dim, self.hidden_size)
self.embedding_layer = nn.Linear(self.embed_dim, self.hidden_size)


self.encoder = Encoder.LSTM( self.encoder = Encoder.LSTM(
input_size=self.embed_dim, hidden_size=self.hidden_size, num_layers=1, bias=True, input_size=self.embed_dim, hidden_size=self.hidden_size, num_layers=1, bias=True,
@@ -51,7 +52,7 @@ class ESIM(BaseModel):
self.mean_pooling = Aggregator.MeanPoolWithMask() self.mean_pooling = Aggregator.MeanPoolWithMask()
self.max_pooling = Aggregator.MaxPoolWithMask() self.max_pooling = Aggregator.MaxPoolWithMask()


self.inference_layer = Encoder.Linear(self.hidden_size * 4, self.hidden_size)
self.inference_layer = nn.Linear(self.hidden_size * 4, self.hidden_size)


self.decoder = Encoder.LSTM( self.decoder = Encoder.LSTM(
input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=1, bias=True, input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=1, bias=True,
@@ -60,12 +61,13 @@ class ESIM(BaseModel):


self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout)


def forward(self, words1, words2, seq_len1=None, seq_len2=None):
def forward(self, words1, words2, seq_len1=None, seq_len2=None, target=None):
""" Forward function """ Forward function
:param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示 :param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示
:param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示 :param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示
:param torch.LongTensor seq_len1: [B] premise的长度 :param torch.LongTensor seq_len1: [B] premise的长度
:param torch.LongTensor seq_len2: [B] hypothesis的长度 :param torch.LongTensor seq_len2: [B] hypothesis的长度
:param torch.LongTensor target: [B] 真实目标值
:return: dict prediction: [B, n_labels(N)] 预测结果 :return: dict prediction: [B, n_labels(N)] 预测结果
""" """


@@ -124,17 +126,23 @@ class ESIM(BaseModel):


prediction = torch.tanh(self.output(v)) # prediction: [B, N] prediction = torch.tanh(self.output(v)) # prediction: [B, N]


return {'pred': prediction}
if target is not None:
func = nn.CrossEntropyLoss()
loss = func(prediction, target)
return {Const.OUTPUT: prediction, Const.LOSS: loss}


def predict(self, words1, words2, seq_len1, seq_len2):
return {Const.OUTPUT: prediction}

def predict(self, words1, words2, seq_len1=None, seq_len2=None, target=None):
""" Predict function """ Predict function


:param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示 :param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示
:param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示 :param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示
:param torch.LongTensor seq_len1: [B] premise的长度 :param torch.LongTensor seq_len1: [B] premise的长度
:param torch.LongTensor seq_len2: [B] hypothesis的长度 :param torch.LongTensor seq_len2: [B] hypothesis的长度
:param torch.LongTensor target: [B] 真实目标值
:return: dict prediction: [B, n_labels(N)] 预测结果 :return: dict prediction: [B, n_labels(N)] 预测结果
""" """
prediction = self.forward(words1, words2, seq_len1, seq_len2)['pred']
return {'pred': torch.argmax(prediction, dim=-1)}
prediction = self.forward(words1, words2, seq_len1, seq_len2)[Const.OUTPUT]
return {Const.OUTPUT: torch.argmax(prediction, dim=-1)}



Loading…
Cancel
Save