Browse Source

当前版本分词准确率已达正常分词分数

tags/v0.2.0
yh_cc 5 years ago
parent
commit
ea1c8c1100
3 changed files with 30 additions and 14 deletions
  1. +2
    -1
      fastNLP/core/sampler.py
  2. +2
    -2
      reproduction/chinese_word_segment/process/cws_processor.py
  3. +26
    -11
      reproduction/chinese_word_segment/train_context.py

+ 2
- 1
fastNLP/core/sampler.py View File

@@ -78,7 +78,8 @@ class BucketSampler(BaseSampler):
for i in range(num_batch_per_bucket): for i in range(num_batch_per_bucket):
batchs.append(left_init_indexes[i*self.batch_size:(i+1)*self.batch_size]) batchs.append(left_init_indexes[i*self.batch_size:(i+1)*self.batch_size])
left_init_indexes = left_init_indexes[num_batch_per_bucket*self.batch_size:] left_init_indexes = left_init_indexes[num_batch_per_bucket*self.batch_size:]

if (left_init_indexes)!=0:
batchs.append(left_init_indexes)
np.random.shuffle(batchs) np.random.shuffle(batchs)


return list(chain(*batchs)) return list(chain(*batchs))


+ 2
- 2
reproduction/chinese_word_segment/process/cws_processor.py View File

@@ -182,10 +182,10 @@ class Pre2Post2BigramProcessor(BigramProcessor):
# Processor了 # Processor了


class VocabProcessor(Processor): class VocabProcessor(Processor):
def __init__(self, field_name):
def __init__(self, field_name, min_count=1, max_vocab_size=None):


super(VocabProcessor, self).__init__(field_name, None) super(VocabProcessor, self).__init__(field_name, None)
self.vocab = Vocabulary()
self.vocab = Vocabulary(min_freq=min_count, max_size=max_vocab_size)


def process(self, *datasets): def process(self, *datasets):
for dataset in datasets: for dataset in datasets:


+ 26
- 11
reproduction/chinese_word_segment/train_context.py View File

@@ -11,11 +11,15 @@ from reproduction.chinese_word_segment.process.cws_processor import SeqLenProces


from reproduction.chinese_word_segment.process.span_converter import AlphaSpanConverter from reproduction.chinese_word_segment.process.span_converter import AlphaSpanConverter
from reproduction.chinese_word_segment.process.span_converter import DigitSpanConverter from reproduction.chinese_word_segment.process.span_converter import DigitSpanConverter
from reproduction.chinese_word_segment.process.span_converter import TimeConverter
from reproduction.chinese_word_segment.process.span_converter import MixNumAlphaConverter
from reproduction.chinese_word_segment.process.span_converter import EmailConverter
from reproduction.chinese_word_segment.cws_io.cws_reader import NaiveCWSReader from reproduction.chinese_word_segment.cws_io.cws_reader import NaiveCWSReader
from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp


tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/pku/middle_files/pku_train.txt'
dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/pku/middle_files/pku_dev.txt'
ds_name = 'pku'
tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, ds_name)
dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, ds_name)


reader = NaiveCWSReader() reader = NaiveCWSReader()


@@ -27,8 +31,12 @@ dev_dataset = reader.load(dev_filename)
fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence')


sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence') sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence')
sp_proc.add_span_converter(EmailConverter())
sp_proc.add_span_converter(MixNumAlphaConverter())
sp_proc.add_span_converter(AlphaSpanConverter()) sp_proc.add_span_converter(AlphaSpanConverter())
sp_proc.add_span_converter(DigitSpanConverter()) sp_proc.add_span_converter(DigitSpanConverter())
sp_proc.add_span_converter(TimeConverter())



char_proc = CWSCharSegProcessor('sentence', 'chars_list') char_proc = CWSCharSegProcessor('sentence', 'chars_list')


@@ -37,7 +45,7 @@ tag_proc = CWSSegAppTagProcessor('sentence', 'tags')
bigram_proc = Pre2Post2BigramProcessor('chars_list', 'bigrams_list') bigram_proc = Pre2Post2BigramProcessor('chars_list', 'bigrams_list')


char_vocab_proc = VocabProcessor('chars_list') char_vocab_proc = VocabProcessor('chars_list')
bigram_vocab_proc = VocabProcessor('bigrams_list')
bigram_vocab_proc = VocabProcessor('bigrams_list', min_count=4)


# 2. 使用processor # 2. 使用processor
fs2hs_proc(tr_dataset) fs2hs_proc(tr_dataset)
@@ -74,6 +82,8 @@ bigram_index_proc(dev_dataset)
seq_len_proc(dev_dataset) seq_len_proc(dev_dataset)


print("Finish preparing data.") print("Finish preparing data.")
print("Vocab size:{}, bigram size:{}.".format(char_vocab_proc.get_vocab_size(), bigram_vocab_proc.get_vocab_size()))



# 3. 得到数据集可以用于训练了 # 3. 得到数据集可以用于训练了
from itertools import chain from itertools import chain
@@ -89,11 +99,10 @@ def flat_nested_list(nested_list):
return list(chain(*nested_list)) return list(chain(*nested_list))


def calculate_pre_rec_f1(model, batcher): def calculate_pre_rec_f1(model, batcher):
true_ys, pred_ys, seq_lens = decode_iterator(model, batcher)
refined_true_ys = refine_ys_on_seq_len(true_ys, seq_lens)
refined_pred_ys = refine_ys_on_seq_len(pred_ys, seq_lens)
true_ys = flat_nested_list(refined_true_ys)
pred_ys = flat_nested_list(refined_pred_ys)
true_ys, pred_ys = decode_iterator(model, batcher)

true_ys = flat_nested_list(true_ys)
pred_ys = flat_nested_list(pred_ys)


cor_num = 0 cor_num = 0
yp_wordnum = pred_ys.count(1) yp_wordnum = pred_ys.count(1)
@@ -134,7 +143,10 @@ def decode_iterator(model, batcher):
seq_lens.extend(list(seq_len)) seq_lens.extend(list(seq_len))
model.train() model.train()


return true_ys, pred_ys, seq_lens
true_ys = refine_ys_on_seq_len(true_ys, seq_lens)
pred_ys = refine_ys_on_seq_len(pred_ys, seq_lens)

return true_ys, pred_ys


# TODO pretrain的embedding是怎么解决的? # TODO pretrain的embedding是怎么解决的?


@@ -161,7 +173,7 @@ cws_model.cuda()


num_epochs = 3 num_epochs = 3
loss_fn = FocalLoss(class_num=tag_size) loss_fn = FocalLoss(class_num=tag_size)
optimizer = optim.Adagrad(cws_model.parameters(), lr=0.01)
optimizer = optim.Adagrad(cws_model.parameters(), lr=0.02)




print_every = 50 print_every = 50
@@ -179,6 +191,8 @@ for num_epoch in range(num_epochs):
pbar.set_description_str('Epoch:%d' % (num_epoch + 1)) pbar.set_description_str('Epoch:%d' % (num_epoch + 1))
cws_model.train() cws_model.train()
for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1):
optimizer.zero_grad()

pred_dict = cws_model(batch_x) # B x L x tag_size pred_dict = cws_model(batch_x) # B x L x tag_size


seq_lens = pred_dict['seq_lens'] seq_lens = pred_dict['seq_lens']
@@ -217,6 +231,7 @@ for num_epoch in range(num_epochs):
} }
best_epoch = num_epoch best_epoch = num_epoch


cws_model.load_state_dict(best_state_dict)


# 4. 组装需要存下的内容 # 4. 组装需要存下的内容
pp = Pipeline() pp = Pipeline()
@@ -229,7 +244,7 @@ pp.add_processor(char_index_proc)
pp.add_processor(bigram_index_proc) pp.add_processor(bigram_index_proc)
pp.add_processor(seq_len_proc) pp.add_processor(seq_len_proc)


te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/pku/middle_files/pku_test.txt'
te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name)
te_dataset = reader.load(te_filename) te_dataset = reader.load(te_filename)
pp(te_dataset) pp(te_dataset)




Loading…
Cancel
Save