Browse Source

Sampler中增加了一个BucketSampler, CWS的训练基本可以实现

tags/v0.2.0
yh_cc 6 years ago
parent
commit
3cb98ddcf2
6 changed files with 86 additions and 52 deletions
  1. +2
    -1
      fastNLP/core/dataset.py
  2. +9
    -5
      fastNLP/core/fieldarray.py
  3. +42
    -1
      fastNLP/core/sampler.py
  4. +5
    -20
      reproduction/chinese_word_segment/models/cws_model.py
  5. +2
    -2
      reproduction/chinese_word_segment/process/cws_processor.py
  6. +26
    -23
      reproduction/chinese_word_segment/train_context.py

+ 2
- 1
fastNLP/core/dataset.py View File

@@ -72,7 +72,8 @@ class DataSet(object):
self.field_arrays[name].append(field) self.field_arrays[name].append(field)


def add_field(self, name, fields): def add_field(self, name, fields):
assert len(self) == len(fields)
if len(self.field_arrays)!=0:
assert len(self) == len(fields)
self.field_arrays[name] = FieldArray(name, fields) self.field_arrays[name] = FieldArray(name, fields)


def delete_field(self, name): def delete_field(self, name):


+ 9
- 5
fastNLP/core/fieldarray.py View File

@@ -28,11 +28,15 @@ class FieldArray(object):
return self.content[idxes] return self.content[idxes]
assert self.need_tensor is True assert self.need_tensor is True
batch_size = len(idxes) batch_size = len(idxes)
max_len = max([len(self.content[i]) for i in idxes])
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32)

for i, idx in enumerate(idxes):
array[i][:len(self.content[idx])] = self.content[idx]
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下
if isinstance(self.content[0], int) or isinstance(self.content[0], float):
array = np.array([self.content[i] for i in idxes], dtype=type(self.content[0]))
else:
max_len = max([len(self.content[i]) for i in idxes])
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32)

for i, idx in enumerate(idxes):
array[i][:len(self.content[idx])] = self.content[idx]
return array return array


def __len__(self): def __len__(self):


+ 42
- 1
fastNLP/core/sampler.py View File

@@ -1,6 +1,6 @@
import numpy as np import numpy as np
import torch import torch
from itertools import chain


def convert_to_torch_tensor(data_list, use_cuda): def convert_to_torch_tensor(data_list, use_cuda):
"""Convert lists into (cuda) Tensors. """Convert lists into (cuda) Tensors.
@@ -43,6 +43,47 @@ class RandomSampler(BaseSampler):
def __call__(self, data_set): def __call__(self, data_set):
return list(np.random.permutation(len(data_set))) return list(np.random.permutation(len(data_set)))


class BucketSampler(BaseSampler):

def __init__(self, num_buckets=10, batch_size=32):
self.num_buckets = num_buckets
self.batch_size = batch_size

def __call__(self, data_set):
assert 'seq_lens' in data_set, "BuckectSampler only support data_set with seq_lens right now."

seq_lens = data_set['seq_lens'].content
total_sample_num = len(seq_lens)

bucket_indexes = []
num_sample_per_bucket = total_sample_num//self.num_buckets
for i in range(self.num_buckets):
bucket_indexes.append([num_sample_per_bucket*i, num_sample_per_bucket*(i+1)])
bucket_indexes[-1][1] = total_sample_num

sorted_seq_lens = list(sorted([(idx, seq_len) for
idx, seq_len in zip(range(total_sample_num), seq_lens)],
key=lambda x:x[1]))

batchs = []

left_init_indexes = []
for b_idx in range(self.num_buckets):
start_idx = bucket_indexes[b_idx][0]
end_idx = bucket_indexes[b_idx][1]
sorted_bucket_seq_lens = sorted_seq_lens[start_idx:end_idx]
left_init_indexes.extend([tup[0] for tup in sorted_bucket_seq_lens])
num_batch_per_bucket = len(left_init_indexes)//self.batch_size
np.random.shuffle(left_init_indexes)
for i in range(num_batch_per_bucket):
batchs.append(left_init_indexes[i*self.batch_size:(i+1)*self.batch_size])
left_init_indexes = left_init_indexes[num_batch_per_bucket*self.batch_size:]

np.random.shuffle(batchs)

return list(chain(*batchs))




def simple_sort_bucketing(lengths): def simple_sort_bucketing(lengths):
""" """


+ 5
- 20
reproduction/chinese_word_segment/models/cws_model.py View File

@@ -68,7 +68,6 @@ class CWSBiLSTMEncoder(BaseModel):
if not bigrams is None: if not bigrams is None:
bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1) bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1)
x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2) x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2)

sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True) sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True)
packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True) packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True)


@@ -97,36 +96,22 @@ class CWSBiLSTMSegApp(BaseModel):


def forward(self, batch_dict): def forward(self, batch_dict):
device = self.parameters().__next__().device device = self.parameters().__next__().device
chars = batch_dict['indexed_chars_list'].to(device)
if 'bigram' in batch_dict:
bigrams = batch_dict['indexed_chars_list'].to(device)
chars = batch_dict['indexed_chars_list'].to(device).long()
if 'indexed_bigrams_list' in batch_dict:
bigrams = batch_dict['indexed_bigrams_list'].to(device).long()
else: else:
bigrams = None bigrams = None
seq_lens = batch_dict['seq_lens'].to(device)
seq_lens = batch_dict['seq_lens'].to(device).long()


feats = self.encoder_model(chars, bigrams, seq_lens) feats = self.encoder_model(chars, bigrams, seq_lens)
probs = self.decoder_model(feats) probs = self.decoder_model(feats)


pred_dict = {} pred_dict = {}
pred_dict['seq_lens'] = seq_lens pred_dict['seq_lens'] = seq_lens
pred_dict['pred_prob'] = probs
pred_dict['pred_probs'] = probs


return pred_dict return pred_dict


def predict(self, batch_dict): def predict(self, batch_dict):
pass pass



def loss_fn(self, pred_dict, true_dict):
seq_lens = pred_dict['seq_lens']
masks = seq_lens_to_mask(seq_lens).float()

pred_prob = pred_dict['pred_prob']
true_y = true_dict['tags']

# TODO 当前把loss写死了
loss = F.cross_entropy(pred_prob.view(-1, self.tag_size),
true_y.view(-1), reduction='none')*masks.view(-1)/torch.sum(masks)


return loss

+ 2
- 2
reproduction/chinese_word_segment/process/cws_processor.py View File

@@ -110,9 +110,9 @@ class CWSTagProcessor(Processor):
for ins in dataset: for ins in dataset:
sentence = ins[self.field_name] sentence = ins[self.field_name]
tag_list = self._generate_tag(sentence) tag_list = self._generate_tag(sentence)
new_tag_field = SeqLabelField(tag_list)
ins[self.new_added_field_name] = new_tag_field
ins[self.new_added_field_name] = tag_list
dataset.set_is_target(**{self.new_added_field_name:True}) dataset.set_is_target(**{self.new_added_field_name:True})
dataset.set_need_tensor(**{self.new_added_field_name:True})
return dataset return dataset


def _tags_from_word_len(self, word_len): def _tags_from_word_len(self, word_len):


+ 26
- 23
reproduction/chinese_word_segment/train_context.py View File

@@ -1,6 +1,4 @@


from fastNLP.core.instance import Instance
from fastNLP.core.dataset import DataSet
from fastNLP.api.pipeline import Pipeline from fastNLP.api.pipeline import Pipeline
from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor
from fastNLP.api.processor import IndexerProcessor from fastNLP.api.processor import IndexerProcessor
@@ -143,7 +141,7 @@ def decode_iterator(model, batcher):
from reproduction.chinese_word_segment.utils import FocalLoss from reproduction.chinese_word_segment.utils import FocalLoss
from reproduction.chinese_word_segment.utils import seq_lens_to_mask from reproduction.chinese_word_segment.utils import seq_lens_to_mask
from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import BucketSampler
from fastNLP.core.sampler import SequentialSampler from fastNLP.core.sampler import SequentialSampler


import torch import torch
@@ -159,6 +157,7 @@ cws_model = CWSBiLSTMSegApp(char_vocab_proc.get_vocab_size(), embed_dim=100,
bigram_embed_dim=100, num_bigram_per_char=8, bigram_embed_dim=100, num_bigram_per_char=8,
hidden_size=200, bidirectional=True, embed_drop_p=None, hidden_size=200, bidirectional=True, embed_drop_p=None,
num_layers=1, tag_size=tag_size) num_layers=1, tag_size=tag_size)
cws_model.cuda()


num_epochs = 3 num_epochs = 3
loss_fn = FocalLoss(class_num=tag_size) loss_fn = FocalLoss(class_num=tag_size)
@@ -167,7 +166,7 @@ optimizer = optim.Adagrad(cws_model.parameters(), lr=0.01)


print_every = 50 print_every = 50
batch_size = 32 batch_size = 32
tr_batcher = Batch(tr_dataset, batch_size, RandomSampler(), use_cuda=False)
tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False)
dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False) dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False)
num_batch_per_epoch = len(tr_dataset) // batch_size num_batch_per_epoch = len(tr_dataset) // batch_size
best_f1 = 0 best_f1 = 0
@@ -181,10 +180,12 @@ for num_epoch in range(num_epochs):
cws_model.train() cws_model.train()
for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1):
pred_dict = cws_model(batch_x) # B x L x tag_size pred_dict = cws_model(batch_x) # B x L x tag_size
seq_lens = batch_x['seq_lens']
masks = seq_lens_to_mask(seq_lens)
tags = batch_y['tags']
loss = torch.sum(loss_fn(pred_dict['pred_prob'].view(-1, tag_size),

seq_lens = pred_dict['seq_lens']
masks = seq_lens_to_mask(seq_lens).float()
tags = batch_y['tags'].long().to(seq_lens.device)

loss = torch.sum(loss_fn(pred_dict['pred_probs'].view(-1, tag_size),
tags.view(-1)) * masks.view(-1)) / torch.sum(masks) tags.view(-1)) * masks.view(-1)) / torch.sum(masks)
# loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float()) # loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float())


@@ -201,20 +202,20 @@ for num_epoch in range(num_epochs):
pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every)) pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every))
avg_loss = 0 avg_loss = 0
pbar.update(print_every) pbar.update(print_every)
# 验证集
pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher)
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100,
pre*100,
rec*100))
if best_f1<f1:
best_f1 = f1
# 缓存最佳的parameter,可能之后会用于保存
best_state_dict = {
key:value.clone() for key, value in
cws_model.state_dict().items()
}
best_epoch = num_epoch
tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False)
# 验证集
pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher)
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100,
pre*100,
rec*100))
if best_f1<f1:
best_f1 = f1
# 缓存最佳的parameter,可能之后会用于保存
best_state_dict = {
key:value.clone() for key, value in
cws_model.state_dict().items()
}
best_epoch = num_epoch




# 4. 组装需要存下的内容 # 4. 组装需要存下的内容
@@ -224,4 +225,6 @@ pp.add_processor(sp_proc)
pp.add_processor(char_proc) pp.add_processor(char_proc)
pp.add_processor(bigram_proc) pp.add_processor(bigram_proc)
pp.add_processor(char_index_proc) pp.add_processor(char_index_proc)
pp.add_processor(bigram_index_proc)
pp.add_processor(bigram_index_proc)
pp.add_processor(seq_len_proc)


Loading…
Cancel
Save