diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py index a2a00a29..845e372f 100644 --- a/fastNLP/models/biaffine_parser.py +++ b/fastNLP/models/biaffine_parser.py @@ -175,12 +175,11 @@ class LabelBilinear(nn.Module): def __init__(self, in1_features, in2_features, num_label, bias=True): super(LabelBilinear, self).__init__() self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) - self.lin1 = nn.Linear(in1_features, num_label, bias=False) - self.lin2 = nn.Linear(in2_features, num_label, bias=False) + self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) def forward(self, x1, x2): output = self.bilinear(x1, x2) - output += self.lin1(x1) + self.lin2(x2) + output += self.lin(torch.cat([x1, x2], dim=2)) return output @@ -226,15 +225,16 @@ class BiaffineParser(GraphParser): rnn_out_size = 2 * rnn_hidden_size self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), - nn.ELU()) + nn.ELU(), + TimestepDropout(p=dropout),) self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), - nn.ELU()) + nn.ELU(), + TimestepDropout(p=dropout),) self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) self.normal_dropout = nn.Dropout(p=dropout) - self.timestep_dropout = TimestepDropout(p=dropout) self.use_greedy_infer = use_greedy_infer initial_parameter(self) @@ -267,10 +267,10 @@ class BiaffineParser(GraphParser): # for arc biaffine # mlp, reduce dim - arc_dep = self.timestep_dropout(self.arc_dep_mlp(feat)) - arc_head = self.timestep_dropout(self.arc_head_mlp(feat)) - label_dep = self.timestep_dropout(self.label_dep_mlp(feat)) - label_head = self.timestep_dropout(self.label_head_mlp(feat)) + arc_dep = self.arc_dep_mlp(feat) + arc_head = self.arc_head_mlp(feat) + label_dep = self.label_dep_mlp(feat) + label_head = self.label_head_mlp(feat) # biaffine arc classifier arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]