|
@@ -175,12 +175,11 @@ class LabelBilinear(nn.Module): |
|
|
def __init__(self, in1_features, in2_features, num_label, bias=True): |
|
|
def __init__(self, in1_features, in2_features, num_label, bias=True): |
|
|
super(LabelBilinear, self).__init__() |
|
|
super(LabelBilinear, self).__init__() |
|
|
self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) |
|
|
self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) |
|
|
self.lin1 = nn.Linear(in1_features, num_label, bias=False) |
|
|
|
|
|
self.lin2 = nn.Linear(in2_features, num_label, bias=False) |
|
|
|
|
|
|
|
|
self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) |
|
|
|
|
|
|
|
|
def forward(self, x1, x2): |
|
|
def forward(self, x1, x2): |
|
|
output = self.bilinear(x1, x2) |
|
|
output = self.bilinear(x1, x2) |
|
|
output += self.lin1(x1) + self.lin2(x2) |
|
|
|
|
|
|
|
|
output += self.lin(torch.cat([x1, x2], dim=2)) |
|
|
return output |
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -226,15 +225,16 @@ class BiaffineParser(GraphParser): |
|
|
|
|
|
|
|
|
rnn_out_size = 2 * rnn_hidden_size |
|
|
rnn_out_size = 2 * rnn_hidden_size |
|
|
self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), |
|
|
self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), |
|
|
nn.ELU()) |
|
|
|
|
|
|
|
|
nn.ELU(), |
|
|
|
|
|
TimestepDropout(p=dropout),) |
|
|
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) |
|
|
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) |
|
|
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), |
|
|
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), |
|
|
nn.ELU()) |
|
|
|
|
|
|
|
|
nn.ELU(), |
|
|
|
|
|
TimestepDropout(p=dropout),) |
|
|
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) |
|
|
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) |
|
|
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) |
|
|
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) |
|
|
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) |
|
|
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) |
|
|
self.normal_dropout = nn.Dropout(p=dropout) |
|
|
self.normal_dropout = nn.Dropout(p=dropout) |
|
|
self.timestep_dropout = TimestepDropout(p=dropout) |
|
|
|
|
|
self.use_greedy_infer = use_greedy_infer |
|
|
self.use_greedy_infer = use_greedy_infer |
|
|
initial_parameter(self) |
|
|
initial_parameter(self) |
|
|
|
|
|
|
|
@@ -267,10 +267,10 @@ class BiaffineParser(GraphParser): |
|
|
|
|
|
|
|
|
# for arc biaffine |
|
|
# for arc biaffine |
|
|
# mlp, reduce dim |
|
|
# mlp, reduce dim |
|
|
arc_dep = self.timestep_dropout(self.arc_dep_mlp(feat)) |
|
|
|
|
|
arc_head = self.timestep_dropout(self.arc_head_mlp(feat)) |
|
|
|
|
|
label_dep = self.timestep_dropout(self.label_dep_mlp(feat)) |
|
|
|
|
|
label_head = self.timestep_dropout(self.label_head_mlp(feat)) |
|
|
|
|
|
|
|
|
arc_dep = self.arc_dep_mlp(feat) |
|
|
|
|
|
arc_head = self.arc_head_mlp(feat) |
|
|
|
|
|
label_dep = self.label_dep_mlp(feat) |
|
|
|
|
|
label_head = self.label_head_mlp(feat) |
|
|
|
|
|
|
|
|
# biaffine arc classifier |
|
|
# biaffine arc classifier |
|
|
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] |
|
|
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] |
|
|