@@ -72,7 +72,8 @@ class DataSet(object): | |||||
self.field_arrays[name].append(field) | self.field_arrays[name].append(field) | ||||
def add_field(self, name, fields): | def add_field(self, name, fields): | ||||
assert len(self) == len(fields) | |||||
if len(self.field_arrays)!=0: | |||||
assert len(self) == len(fields) | |||||
self.field_arrays[name] = FieldArray(name, fields) | self.field_arrays[name] = FieldArray(name, fields) | ||||
def delete_field(self, name): | def delete_field(self, name): | ||||
@@ -28,11 +28,15 @@ class FieldArray(object): | |||||
return self.content[idxes] | return self.content[idxes] | ||||
assert self.need_tensor is True | assert self.need_tensor is True | ||||
batch_size = len(idxes) | batch_size = len(idxes) | ||||
max_len = max([len(self.content[i]) for i in idxes]) | |||||
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32) | |||||
for i, idx in enumerate(idxes): | |||||
array[i][:len(self.content[idx])] = self.content[idx] | |||||
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 | |||||
if isinstance(self.content[0], int) or isinstance(self.content[0], float): | |||||
array = np.array([self.content[i] for i in idxes], dtype=type(self.content[0])) | |||||
else: | |||||
max_len = max([len(self.content[i]) for i in idxes]) | |||||
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32) | |||||
for i, idx in enumerate(idxes): | |||||
array[i][:len(self.content[idx])] = self.content[idx] | |||||
return array | return array | ||||
def __len__(self): | def __len__(self): | ||||
@@ -1,6 +1,6 @@ | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from itertools import chain | |||||
def convert_to_torch_tensor(data_list, use_cuda): | def convert_to_torch_tensor(data_list, use_cuda): | ||||
"""Convert lists into (cuda) Tensors. | """Convert lists into (cuda) Tensors. | ||||
@@ -43,6 +43,47 @@ class RandomSampler(BaseSampler): | |||||
def __call__(self, data_set): | def __call__(self, data_set): | ||||
return list(np.random.permutation(len(data_set))) | return list(np.random.permutation(len(data_set))) | ||||
class BucketSampler(BaseSampler): | |||||
def __init__(self, num_buckets=10, batch_size=32): | |||||
self.num_buckets = num_buckets | |||||
self.batch_size = batch_size | |||||
def __call__(self, data_set): | |||||
assert 'seq_lens' in data_set, "BuckectSampler only support data_set with seq_lens right now." | |||||
seq_lens = data_set['seq_lens'].content | |||||
total_sample_num = len(seq_lens) | |||||
bucket_indexes = [] | |||||
num_sample_per_bucket = total_sample_num//self.num_buckets | |||||
for i in range(self.num_buckets): | |||||
bucket_indexes.append([num_sample_per_bucket*i, num_sample_per_bucket*(i+1)]) | |||||
bucket_indexes[-1][1] = total_sample_num | |||||
sorted_seq_lens = list(sorted([(idx, seq_len) for | |||||
idx, seq_len in zip(range(total_sample_num), seq_lens)], | |||||
key=lambda x:x[1])) | |||||
batchs = [] | |||||
left_init_indexes = [] | |||||
for b_idx in range(self.num_buckets): | |||||
start_idx = bucket_indexes[b_idx][0] | |||||
end_idx = bucket_indexes[b_idx][1] | |||||
sorted_bucket_seq_lens = sorted_seq_lens[start_idx:end_idx] | |||||
left_init_indexes.extend([tup[0] for tup in sorted_bucket_seq_lens]) | |||||
num_batch_per_bucket = len(left_init_indexes)//self.batch_size | |||||
np.random.shuffle(left_init_indexes) | |||||
for i in range(num_batch_per_bucket): | |||||
batchs.append(left_init_indexes[i*self.batch_size:(i+1)*self.batch_size]) | |||||
left_init_indexes = left_init_indexes[num_batch_per_bucket*self.batch_size:] | |||||
np.random.shuffle(batchs) | |||||
return list(chain(*batchs)) | |||||
def simple_sort_bucketing(lengths): | def simple_sort_bucketing(lengths): | ||||
""" | """ | ||||
@@ -68,7 +68,6 @@ class CWSBiLSTMEncoder(BaseModel): | |||||
if not bigrams is None: | if not bigrams is None: | ||||
bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1) | bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1) | ||||
x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2) | x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2) | ||||
sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True) | sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True) | ||||
packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True) | packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True) | ||||
@@ -97,36 +96,22 @@ class CWSBiLSTMSegApp(BaseModel): | |||||
def forward(self, batch_dict): | def forward(self, batch_dict): | ||||
device = self.parameters().__next__().device | device = self.parameters().__next__().device | ||||
chars = batch_dict['indexed_chars_list'].to(device) | |||||
if 'bigram' in batch_dict: | |||||
bigrams = batch_dict['indexed_chars_list'].to(device) | |||||
chars = batch_dict['indexed_chars_list'].to(device).long() | |||||
if 'indexed_bigrams_list' in batch_dict: | |||||
bigrams = batch_dict['indexed_bigrams_list'].to(device).long() | |||||
else: | else: | ||||
bigrams = None | bigrams = None | ||||
seq_lens = batch_dict['seq_lens'].to(device) | |||||
seq_lens = batch_dict['seq_lens'].to(device).long() | |||||
feats = self.encoder_model(chars, bigrams, seq_lens) | feats = self.encoder_model(chars, bigrams, seq_lens) | ||||
probs = self.decoder_model(feats) | probs = self.decoder_model(feats) | ||||
pred_dict = {} | pred_dict = {} | ||||
pred_dict['seq_lens'] = seq_lens | pred_dict['seq_lens'] = seq_lens | ||||
pred_dict['pred_prob'] = probs | |||||
pred_dict['pred_probs'] = probs | |||||
return pred_dict | return pred_dict | ||||
def predict(self, batch_dict): | def predict(self, batch_dict): | ||||
pass | pass | ||||
def loss_fn(self, pred_dict, true_dict): | |||||
seq_lens = pred_dict['seq_lens'] | |||||
masks = seq_lens_to_mask(seq_lens).float() | |||||
pred_prob = pred_dict['pred_prob'] | |||||
true_y = true_dict['tags'] | |||||
# TODO 当前把loss写死了 | |||||
loss = F.cross_entropy(pred_prob.view(-1, self.tag_size), | |||||
true_y.view(-1), reduction='none')*masks.view(-1)/torch.sum(masks) | |||||
return loss |
@@ -110,9 +110,9 @@ class CWSTagProcessor(Processor): | |||||
for ins in dataset: | for ins in dataset: | ||||
sentence = ins[self.field_name] | sentence = ins[self.field_name] | ||||
tag_list = self._generate_tag(sentence) | tag_list = self._generate_tag(sentence) | ||||
new_tag_field = SeqLabelField(tag_list) | |||||
ins[self.new_added_field_name] = new_tag_field | |||||
ins[self.new_added_field_name] = tag_list | |||||
dataset.set_is_target(**{self.new_added_field_name:True}) | dataset.set_is_target(**{self.new_added_field_name:True}) | ||||
dataset.set_need_tensor(**{self.new_added_field_name:True}) | |||||
return dataset | return dataset | ||||
def _tags_from_word_len(self, word_len): | def _tags_from_word_len(self, word_len): | ||||
@@ -1,6 +1,4 @@ | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.api.pipeline import Pipeline | from fastNLP.api.pipeline import Pipeline | ||||
from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor | from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor | ||||
from fastNLP.api.processor import IndexerProcessor | from fastNLP.api.processor import IndexerProcessor | ||||
@@ -143,7 +141,7 @@ def decode_iterator(model, batcher): | |||||
from reproduction.chinese_word_segment.utils import FocalLoss | from reproduction.chinese_word_segment.utils import FocalLoss | ||||
from reproduction.chinese_word_segment.utils import seq_lens_to_mask | from reproduction.chinese_word_segment.utils import seq_lens_to_mask | ||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.sampler import RandomSampler | |||||
from fastNLP.core.sampler import BucketSampler | |||||
from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
import torch | import torch | ||||
@@ -159,6 +157,7 @@ cws_model = CWSBiLSTMSegApp(char_vocab_proc.get_vocab_size(), embed_dim=100, | |||||
bigram_embed_dim=100, num_bigram_per_char=8, | bigram_embed_dim=100, num_bigram_per_char=8, | ||||
hidden_size=200, bidirectional=True, embed_drop_p=None, | hidden_size=200, bidirectional=True, embed_drop_p=None, | ||||
num_layers=1, tag_size=tag_size) | num_layers=1, tag_size=tag_size) | ||||
cws_model.cuda() | |||||
num_epochs = 3 | num_epochs = 3 | ||||
loss_fn = FocalLoss(class_num=tag_size) | loss_fn = FocalLoss(class_num=tag_size) | ||||
@@ -167,7 +166,7 @@ optimizer = optim.Adagrad(cws_model.parameters(), lr=0.01) | |||||
print_every = 50 | print_every = 50 | ||||
batch_size = 32 | batch_size = 32 | ||||
tr_batcher = Batch(tr_dataset, batch_size, RandomSampler(), use_cuda=False) | |||||
tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) | |||||
dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False) | dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False) | ||||
num_batch_per_epoch = len(tr_dataset) // batch_size | num_batch_per_epoch = len(tr_dataset) // batch_size | ||||
best_f1 = 0 | best_f1 = 0 | ||||
@@ -181,10 +180,12 @@ for num_epoch in range(num_epochs): | |||||
cws_model.train() | cws_model.train() | ||||
for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): | for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): | ||||
pred_dict = cws_model(batch_x) # B x L x tag_size | pred_dict = cws_model(batch_x) # B x L x tag_size | ||||
seq_lens = batch_x['seq_lens'] | |||||
masks = seq_lens_to_mask(seq_lens) | |||||
tags = batch_y['tags'] | |||||
loss = torch.sum(loss_fn(pred_dict['pred_prob'].view(-1, tag_size), | |||||
seq_lens = pred_dict['seq_lens'] | |||||
masks = seq_lens_to_mask(seq_lens).float() | |||||
tags = batch_y['tags'].long().to(seq_lens.device) | |||||
loss = torch.sum(loss_fn(pred_dict['pred_probs'].view(-1, tag_size), | |||||
tags.view(-1)) * masks.view(-1)) / torch.sum(masks) | tags.view(-1)) * masks.view(-1)) / torch.sum(masks) | ||||
# loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float()) | # loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float()) | ||||
@@ -201,20 +202,20 @@ for num_epoch in range(num_epochs): | |||||
pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every)) | pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every)) | ||||
avg_loss = 0 | avg_loss = 0 | ||||
pbar.update(print_every) | pbar.update(print_every) | ||||
# 验证集 | |||||
pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher) | |||||
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100, | |||||
pre*100, | |||||
rec*100)) | |||||
if best_f1<f1: | |||||
best_f1 = f1 | |||||
# 缓存最佳的parameter,可能之后会用于保存 | |||||
best_state_dict = { | |||||
key:value.clone() for key, value in | |||||
cws_model.state_dict().items() | |||||
} | |||||
best_epoch = num_epoch | |||||
tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) | |||||
# 验证集 | |||||
pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher) | |||||
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100, | |||||
pre*100, | |||||
rec*100)) | |||||
if best_f1<f1: | |||||
best_f1 = f1 | |||||
# 缓存最佳的parameter,可能之后会用于保存 | |||||
best_state_dict = { | |||||
key:value.clone() for key, value in | |||||
cws_model.state_dict().items() | |||||
} | |||||
best_epoch = num_epoch | |||||
# 4. 组装需要存下的内容 | # 4. 组装需要存下的内容 | ||||
@@ -224,4 +225,6 @@ pp.add_processor(sp_proc) | |||||
pp.add_processor(char_proc) | pp.add_processor(char_proc) | ||||
pp.add_processor(bigram_proc) | pp.add_processor(bigram_proc) | ||||
pp.add_processor(char_index_proc) | pp.add_processor(char_index_proc) | ||||
pp.add_processor(bigram_index_proc) | |||||
pp.add_processor(bigram_index_proc) | |||||
pp.add_processor(seq_len_proc) | |||||