@@ -134,5 +134,5 @@ if tqdm is None: | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
pipeline = load_url('http://10.141.208.102:5000/file/download/infer_context.pkl', model_dir='.') | |||||
pipeline = load_url('http://10.141.208.102:5000/file/download/infer_context-4e86fd93.pkl', model_dir='.') | |||||
print(type(pipeline)) | print(type(pipeline)) |
@@ -90,7 +90,7 @@ class CWSBiLSTMSegApp(BaseModel): | |||||
self.encoder_model = CWSBiLSTMEncoder(vocab_num, embed_dim, bigram_vocab_num, bigram_embed_dim, num_bigram_per_char, | self.encoder_model = CWSBiLSTMEncoder(vocab_num, embed_dim, bigram_vocab_num, bigram_embed_dim, num_bigram_per_char, | ||||
hidden_size, bidirectional, embed_drop_p, num_layers) | hidden_size, bidirectional, embed_drop_p, num_layers) | ||||
size_layer = [hidden_size, 100, tag_size] | |||||
size_layer = [hidden_size, 200, tag_size] | |||||
self.decoder_model = MLP(size_layer) | self.decoder_model = MLP(size_layer) | ||||
@@ -194,6 +194,7 @@ class VocabProcessor(Processor): | |||||
tokens = ins[self.field_name] | tokens = ins[self.field_name] | ||||
self.vocab.update(tokens) | self.vocab.update(tokens) | ||||
def get_vocab(self): | def get_vocab(self): | ||||
self.vocab.build_vocab() | self.vocab.build_vocab() | ||||
return self.vocab | return self.vocab | ||||
@@ -6,23 +6,42 @@ from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 | from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 | ||||
ds_name = 'ncc' | |||||
def f1(): | |||||
ds_name = 'pku' | |||||
test_dict = torch.load('models/test_context.pkl') | |||||
test_dict = torch.load('models/test_context.pkl') | |||||
pp = test_dict['pipeline'] | |||||
model = test_dict['model'].cuda() | |||||
pp = test_dict['pipeline'] | |||||
model = test_dict['model'].cuda() | |||||
reader = NaiveCWSReader() | |||||
te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/{}_raw_data/{}_raw_test.txt'.format(ds_name, ds_name, | |||||
ds_name) | |||||
te_dataset = reader.load(te_filename) | |||||
pp(te_dataset) | |||||
reader = NaiveCWSReader() | |||||
te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/{}_raw_data/{}_raw_test.txt'.format(ds_name, ds_name, | |||||
ds_name) | |||||
te_dataset = reader.load(te_filename) | |||||
pp(te_dataset) | |||||
batch_size = 64 | |||||
te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False) | |||||
pre, rec, f1 = calculate_pre_rec_f1(model, te_batcher) | |||||
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100, | |||||
pre * 100, | |||||
rec * 100)) | |||||
batch_size = 64 | |||||
te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False) | |||||
pre, rec, f1 = calculate_pre_rec_f1(model, te_batcher) | |||||
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100, | |||||
pre * 100, | |||||
rec * 100)) | |||||
def f2(): | |||||
from fastNLP.api.api import CWS | |||||
cws = CWS('models/maml-cws.pkl') | |||||
datasets = ['msr', 'as', 'pku', 'ctb', 'ncc', 'cityu', 'ckip', 'sxu'] | |||||
for dataset in datasets: | |||||
print(dataset) | |||||
with open('/hdd/fudanNLP/CWS/others/benchmark/raw_and_gold/{}_raw.txt'.format(dataset), 'r') as f: | |||||
lines = f.readlines() | |||||
results = cws.predict(lines) | |||||
with open('/hdd/fudanNLP/CWS/others/benchmark/fastNLP_output/{}_seg.txt'.format(dataset), 'w', encoding='utf-8') as f: | |||||
for line in results: | |||||
f.write(line) | |||||
f1() |
@@ -19,10 +19,15 @@ from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp | |||||
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 | from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 | ||||
ds_name = 'msr' | |||||
tr_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, | |||||
ds_name = 'pku' | |||||
# tr_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, | |||||
# ds_name) | |||||
# dev_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, | |||||
# ds_name) | |||||
tr_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_train.txt'.format(ds_name, | |||||
ds_name) | ds_name) | ||||
dev_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, | |||||
dev_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, | |||||
ds_name) | ds_name) | ||||
reader = NaiveCWSReader() | reader = NaiveCWSReader() | ||||
@@ -189,7 +194,7 @@ pp.add_processor(seq_len_proc) | |||||
te_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name) | |||||
te_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name) | |||||
te_dataset = reader.load(te_filename) | te_dataset = reader.load(te_filename) | ||||
pp(te_dataset) | pp(te_dataset) | ||||
@@ -231,9 +236,8 @@ pp.add_processor(output_proc) | |||||
# TODO 这里貌似需要区分test pipeline与infer pipeline | # TODO 这里貌似需要区分test pipeline与infer pipeline | ||||
infer_context_dict = {'pipeline': pp, | |||||
'model': cws_model} | |||||
torch.save(infer_context_dict, 'models/infer_context.pkl') | |||||
infer_context_dict = {'pipeline': pp} | |||||
torch.save(infer_context_dict, 'models/infer_cws.pkl') | |||||
# TODO 还需要考虑如何替换回原文的问题? | # TODO 还需要考虑如何替换回原文的问题? | ||||
@@ -34,19 +34,27 @@ def calculate_pre_rec_f1(model, batcher): | |||||
yp_wordnum = pred_ys.count(1) | yp_wordnum = pred_ys.count(1) | ||||
yt_wordnum = true_ys.count(1) | yt_wordnum = true_ys.count(1) | ||||
start = 0 | start = 0 | ||||
for i in range(len(true_ys)): | |||||
if true_ys[0]==1 and pred_ys[0]==1: | |||||
cor_num += 1 | |||||
start = 1 | |||||
for i in range(1, len(true_ys)): | |||||
if true_ys[i] == 1: | if true_ys[i] == 1: | ||||
flag = True | flag = True | ||||
for j in range(start, i + 1): | |||||
if true_ys[j] != pred_ys[j]: | |||||
flag = False | |||||
break | |||||
if true_ys[start-1] != pred_ys[start-1]: | |||||
flag = False | |||||
else: | |||||
for j in range(start, i + 1): | |||||
if true_ys[j] != pred_ys[j]: | |||||
flag = False | |||||
break | |||||
if flag: | if flag: | ||||
cor_num += 1 | cor_num += 1 | ||||
start = i + 1 | start = i + 1 | ||||
P = cor_num / (float(yp_wordnum) + 1e-6) | P = cor_num / (float(yp_wordnum) + 1e-6) | ||||
R = cor_num / (float(yt_wordnum) + 1e-6) | R = cor_num / (float(yt_wordnum) + 1e-6) | ||||
F = 2 * P * R / (P + R + 1e-6) | F = 2 * P * R / (P + R + 1e-6) | ||||
print(cor_num, yt_wordnum, yp_wordnum) | |||||
return P, R, F | return P, R, F | ||||