Browse Source

创建了一个测试context

tags/v0.2.0
yh_cc 6 years ago
parent
commit
3e50ca8a72
3 changed files with 34 additions and 2 deletions
  1. +0
    -1
      fastNLP/api/cws.py
  2. +28
    -0
      reproduction/chinese_word_segment/testcontext.py
  3. +6
    -1
      reproduction/chinese_word_segment/train_context.py

+ 0
- 1
fastNLP/api/cws.py View File

@@ -30,4 +30,3 @@ class CWS(API):
# 4. TODO 这里应该要交给一个iterator一样的东西预测这个结果

# 5. TODO 得到结果,需要考虑是否需要反转回去, 及post_process的操作

+ 28
- 0
reproduction/chinese_word_segment/testcontext.py View File

@@ -0,0 +1,28 @@


import torch
from reproduction.chinese_word_segment.cws_io.cws_reader import NaiveCWSReader
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.batch import Batch
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1

ds_name = 'ncc'

test_dict = torch.load('models/test_context.pkl')


pp = test_dict['pipeline']
model = test_dict['model'].cuda()

reader = NaiveCWSReader()
te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/{}_raw_data/{}_raw_test.txt'.format(ds_name, ds_name,
ds_name)
te_dataset = reader.load(te_filename)
pp(te_dataset)

batch_size = 64
te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False)
pre, rec, f1 = calculate_pre_rec_f1(model, te_batcher)
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100,
pre * 100,
rec * 100))

+ 6
- 1
reproduction/chinese_word_segment/train_context.py View File

@@ -19,7 +19,7 @@ from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp

from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1

ds_name = 'pku'
ds_name = 'msr'
tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, ds_name)
dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, ds_name)

@@ -197,6 +197,11 @@ print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100,

# TODO 这里貌似需要区分test pipeline与dev pipeline

test_context_dict = {'pipeline': pp,
'model': cws_model}
torch.save(test_context_dict, 'models/test_context.pkl')


# TODO 还需要考虑如何替换回原文的问题?
# 1. 不需要将特殊tag替换
# 2. 需要将特殊tag替换回去

Loading…
Cancel
Save