Browse Source

修改了遇到的若干问题,增加了分词任务的一些方法

tags/v0.2.0
yh_cc 5 years ago
parent
commit
69a138eb18
8 changed files with 212 additions and 64 deletions
  1. +8
    -2
      fastNLP/api/processor.py
  2. +9
    -8
      fastNLP/core/batch.py
  3. +3
    -1
      fastNLP/core/dataset.py
  4. +0
    -0
      reproduction/chinese_word_segment/cws_io/__init__.py
  5. +0
    -0
      reproduction/chinese_word_segment/cws_io/cws_reader.py
  6. +11
    -14
      reproduction/chinese_word_segment/models/cws_model.py
  7. +27
    -9
      reproduction/chinese_word_segment/process/cws_processor.py
  8. +154
    -30
      reproduction/chinese_word_segment/train_context.py

+ 8
- 2
fastNLP/api/processor.py View File

@@ -67,7 +67,7 @@ class FullSpaceToHalfSpaceProcessor(Processor):
def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
sentence = ins[self.field_name].text
sentence = ins[self.field_name]
new_sentence = [None]*len(sentence)
for idx, char in enumerate(sentence):
if char in self.convert_map:
@@ -78,12 +78,13 @@ class FullSpaceToHalfSpaceProcessor(Processor):


class IndexerProcessor(Processor):
def __init__(self, vocab, field_name, new_added_field_name):
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False):

assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))

super(IndexerProcessor, self).__init__(field_name, new_added_field_name)
self.vocab = vocab
self.delete_old_field = delete_old_field

def set_vocab(self, vocab):
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
@@ -97,6 +98,11 @@ class IndexerProcessor(Processor):
index = [self.vocab.to_index(token) for token in tokens]
ins[self.new_added_field_name] = index

dataset.set_need_tensor(**{self.new_added_field_name:True})

if self.delete_old_field:
dataset.delete_field(self.field_name)

return dataset




+ 9
- 8
fastNLP/core/batch.py View File

@@ -55,14 +55,15 @@ class Batch(object):

indices = self.idx_list[self.curidx:endidx]

for field_name, field in self.dataset.get_fields():
batch = torch.from_numpy(field.get(indices))
if not field.need_tensor: #TODO 修改
pass
elif field.is_target:
batch_y[field_name] = batch
else:
batch_x[field_name] = batch
for field_name, field in self.dataset.get_fields().items():
if field.need_tensor:
batch = torch.from_numpy(field.get(indices))
if not field.need_tensor:
pass
elif field.is_target:
batch_y[field_name] = batch
else:
batch_x[field_name] = batch

self.curidx = endidx



+ 3
- 1
fastNLP/core/dataset.py View File

@@ -75,11 +75,13 @@ class DataSet(object):
assert len(self) == len(fields)
self.field_arrays[name] = FieldArray(name, fields)

def delete_field(self, name):
self.field_arrays.pop(name)

def get_fields(self):
return self.field_arrays

def __getitem__(self, name):
assert name in self.field_arrays
return self.field_arrays[name]

def __len__(self):


reproduction/chinese_word_segment/io/__init__.py → reproduction/chinese_word_segment/cws_io/__init__.py View File


reproduction/chinese_word_segment/io/cws_reader.py → reproduction/chinese_word_segment/cws_io/cws_reader.py View File


+ 11
- 14
reproduction/chinese_word_segment/models/cws_model.py View File

@@ -35,13 +35,6 @@ class CWSBiLSTMEncoder(BaseModel):
self.bigram_embedding = nn.Embedding(num_embeddings=bigram_vocab_num, embedding_dim=bigram_embed_dim)
self.input_size += self.num_bigram_per_char*bigram_embed_dim

if self.num_criterion!=None:
if bidirectional:
self.backward_criterion_embedding = nn.Embedding(num_embeddings=self.num_criterion,
embedding_dim=self.hidden_size)
self.forward_criterion_embedding = nn.Embedding(num_embeddings=self.num_criterion,
embedding_dim=self.hidden_size)

if not self.embed_drop_p is None:
self.embedding_drop = nn.Dropout(p=self.embed_drop_p)

@@ -102,13 +95,14 @@ class CWSBiLSTMSegApp(BaseModel):
self.decoder_model = MLP(size_layer)


def forward(self, **kwargs):
chars = kwargs['chars']
if 'bigram' in kwargs:
bigrams = kwargs['bigrams']
def forward(self, batch_dict):
device = self.parameters().__next__().device
chars = batch_dict['indexed_chars_list'].to(device)
if 'bigram' in batch_dict:
bigrams = batch_dict['indexed_chars_list'].to(device)
else:
bigrams = None
seq_lens = kwargs['seq_lens']
seq_lens = batch_dict['seq_lens'].to(device)

feats = self.encoder_model(chars, bigrams, seq_lens)
probs = self.decoder_model(feats)
@@ -119,6 +113,10 @@ class CWSBiLSTMSegApp(BaseModel):

return pred_dict

def predict(self, batch_dict):
pass


def loss_fn(self, pred_dict, true_dict):
seq_lens = pred_dict['seq_lens']
masks = seq_lens_to_mask(seq_lens).float()
@@ -131,5 +129,4 @@ class CWSBiLSTMSegApp(BaseModel):
true_y.view(-1), reduction='none')*masks.view(-1)/torch.sum(masks)


return loss

return loss

+ 27
- 9
reproduction/chinese_word_segment/process/cws_processor.py View File

@@ -21,7 +21,7 @@ class SpeicalSpanProcessor(Processor):
def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
sentence = ins[self.field_name].text
sentence = ins[self.field_name]
for span_converter in self.span_converters:
sentence = span_converter.find_certain_span_and_replace(sentence)
ins[self.new_added_field_name] = sentence
@@ -42,10 +42,9 @@ class CWSCharSegProcessor(Processor):
def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
sentence = ins[self.field_name].text
sentence = ins[self.field_name]
chars = self._split_sent_into_chars(sentence)
new_token_field = TokenListFiled(chars, is_target=False)
ins[self.new_added_field_name] = new_token_field
ins[self.new_added_field_name] = chars

return dataset

@@ -109,10 +108,11 @@ class CWSTagProcessor(Processor):
def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
sentence = ins[self.field_name].text
sentence = ins[self.field_name]
tag_list = self._generate_tag(sentence)
new_tag_field = SeqLabelField(tag_list)
ins[self.new_added_field_name] = new_tag_field
dataset.set_is_target(**{self.new_added_field_name:True})
return dataset

def _tags_from_word_len(self, word_len):
@@ -123,6 +123,8 @@ class CWSSegAppTagProcessor(CWSTagProcessor):
def __init__(self, field_name, new_added_field_name=None):
super(CWSSegAppTagProcessor, self).__init__(field_name, new_added_field_name)

self.tag_size = 2

def _tags_from_word_len(self, word_len):
tag_list = []
for _ in range(word_len-1):
@@ -140,10 +142,9 @@ class BigramProcessor(Processor):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))

for ins in dataset:
characters = ins[self.field_name].content
characters = ins[self.field_name]
bigrams = self._generate_bigram(characters)
new_token_field = TokenListFiled(bigrams)
ins[self.new_added_field_name] = new_token_field
ins[self.new_added_field_name] = bigrams

return dataset

@@ -190,9 +191,26 @@ class VocabProcessor(Processor):
for dataset in datasets:
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
tokens = ins[self.field_name].content
tokens = ins[self.field_name]
self.vocab.update(tokens)

def get_vocab(self):
self.vocab.build_vocab()
return self.vocab

def get_vocab_size(self):
return len(self.vocab)


class SeqLenProcessor(Processor):
def __init__(self, field_name, new_added_field_name='seq_lens'):

super(SeqLenProcessor, self).__init__(field_name, new_added_field_name)

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
length = len(ins[self.field_name])
ins[self.new_added_field_name] = length
dataset.set_need_tensor(**{self.new_added_field_name:True})
return dataset

+ 154
- 30
reproduction/chinese_word_segment/train_context.py View File

@@ -9,35 +9,22 @@ from reproduction.chinese_word_segment.process.cws_processor import CWSCharSegPr
from reproduction.chinese_word_segment.process.cws_processor import CWSSegAppTagProcessor
from reproduction.chinese_word_segment.process.cws_processor import Pre2Post2BigramProcessor
from reproduction.chinese_word_segment.process.cws_processor import VocabProcessor
from reproduction.chinese_word_segment.process.cws_processor import SeqLenProcessor

from reproduction.chinese_word_segment.process.span_converter import AlphaSpanConverter
from reproduction.chinese_word_segment.process.span_converter import DigitSpanConverter
from reproduction.chinese_word_segment.io.cws_reader import NaiveCWSReader
from reproduction.chinese_word_segment.cws_io.cws_reader import NaiveCWSReader
from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp

tr_filename = ''
dev_filename = ''
tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/pku/middle_files/pku_train.txt'
dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/pku/middle_files/pku_dev.txt'

reader = NaiveCWSReader()

tr_sentences = reader.load(tr_filename, cut_long_sent=True)
dev_sentences = reader.load(dev_filename)
tr_dataset = reader.load(tr_filename, cut_long_sent=True)
dev_dataset = reader.load(dev_filename)


# TODO 如何组建成为一个Dataset
def construct_dataset(sentences):
dataset = DataSet()
for sentence in sentences:
instance = Instance()
instance['raw_sentence'] = sentence
dataset.append(instance)

return dataset


tr_dataset = construct_dataset(tr_sentences)
dev_dataset = construct_dataset(dev_sentences)

# 1. 准备processor
fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence')

@@ -45,14 +32,14 @@ sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence')
sp_proc.add_span_converter(AlphaSpanConverter())
sp_proc.add_span_converter(DigitSpanConverter())

char_proc = CWSCharSegProcessor('sentence', 'char_list')
char_proc = CWSCharSegProcessor('sentence', 'chars_list')

tag_proc = CWSSegAppTagProcessor('sentence', 'tag')
tag_proc = CWSSegAppTagProcessor('sentence', 'tags')

bigram_proc = Pre2Post2BigramProcessor('char_list', 'bigram_list')
bigram_proc = Pre2Post2BigramProcessor('chars_list', 'bigrams_list')

char_vocab_proc = VocabProcessor('char_list')
bigram_vocab_proc = VocabProcessor('bigram_list')
char_vocab_proc = VocabProcessor('chars_list')
bigram_vocab_proc = VocabProcessor('bigrams_list')

# 2. 使用processor
fs2hs_proc(tr_dataset)
@@ -66,15 +53,18 @@ bigram_proc(tr_dataset)
char_vocab_proc(tr_dataset)
bigram_vocab_proc(tr_dataset)

char_index_proc = IndexerProcessor(char_vocab_proc.get_vocab(), 'chars_list', 'indexed_chars_list')
bigram_index_proc = IndexerProcessor(bigram_vocab_proc.get_vocab(), 'bigrams_list','indexed_bigrams_list')
char_index_proc = IndexerProcessor(char_vocab_proc.get_vocab(), 'chars_list', 'indexed_chars_list',
delete_old_field=True)
bigram_index_proc = IndexerProcessor(bigram_vocab_proc.get_vocab(), 'bigrams_list','indexed_bigrams_list',
delete_old_field=True)
seq_len_proc = SeqLenProcessor('indexed_chars_list')

char_index_proc(tr_dataset)
bigram_index_proc(tr_dataset)
seq_len_proc(tr_dataset)

# 2.1 处理dev_dataset
fs2hs_proc(dev_dataset)

sp_proc(dev_dataset)

char_proc(dev_dataset)
@@ -83,14 +73,148 @@ bigram_proc(dev_dataset)

char_index_proc(dev_dataset)
bigram_index_proc(dev_dataset)
seq_len_proc(dev_dataset)

print("Finish preparing data.")

# 3. 得到数据集可以用于训练了
# TODO pretrain的embedding是怎么解决的?
cws_model = CWSBiLSTMSegApp(vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None,
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=2)
from itertools import chain

def refine_ys_on_seq_len(ys, seq_lens):
refined_ys = []
for b_idx, length in enumerate(seq_lens):
refined_ys.append(list(ys[b_idx][:length]))

return refined_ys

def flat_nested_list(nested_list):
return list(chain(*nested_list))

def calculate_pre_rec_f1(model, batcher):
true_ys, pred_ys, seq_lens = decode_iterator(model, batcher)
refined_true_ys = refine_ys_on_seq_len(true_ys, seq_lens)
refined_pred_ys = refine_ys_on_seq_len(pred_ys, seq_lens)
true_ys = flat_nested_list(refined_true_ys)
pred_ys = flat_nested_list(refined_pred_ys)

cor_num = 0
yp_wordnum = pred_ys.count(1)
yt_wordnum = true_ys.count(1)
start = 0
for i in range(len(true_ys)):
if true_ys[i] == 1:
flag = True
for j in range(start, i + 1):
if true_ys[j] != pred_ys[j]:
flag = False
break
if flag:
cor_num += 1
start = i + 1
P = cor_num / (float(yp_wordnum) + 1e-6)
R = cor_num / (float(yt_wordnum) + 1e-6)
F = 2 * P * R / (P + R + 1e-6)
return P, R, F

def decode_iterator(model, batcher):
true_ys = []
pred_ys = []
seq_lens = []
with torch.no_grad():
model.eval()
for batch_x, batch_y in batcher:
pred_dict = model(batch_x)
seq_len = pred_dict['seq_lens'].cpu().numpy()
probs = pred_dict['pred_probs']
_, pred_y = probs.max(dim=-1)
true_y = batch_y['tags']
pred_y = pred_y.cpu().numpy()
true_y = true_y.cpu().numpy()

true_ys.extend(list(true_y))
pred_ys.extend(list(pred_y))
seq_lens.extend(list(seq_len))
model.train()

return true_ys, pred_ys, seq_lens

# TODO pretrain的embedding是怎么解决的?

from reproduction.chinese_word_segment.utils import FocalLoss
from reproduction.chinese_word_segment.utils import seq_lens_to_mask
from fastNLP.core.batch import Batch
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler

import torch
from torch import optim
import sys
from tqdm import tqdm


tag_size = tag_proc.tag_size

cws_model = CWSBiLSTMSegApp(char_vocab_proc.get_vocab_size(), embed_dim=100,
bigram_vocab_num=bigram_vocab_proc.get_vocab_size(),
bigram_embed_dim=100, num_bigram_per_char=8,
hidden_size=200, bidirectional=True, embed_drop_p=None,
num_layers=1, tag_size=tag_size)

num_epochs = 3
loss_fn = FocalLoss(class_num=tag_size)
optimizer = optim.Adagrad(cws_model.parameters(), lr=0.01)


print_every = 50
batch_size = 32
tr_batcher = Batch(tr_dataset, batch_size, RandomSampler(), use_cuda=False)
dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False)
num_batch_per_epoch = len(tr_dataset) // batch_size
best_f1 = 0
best_epoch = 0
for num_epoch in range(num_epochs):
print('X' * 10 + ' Epoch: {}/{} '.format(num_epoch + 1, num_epochs) + 'X' * 10)
sys.stdout.flush()
avg_loss = 0
with tqdm(total=num_batch_per_epoch, leave=True) as pbar:
pbar.set_description_str('Epoch:%d' % (num_epoch + 1))
cws_model.train()
for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1):
pred_dict = cws_model(batch_x) # B x L x tag_size
seq_lens = batch_x['seq_lens']
masks = seq_lens_to_mask(seq_lens)
tags = batch_y['tags']
loss = torch.sum(loss_fn(pred_dict['pred_prob'].view(-1, tag_size),
tags.view(-1)) * masks.view(-1)) / torch.sum(masks)
# loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float())

avg_loss += loss.item()

loss.backward()
for group in optimizer.param_groups:
for param in group['params']:
param.grad.clamp_(-5, 5)

optimizer.step()

if batch_idx % print_every == 0:
pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every))
avg_loss = 0
pbar.update(print_every)

# 验证集
pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher)
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100,
pre*100,
rec*100))
if best_f1<f1:
best_f1 = f1
# 缓存最佳的parameter,可能之后会用于保存
best_state_dict = {
key:value.clone() for key, value in
cws_model.state_dict().items()
}
best_epoch = num_epoch


# 4. 组装需要存下的内容


Loading…
Cancel
Save