Browse Source

[bugfix]修复 coreference resolution复现代码中参数名字不对应的bug (#323)

* pipeline

* 修复找不到对应参数的bug

* 增加requirement文件
tags/v0.6.0
liuxiaoxiong GitHub 4 years ago
parent
commit
acdebfccbc
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 7 deletions
  1. +2
    -1
      reproduction/coreference_resolution/model/metric.py
  2. +23
    -6
      reproduction/coreference_resolution/model/model_re.py
  3. +5
    -0
      reproduction/coreference_resolution/requirements.txt
  4. +2
    -0
      reproduction/coreference_resolution/train.py

+ 2
- 1
reproduction/coreference_resolution/model/metric.py View File

@@ -17,7 +17,8 @@ class CRMetric(MetricBase):
self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)] self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)]


# TODO 改名为evaluate,输入也 # TODO 改名为evaluate,输入也
def evaluate(self, predicted, mention_to_predicted,clusters):
def evaluate(self, predicted, mention_to_predicted,target):
clusters = target
for e in self.evaluators: for e in self.evaluators:
e.update(predicted,mention_to_predicted, clusters) e.update(predicted,mention_to_predicted, clusters)




+ 23
- 6
reproduction/coreference_resolution/model/model_re.py View File

@@ -17,7 +17,6 @@ torch.cuda.manual_seed(0) # gpu
np.random.seed(0) # numpy np.random.seed(0) # numpy
random.seed(0) random.seed(0)



class ffnn(nn.Module): class ffnn(nn.Module):
def __init__(self, input_size, hidden_size, output_size): def __init__(self, input_size, hidden_size, output_size):
super(ffnn, self).__init__() super(ffnn, self).__init__()
@@ -565,19 +564,37 @@ class Model(BaseModel):


return ans return ans


def predict(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len):
def predict(self, words1 , words2, words3, words4, chars, seq_len):
"""
实际输入都是tensor
:param sentences: 句子,被fastNLP转化成了numpy,
:param doc_np: 被fastNLP转化成了Tensor
:param speaker_ids_np: 被fastNLP转化成了Tensor
:param genre: 被fastNLP转化成了Tensor
:param char_index: 被fastNLP转化成了Tensor
:param seq_len: 被fastNLP转化成了Tensor
:return:
"""
sentences = words1
doc_np = words2
speaker_ids_np = words3
genre = words4
char_index = chars

# def predict(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len):
ans = self(sentences, ans = self(sentences,
doc_np, doc_np,
speaker_ids_np, speaker_ids_np,
genre, genre,
char_index, char_index,
seq_len) seq_len)

predicted_antecedents = self.get_predicted_antecedents(ans["antecedents"], ans["antecedent_scores"])
predicted_clusters, mention_to_predicted = self.get_predicted_clusters(ans["mention_start_tensor"],
ans["mention_end_tensor"],
predicted_antecedents = self.get_predicted_antecedents(ans["antecedents"], ans["antecedent_scores"].cpu())
predicted_clusters, mention_to_predicted = self.get_predicted_clusters(ans["mention_start_tensor"].cpu(),
ans["mention_end_tensor"].cpu(),
predicted_antecedents) predicted_antecedents)



return {'predicted':predicted_clusters,"mention_to_predicted":mention_to_predicted} return {'predicted':predicted_clusters,"mention_to_predicted":mention_to_predicted}






+ 5
- 0
reproduction/coreference_resolution/requirements.txt View File

@@ -0,0 +1,5 @@
prettytable==0.7.2
allennlp==0.8.2
scikit-learn==0.22.2
pyhocon==0.3.50
torch==1.1

+ 2
- 0
reproduction/coreference_resolution/train.py View File

@@ -1,3 +1,5 @@
import sys
sys.path.append('../..')


import torch import torch
from torch.optim import Adam from torch.optim import Adam


Loading…
Cancel
Save