@@ -72,7 +72,8 @@ class DataSet(object): | |||
self.field_arrays[name].append(field) | |||
def add_field(self, name, fields): | |||
assert len(self) == len(fields) | |||
if len(self.field_arrays)!=0: | |||
assert len(self) == len(fields) | |||
self.field_arrays[name] = FieldArray(name, fields) | |||
def delete_field(self, name): | |||
@@ -28,11 +28,15 @@ class FieldArray(object): | |||
return self.content[idxes] | |||
assert self.need_tensor is True | |||
batch_size = len(idxes) | |||
max_len = max([len(self.content[i]) for i in idxes]) | |||
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32) | |||
for i, idx in enumerate(idxes): | |||
array[i][:len(self.content[idx])] = self.content[idx] | |||
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 | |||
if isinstance(self.content[0], int) or isinstance(self.content[0], float): | |||
array = np.array([self.content[i] for i in idxes], dtype=type(self.content[0])) | |||
else: | |||
max_len = max([len(self.content[i]) for i in idxes]) | |||
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32) | |||
for i, idx in enumerate(idxes): | |||
array[i][:len(self.content[idx])] = self.content[idx] | |||
return array | |||
def __len__(self): | |||
@@ -1,6 +1,6 @@ | |||
import numpy as np | |||
import torch | |||
from itertools import chain | |||
def convert_to_torch_tensor(data_list, use_cuda): | |||
"""Convert lists into (cuda) Tensors. | |||
@@ -43,6 +43,47 @@ class RandomSampler(BaseSampler): | |||
def __call__(self, data_set): | |||
return list(np.random.permutation(len(data_set))) | |||
class BucketSampler(BaseSampler): | |||
def __init__(self, num_buckets=10, batch_size=32): | |||
self.num_buckets = num_buckets | |||
self.batch_size = batch_size | |||
def __call__(self, data_set): | |||
assert 'seq_lens' in data_set, "BuckectSampler only support data_set with seq_lens right now." | |||
seq_lens = data_set['seq_lens'].content | |||
total_sample_num = len(seq_lens) | |||
bucket_indexes = [] | |||
num_sample_per_bucket = total_sample_num//self.num_buckets | |||
for i in range(self.num_buckets): | |||
bucket_indexes.append([num_sample_per_bucket*i, num_sample_per_bucket*(i+1)]) | |||
bucket_indexes[-1][1] = total_sample_num | |||
sorted_seq_lens = list(sorted([(idx, seq_len) for | |||
idx, seq_len in zip(range(total_sample_num), seq_lens)], | |||
key=lambda x:x[1])) | |||
batchs = [] | |||
left_init_indexes = [] | |||
for b_idx in range(self.num_buckets): | |||
start_idx = bucket_indexes[b_idx][0] | |||
end_idx = bucket_indexes[b_idx][1] | |||
sorted_bucket_seq_lens = sorted_seq_lens[start_idx:end_idx] | |||
left_init_indexes.extend([tup[0] for tup in sorted_bucket_seq_lens]) | |||
num_batch_per_bucket = len(left_init_indexes)//self.batch_size | |||
np.random.shuffle(left_init_indexes) | |||
for i in range(num_batch_per_bucket): | |||
batchs.append(left_init_indexes[i*self.batch_size:(i+1)*self.batch_size]) | |||
left_init_indexes = left_init_indexes[num_batch_per_bucket*self.batch_size:] | |||
np.random.shuffle(batchs) | |||
return list(chain(*batchs)) | |||
def simple_sort_bucketing(lengths): | |||
""" | |||
@@ -68,7 +68,6 @@ class CWSBiLSTMEncoder(BaseModel): | |||
if not bigrams is None: | |||
bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1) | |||
x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2) | |||
sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True) | |||
packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True) | |||
@@ -97,36 +96,22 @@ class CWSBiLSTMSegApp(BaseModel): | |||
def forward(self, batch_dict): | |||
device = self.parameters().__next__().device | |||
chars = batch_dict['indexed_chars_list'].to(device) | |||
if 'bigram' in batch_dict: | |||
bigrams = batch_dict['indexed_chars_list'].to(device) | |||
chars = batch_dict['indexed_chars_list'].to(device).long() | |||
if 'indexed_bigrams_list' in batch_dict: | |||
bigrams = batch_dict['indexed_bigrams_list'].to(device).long() | |||
else: | |||
bigrams = None | |||
seq_lens = batch_dict['seq_lens'].to(device) | |||
seq_lens = batch_dict['seq_lens'].to(device).long() | |||
feats = self.encoder_model(chars, bigrams, seq_lens) | |||
probs = self.decoder_model(feats) | |||
pred_dict = {} | |||
pred_dict['seq_lens'] = seq_lens | |||
pred_dict['pred_prob'] = probs | |||
pred_dict['pred_probs'] = probs | |||
return pred_dict | |||
def predict(self, batch_dict): | |||
pass | |||
def loss_fn(self, pred_dict, true_dict): | |||
seq_lens = pred_dict['seq_lens'] | |||
masks = seq_lens_to_mask(seq_lens).float() | |||
pred_prob = pred_dict['pred_prob'] | |||
true_y = true_dict['tags'] | |||
# TODO 当前把loss写死了 | |||
loss = F.cross_entropy(pred_prob.view(-1, self.tag_size), | |||
true_y.view(-1), reduction='none')*masks.view(-1)/torch.sum(masks) | |||
return loss |
@@ -110,9 +110,9 @@ class CWSTagProcessor(Processor): | |||
for ins in dataset: | |||
sentence = ins[self.field_name] | |||
tag_list = self._generate_tag(sentence) | |||
new_tag_field = SeqLabelField(tag_list) | |||
ins[self.new_added_field_name] = new_tag_field | |||
ins[self.new_added_field_name] = tag_list | |||
dataset.set_is_target(**{self.new_added_field_name:True}) | |||
dataset.set_need_tensor(**{self.new_added_field_name:True}) | |||
return dataset | |||
def _tags_from_word_len(self, word_len): | |||
@@ -1,6 +1,4 @@ | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.api.pipeline import Pipeline | |||
from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor | |||
from fastNLP.api.processor import IndexerProcessor | |||
@@ -143,7 +141,7 @@ def decode_iterator(model, batcher): | |||
from reproduction.chinese_word_segment.utils import FocalLoss | |||
from reproduction.chinese_word_segment.utils import seq_lens_to_mask | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.sampler import RandomSampler | |||
from fastNLP.core.sampler import BucketSampler | |||
from fastNLP.core.sampler import SequentialSampler | |||
import torch | |||
@@ -159,6 +157,7 @@ cws_model = CWSBiLSTMSegApp(char_vocab_proc.get_vocab_size(), embed_dim=100, | |||
bigram_embed_dim=100, num_bigram_per_char=8, | |||
hidden_size=200, bidirectional=True, embed_drop_p=None, | |||
num_layers=1, tag_size=tag_size) | |||
cws_model.cuda() | |||
num_epochs = 3 | |||
loss_fn = FocalLoss(class_num=tag_size) | |||
@@ -167,7 +166,7 @@ optimizer = optim.Adagrad(cws_model.parameters(), lr=0.01) | |||
print_every = 50 | |||
batch_size = 32 | |||
tr_batcher = Batch(tr_dataset, batch_size, RandomSampler(), use_cuda=False) | |||
tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) | |||
dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False) | |||
num_batch_per_epoch = len(tr_dataset) // batch_size | |||
best_f1 = 0 | |||
@@ -181,10 +180,12 @@ for num_epoch in range(num_epochs): | |||
cws_model.train() | |||
for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): | |||
pred_dict = cws_model(batch_x) # B x L x tag_size | |||
seq_lens = batch_x['seq_lens'] | |||
masks = seq_lens_to_mask(seq_lens) | |||
tags = batch_y['tags'] | |||
loss = torch.sum(loss_fn(pred_dict['pred_prob'].view(-1, tag_size), | |||
seq_lens = pred_dict['seq_lens'] | |||
masks = seq_lens_to_mask(seq_lens).float() | |||
tags = batch_y['tags'].long().to(seq_lens.device) | |||
loss = torch.sum(loss_fn(pred_dict['pred_probs'].view(-1, tag_size), | |||
tags.view(-1)) * masks.view(-1)) / torch.sum(masks) | |||
# loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float()) | |||
@@ -201,20 +202,20 @@ for num_epoch in range(num_epochs): | |||
pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every)) | |||
avg_loss = 0 | |||
pbar.update(print_every) | |||
# 验证集 | |||
pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher) | |||
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100, | |||
pre*100, | |||
rec*100)) | |||
if best_f1<f1: | |||
best_f1 = f1 | |||
# 缓存最佳的parameter,可能之后会用于保存 | |||
best_state_dict = { | |||
key:value.clone() for key, value in | |||
cws_model.state_dict().items() | |||
} | |||
best_epoch = num_epoch | |||
tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) | |||
# 验证集 | |||
pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher) | |||
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100, | |||
pre*100, | |||
rec*100)) | |||
if best_f1<f1: | |||
best_f1 = f1 | |||
# 缓存最佳的parameter,可能之后会用于保存 | |||
best_state_dict = { | |||
key:value.clone() for key, value in | |||
cws_model.state_dict().items() | |||
} | |||
best_epoch = num_epoch | |||
# 4. 组装需要存下的内容 | |||
@@ -224,4 +225,6 @@ pp.add_processor(sp_proc) | |||
pp.add_processor(char_proc) | |||
pp.add_processor(bigram_proc) | |||
pp.add_processor(char_index_proc) | |||
pp.add_processor(bigram_index_proc) | |||
pp.add_processor(bigram_index_proc) | |||
pp.add_processor(seq_len_proc) | |||