- import sys
- sys.path.extend(['/home/yfshao/workdir/dev_fastnlp'])
- import torch
- import argparse
- from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag
- from fastNLP.core.dataset import DataSet
- from fastNLP.core.instance import Instance
- parser = argparse.ArgumentParser()
- parser.add_argument('--pipe', type=str, default='')
- parser.add_argument('--gold_data', type=str, default='')
- parser.add_argument('--new_data', type=str)
- args = parser.parse_args()
- pipe = torch.load(args.pipe)['pipeline']
- for p in pipe:
- if p.field_name == 'word_list':
- print(p.field_name)
- p.field_name = 'gold_words'
- elif p.field_name == 'pos_list':
- print(p.field_name)
- p.field_name = 'gold_pos'
- data = ConllxDataLoader().load(args.gold_data)
- ds = DataSet()
- for ins1, ins2 in zip(add_seg_tag(data), data):
- ds.append(Instance(words=ins1[0], tag=ins1[1],
- gold_words=ins2[0], gold_pos=ins2[1],
- gold_heads=ins2[2], gold_head_tags=ins2[3]))
- ds = pipe(ds)
- seg_threshold = 0.
- pos_threshold = 0.
- parse_threshold = 0.74
- def get_heads(ins, head_f, word_f):
- head_pred = []
- for i, idx in enumerate(ins[head_f]):
- j = idx - 1 if idx != 0 else i
- head_pred.append(ins[word_f][j])
- return head_pred
- def evaluate(ins):
- seg_count = sum([1 for i, j in zip(ins['word_list'], ins['gold_words']) if i == j])
- pos_count = sum([1 for i, j in zip(ins['pos_list'], ins['gold_pos']) if i == j])
- head_count = sum([1 for i, j in zip(ins['heads'], ins['gold_heads']) if i == j])
- total = len(ins['gold_words'])
- return seg_count / total, pos_count / total, head_count / total
- def is_ok(x):
- seg, pos, head = x[1]
- return seg > seg_threshold and pos > pos_threshold and head > parse_threshold
- res_list = []
- for i, ins in enumerate(ds):
- res_list.append((i, evaluate(ins)))
- res_list = list(filter(is_ok, res_list))
- print('{} {}'.format(len(ds), len(res_list)))
- seg_cor, pos_cor, head_cor, label_cor, total = 0,0,0,0,0
- for i, _ in res_list:
- ins = ds[i]
- # print(i)
- # print('gold_words:\t', ins['gold_words'])
- # print('predict_words:\t', ins['word_list'])
- # print('gold_tag:\t', ins['gold_pos'])
- # print('predict_tag:\t', ins['pos_list'])
- # print('gold_heads:\t', ins['gold_heads'])
- # print('predict_heads:\t', ins['heads'].tolist())
- # print('gold_head_tags:\t', ins['gold_head_tags'])
- # print('predict_labels:\t', ins['labels'])
- # print()
- head_pred = ins['heads']
- head_gold = ins['gold_heads']
- label_pred = ins['labels']
- label_gold = ins['gold_head_tags']
- total += len(head_gold)
- seg_cor += sum([1 for i, j in zip(ins['word_list'], ins['gold_words']) if i == j])
- pos_cor += sum([1 for i, j in zip(ins['pos_list'], ins['gold_pos']) if i == j])
- length = len(head_gold)
- for i in range(length):
- head_cor += 1 if head_pred[i] == head_gold[i] else 0
- label_cor += 1 if head_pred[i] == head_gold[i] and label_gold[i] == label_pred[i] else 0
- print('SEG: {}, POS: {}, UAS: {}, LAS: {}'.format(seg_cor/total, pos_cor/total, head_cor/total, label_cor/total))
- colln_path = args.gold_data
- new_colln_path = args.new_data
- index_list = [x[0] for x in res_list]
- with open(colln_path, 'r', encoding='utf-8') as f1, \
- open(new_colln_path, 'w', encoding='utf-8') as f2:
- for idx, ins in enumerate(ds):
- if idx in index_list:
- length = len(ins['gold_words'])
- pad = ['_' for _ in range(length)]
- for x in zip(
- map(str, range(1, length+1)), ins['gold_words'], ins['gold_words'], ins['gold_pos'],
- pad, pad, map(str, ins['gold_heads']), ins['gold_head_tags']):
- new_lines = '\t'.join(x)
- f2.write(new_lines)
- f2.write('\n')
- f2.write('\n')