| @@ -175,12 +175,11 @@ class LabelBilinear(nn.Module): | |||||
| def __init__(self, in1_features, in2_features, num_label, bias=True): | def __init__(self, in1_features, in2_features, num_label, bias=True): | ||||
| super(LabelBilinear, self).__init__() | super(LabelBilinear, self).__init__() | ||||
| self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) | self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) | ||||
| self.lin1 = nn.Linear(in1_features, num_label, bias=False) | |||||
| self.lin2 = nn.Linear(in2_features, num_label, bias=False) | |||||
| self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) | |||||
| def forward(self, x1, x2): | def forward(self, x1, x2): | ||||
| output = self.bilinear(x1, x2) | output = self.bilinear(x1, x2) | ||||
| output += self.lin1(x1) + self.lin2(x2) | |||||
| output += self.lin(torch.cat([x1, x2], dim=2)) | |||||
| return output | return output | ||||
| @@ -226,15 +225,16 @@ class BiaffineParser(GraphParser): | |||||
| rnn_out_size = 2 * rnn_hidden_size | rnn_out_size = 2 * rnn_hidden_size | ||||
| self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), | self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), | ||||
| nn.ELU()) | |||||
| nn.ELU(), | |||||
| TimestepDropout(p=dropout),) | |||||
| self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) | self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) | ||||
| self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), | self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), | ||||
| nn.ELU()) | |||||
| nn.ELU(), | |||||
| TimestepDropout(p=dropout),) | |||||
| self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | ||||
| self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | ||||
| self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | ||||
| self.normal_dropout = nn.Dropout(p=dropout) | self.normal_dropout = nn.Dropout(p=dropout) | ||||
| self.timestep_dropout = TimestepDropout(p=dropout) | |||||
| self.use_greedy_infer = use_greedy_infer | self.use_greedy_infer = use_greedy_infer | ||||
| initial_parameter(self) | initial_parameter(self) | ||||
| @@ -267,10 +267,10 @@ class BiaffineParser(GraphParser): | |||||
| # for arc biaffine | # for arc biaffine | ||||
| # mlp, reduce dim | # mlp, reduce dim | ||||
| arc_dep = self.timestep_dropout(self.arc_dep_mlp(feat)) | |||||
| arc_head = self.timestep_dropout(self.arc_head_mlp(feat)) | |||||
| label_dep = self.timestep_dropout(self.label_dep_mlp(feat)) | |||||
| label_head = self.timestep_dropout(self.label_head_mlp(feat)) | |||||
| arc_dep = self.arc_dep_mlp(feat) | |||||
| arc_head = self.arc_head_mlp(feat) | |||||
| label_dep = self.label_dep_mlp(feat) | |||||
| label_head = self.label_head_mlp(feat) | |||||
| # biaffine arc classifier | # biaffine arc classifier | ||||
| arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | ||||