Browse Source

Sampler中增加了一个BucketSampler, CWS的训练基本可以实现

tags/v0.2.0
yh_cc 6 years ago
parent
commit
3cb98ddcf2
6 changed files with 86 additions and 52 deletions
  1. +2
    -1
      fastNLP/core/dataset.py
  2. +9
    -5
      fastNLP/core/fieldarray.py
  3. +42
    -1
      fastNLP/core/sampler.py
  4. +5
    -20
      reproduction/chinese_word_segment/models/cws_model.py
  5. +2
    -2
      reproduction/chinese_word_segment/process/cws_processor.py
  6. +26
    -23
      reproduction/chinese_word_segment/train_context.py

+ 2
- 1
fastNLP/core/dataset.py View File

@@ -72,7 +72,8 @@ class DataSet(object):
self.field_arrays[name].append(field)

def add_field(self, name, fields):
assert len(self) == len(fields)
if len(self.field_arrays)!=0:
assert len(self) == len(fields)
self.field_arrays[name] = FieldArray(name, fields)

def delete_field(self, name):


+ 9
- 5
fastNLP/core/fieldarray.py View File

@@ -28,11 +28,15 @@ class FieldArray(object):
return self.content[idxes]
assert self.need_tensor is True
batch_size = len(idxes)
max_len = max([len(self.content[i]) for i in idxes])
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32)

for i, idx in enumerate(idxes):
array[i][:len(self.content[idx])] = self.content[idx]
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下
if isinstance(self.content[0], int) or isinstance(self.content[0], float):
array = np.array([self.content[i] for i in idxes], dtype=type(self.content[0]))
else:
max_len = max([len(self.content[i]) for i in idxes])
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32)

for i, idx in enumerate(idxes):
array[i][:len(self.content[idx])] = self.content[idx]
return array

def __len__(self):


+ 42
- 1
fastNLP/core/sampler.py View File

@@ -1,6 +1,6 @@
import numpy as np
import torch
from itertools import chain

def convert_to_torch_tensor(data_list, use_cuda):
"""Convert lists into (cuda) Tensors.
@@ -43,6 +43,47 @@ class RandomSampler(BaseSampler):
def __call__(self, data_set):
return list(np.random.permutation(len(data_set)))

class BucketSampler(BaseSampler):

def __init__(self, num_buckets=10, batch_size=32):
self.num_buckets = num_buckets
self.batch_size = batch_size

def __call__(self, data_set):
assert 'seq_lens' in data_set, "BuckectSampler only support data_set with seq_lens right now."

seq_lens = data_set['seq_lens'].content
total_sample_num = len(seq_lens)

bucket_indexes = []
num_sample_per_bucket = total_sample_num//self.num_buckets
for i in range(self.num_buckets):
bucket_indexes.append([num_sample_per_bucket*i, num_sample_per_bucket*(i+1)])
bucket_indexes[-1][1] = total_sample_num

sorted_seq_lens = list(sorted([(idx, seq_len) for
idx, seq_len in zip(range(total_sample_num), seq_lens)],
key=lambda x:x[1]))

batchs = []

left_init_indexes = []
for b_idx in range(self.num_buckets):
start_idx = bucket_indexes[b_idx][0]
end_idx = bucket_indexes[b_idx][1]
sorted_bucket_seq_lens = sorted_seq_lens[start_idx:end_idx]
left_init_indexes.extend([tup[0] for tup in sorted_bucket_seq_lens])
num_batch_per_bucket = len(left_init_indexes)//self.batch_size
np.random.shuffle(left_init_indexes)
for i in range(num_batch_per_bucket):
batchs.append(left_init_indexes[i*self.batch_size:(i+1)*self.batch_size])
left_init_indexes = left_init_indexes[num_batch_per_bucket*self.batch_size:]

np.random.shuffle(batchs)

return list(chain(*batchs))



def simple_sort_bucketing(lengths):
"""


+ 5
- 20
reproduction/chinese_word_segment/models/cws_model.py View File

@@ -68,7 +68,6 @@ class CWSBiLSTMEncoder(BaseModel):
if not bigrams is None:
bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1)
x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2)

sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True)
packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True)

@@ -97,36 +96,22 @@ class CWSBiLSTMSegApp(BaseModel):

def forward(self, batch_dict):
device = self.parameters().__next__().device
chars = batch_dict['indexed_chars_list'].to(device)
if 'bigram' in batch_dict:
bigrams = batch_dict['indexed_chars_list'].to(device)
chars = batch_dict['indexed_chars_list'].to(device).long()
if 'indexed_bigrams_list' in batch_dict:
bigrams = batch_dict['indexed_bigrams_list'].to(device).long()
else:
bigrams = None
seq_lens = batch_dict['seq_lens'].to(device)
seq_lens = batch_dict['seq_lens'].to(device).long()

feats = self.encoder_model(chars, bigrams, seq_lens)
probs = self.decoder_model(feats)

pred_dict = {}
pred_dict['seq_lens'] = seq_lens
pred_dict['pred_prob'] = probs
pred_dict['pred_probs'] = probs

return pred_dict

def predict(self, batch_dict):
pass


def loss_fn(self, pred_dict, true_dict):
seq_lens = pred_dict['seq_lens']
masks = seq_lens_to_mask(seq_lens).float()

pred_prob = pred_dict['pred_prob']
true_y = true_dict['tags']

# TODO 当前把loss写死了
loss = F.cross_entropy(pred_prob.view(-1, self.tag_size),
true_y.view(-1), reduction='none')*masks.view(-1)/torch.sum(masks)


return loss

+ 2
- 2
reproduction/chinese_word_segment/process/cws_processor.py View File

@@ -110,9 +110,9 @@ class CWSTagProcessor(Processor):
for ins in dataset:
sentence = ins[self.field_name]
tag_list = self._generate_tag(sentence)
new_tag_field = SeqLabelField(tag_list)
ins[self.new_added_field_name] = new_tag_field
ins[self.new_added_field_name] = tag_list
dataset.set_is_target(**{self.new_added_field_name:True})
dataset.set_need_tensor(**{self.new_added_field_name:True})
return dataset

def _tags_from_word_len(self, word_len):


+ 26
- 23
reproduction/chinese_word_segment/train_context.py View File

@@ -1,6 +1,4 @@

from fastNLP.core.instance import Instance
from fastNLP.core.dataset import DataSet
from fastNLP.api.pipeline import Pipeline
from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor
from fastNLP.api.processor import IndexerProcessor
@@ -143,7 +141,7 @@ def decode_iterator(model, batcher):
from reproduction.chinese_word_segment.utils import FocalLoss
from reproduction.chinese_word_segment.utils import seq_lens_to_mask
from fastNLP.core.batch import Batch
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import BucketSampler
from fastNLP.core.sampler import SequentialSampler

import torch
@@ -159,6 +157,7 @@ cws_model = CWSBiLSTMSegApp(char_vocab_proc.get_vocab_size(), embed_dim=100,
bigram_embed_dim=100, num_bigram_per_char=8,
hidden_size=200, bidirectional=True, embed_drop_p=None,
num_layers=1, tag_size=tag_size)
cws_model.cuda()

num_epochs = 3
loss_fn = FocalLoss(class_num=tag_size)
@@ -167,7 +166,7 @@ optimizer = optim.Adagrad(cws_model.parameters(), lr=0.01)

print_every = 50
batch_size = 32
tr_batcher = Batch(tr_dataset, batch_size, RandomSampler(), use_cuda=False)
tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False)
dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False)
num_batch_per_epoch = len(tr_dataset) // batch_size
best_f1 = 0
@@ -181,10 +180,12 @@ for num_epoch in range(num_epochs):
cws_model.train()
for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1):
pred_dict = cws_model(batch_x) # B x L x tag_size
seq_lens = batch_x['seq_lens']
masks = seq_lens_to_mask(seq_lens)
tags = batch_y['tags']
loss = torch.sum(loss_fn(pred_dict['pred_prob'].view(-1, tag_size),

seq_lens = pred_dict['seq_lens']
masks = seq_lens_to_mask(seq_lens).float()
tags = batch_y['tags'].long().to(seq_lens.device)

loss = torch.sum(loss_fn(pred_dict['pred_probs'].view(-1, tag_size),
tags.view(-1)) * masks.view(-1)) / torch.sum(masks)
# loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float())

@@ -201,20 +202,20 @@ for num_epoch in range(num_epochs):
pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every))
avg_loss = 0
pbar.update(print_every)
# 验证集
pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher)
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100,
pre*100,
rec*100))
if best_f1<f1:
best_f1 = f1
# 缓存最佳的parameter,可能之后会用于保存
best_state_dict = {
key:value.clone() for key, value in
cws_model.state_dict().items()
}
best_epoch = num_epoch
tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False)
# 验证集
pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher)
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100,
pre*100,
rec*100))
if best_f1<f1:
best_f1 = f1
# 缓存最佳的parameter,可能之后会用于保存
best_state_dict = {
key:value.clone() for key, value in
cws_model.state_dict().items()
}
best_epoch = num_epoch


# 4. 组装需要存下的内容
@@ -224,4 +225,6 @@ pp.add_processor(sp_proc)
pp.add_processor(char_proc)
pp.add_processor(bigram_proc)
pp.add_processor(char_index_proc)
pp.add_processor(bigram_index_proc)
pp.add_processor(bigram_index_proc)
pp.add_processor(seq_len_proc)


Loading…
Cancel
Save