Browse Source

Merge remote-tracking branch 'origin/dataset' into dataset

tags/v0.2.0
FengZiYjun 5 years ago
parent
commit
12e9a93b52
8 changed files with 518 additions and 1 deletions
  1. +11
    -0
      fastNLP/api/api.py
  2. +0
    -1
      fastNLP/api/pipeline.py
  3. +0
    -0
      reproduction/chinese_word_segment/model/__init__.py
  4. +135
    -0
      reproduction/chinese_word_segment/model/cws_model.py
  5. +0
    -0
      reproduction/chinese_word_segment/process/__init__.py
  6. +283
    -0
      reproduction/chinese_word_segment/process/cws_processor.py
  7. +3
    -0
      reproduction/chinese_word_segment/train_context.py
  8. +86
    -0
      reproduction/chinese_word_segment/utils.py

+ 11
- 0
fastNLP/api/api.py View File

@@ -0,0 +1,11 @@


class API:
def __init__(self):
pass

def predict(self):
pass

def load(self):
pass

+ 0
- 1
fastNLP/api/pipeline.py View File

@@ -8,7 +8,6 @@ class Pipeline:

def add_processor(self, processor):
assert isinstance(processor, Processor), "Must be a Processor, not {}.".format(type(processor))
processor_name = type(processor)
self.pipeline.append(processor)

def process(self, dataset):


+ 0
- 0
reproduction/chinese_word_segment/model/__init__.py View File


+ 135
- 0
reproduction/chinese_word_segment/model/cws_model.py View File

@@ -0,0 +1,135 @@

from torch import nn
import torch
import torch.nn.functional as F

from fastNLP.modules.decoder.MLP import MLP
from fastNLP.models.base_model import BaseModel
from reproduction.chinese_word_segment.utils import seq_lens_to_mask

class CWSBiLSTMEncoder(BaseModel):
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None,
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1):
super().__init__()

self.input_size = 0
self.num_bigram_per_char = num_bigram_per_char
self.bidirectional = bidirectional
self.num_layers = num_layers
self.embed_drop_p = embed_drop_p
if self.bidirectional:
self.hidden_size = hidden_size//2
self.num_directions = 2
else:
self.hidden_size = hidden_size
self.num_directions = 1

if not bigram_vocab_num is None:
assert not bigram_vocab_num is None, "Specify num_bigram_per_char."

if vocab_num is not None:
self.char_embedding = nn.Embedding(num_embeddings=vocab_num, embedding_dim=embed_dim)
self.input_size += embed_dim

if bigram_vocab_num is not None:
self.bigram_embedding = nn.Embedding(num_embeddings=bigram_vocab_num, embedding_dim=bigram_embed_dim)
self.input_size += self.num_bigram_per_char*bigram_embed_dim

if self.num_criterion!=None:
if bidirectional:
self.backward_criterion_embedding = nn.Embedding(num_embeddings=self.num_criterion,
embedding_dim=self.hidden_size)
self.forward_criterion_embedding = nn.Embedding(num_embeddings=self.num_criterion,
embedding_dim=self.hidden_size)

if not self.embed_drop_p is None:
self.embedding_drop = nn.Dropout(p=self.embed_drop_p)

self.lstm = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, bidirectional=self.bidirectional,
batch_first=True, num_layers=self.num_layers)

self.reset_parameters()

def reset_parameters(self):
for name, param in self.named_parameters():
if 'bias_hh' in name:
nn.init.constant_(param, 0)
elif 'bias_ih' in name:
nn.init.constant_(param, 1)
else:
nn.init.xavier_uniform_(param)

def init_embedding(self, embedding, embed_name):
if embed_name == 'bigram':
self.bigram_embedding.weight.data = torch.from_numpy(embedding)
elif embed_name == 'char':
self.char_embedding.weight.data = torch.from_numpy(embedding)


def forward(self, chars, bigrams=None, seq_lens=None):

batch_size, max_len = chars.size()

x_tensor = self.char_embedding(chars)

if not bigrams is None:
bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1)
x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2)

sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True)
packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True)

outputs, _ = self.lstm(packed_x)
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)

_, desorted_indices = torch.sort(sorted_indices, descending=False)
outputs = outputs[desorted_indices]

return outputs


class CWSBiLSTMSegApp(BaseModel):
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None,
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=2):
super(CWSBiLSTMSegApp, self).__init__()

self.tag_size = tag_size

self.encoder_model = CWSBiLSTMEncoder(vocab_num, embed_dim, bigram_vocab_num, bigram_embed_dim, num_bigram_per_char,
hidden_size, bidirectional, embed_drop_p, num_layers)

size_layer = [hidden_size, 100, tag_size]
self.decoder_model = MLP(size_layer)


def forward(self, **kwargs):
chars = kwargs['chars']
if 'bigram' in kwargs:
bigrams = kwargs['bigrams']
else:
bigrams = None
seq_lens = kwargs['seq_lens']

feats = self.encoder_model(chars, bigrams, seq_lens)
probs = self.decoder_model(feats)

pred_dict = {}
pred_dict['seq_lens'] = seq_lens
pred_dict['pred_prob'] = probs

return pred_dict

def loss_fn(self, pred_dict, true_dict):
seq_lens = pred_dict['seq_lens']
masks = seq_lens_to_mask(seq_lens).float()

pred_prob = pred_dict['pred_prob']
true_y = true_dict['tags']

# TODO 当前把loss写死了
loss = F.cross_entropy(pred_prob.view(-1, self.tag_size),
true_y.view(-1), reduction='none')*masks.view(-1)/torch.sum(masks)


return loss


+ 0
- 0
reproduction/chinese_word_segment/process/__init__.py View File


+ 283
- 0
reproduction/chinese_word_segment/process/cws_processor.py View File

@@ -0,0 +1,283 @@

import re


from fastNLP.core.field import SeqLabelField
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet

from fastNLP.api.processor import Processor


_SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>'

class FullSpaceToHalfSpaceProcessor(Processor):
def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True,
change_space=True):
super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None)

self.change_alpha = change_alpha
self.change_digit = change_digit
self.change_punctuation = change_punctuation
self.change_space = change_space

FH_SPACE = [(u" ", u" ")]
FH_NUM = [
(u"0", u"0"), (u"1", u"1"), (u"2", u"2"), (u"3", u"3"), (u"4", u"4"),
(u"5", u"5"), (u"6", u"6"), (u"7", u"7"), (u"8", u"8"), (u"9", u"9")]
FH_ALPHA = [
(u"a", u"a"), (u"b", u"b"), (u"c", u"c"), (u"d", u"d"), (u"e", u"e"),
(u"f", u"f"), (u"g", u"g"), (u"h", u"h"), (u"i", u"i"), (u"j", u"j"),
(u"k", u"k"), (u"l", u"l"), (u"m", u"m"), (u"n", u"n"), (u"o", u"o"),
(u"p", u"p"), (u"q", u"q"), (u"r", u"r"), (u"s", u"s"), (u"t", u"t"),
(u"u", u"u"), (u"v", u"v"), (u"w", u"w"), (u"x", u"x"), (u"y", u"y"),
(u"z", u"z"),
(u"A", u"A"), (u"B", u"B"), (u"C", u"C"), (u"D", u"D"), (u"E", u"E"),
(u"F", u"F"), (u"G", u"G"), (u"H", u"H"), (u"I", u"I"), (u"J", u"J"),
(u"K", u"K"), (u"L", u"L"), (u"M", u"M"), (u"N", u"N"), (u"O", u"O"),
(u"P", u"P"), (u"Q", u"Q"), (u"R", u"R"), (u"S", u"S"), (u"T", u"T"),
(u"U", u"U"), (u"V", u"V"), (u"W", u"W"), (u"X", u"X"), (u"Y", u"Y"),
(u"Z", u"Z")]
# 谨慎使用标点符号转换, 因为"5.12特大地震"转换后可能就成了"5.12特大地震"
FH_PUNCTUATION = [
(u'%', u'%'), (u'!', u'!'), (u'"', u'\"'), (u''', u'\''), (u'#', u'#'),
(u'¥', u'$'), (u'&', u'&'), (u'(', u'('), (u')', u')'), (u'*', u'*'),
(u'+', u'+'), (u',', u','), (u'-', u'-'), (u'.', u'.'), (u'/', u'/'),
(u':', u':'), (u';', u';'), (u'<', u'<'), (u'=', u'='), (u'>', u'>'),
(u'?', u'?'), (u'@', u'@'), (u'[', u'['), (u']', u']'), (u'\', u'\\'),
(u'^', u'^'), (u'_', u'_'), (u'`', u'`'), (u'~', u'~'), (u'{', u'{'),
(u'}', u'}'), (u'|', u'|')]
FHs = []
if self.change_alpha:
FHs = FH_ALPHA
if self.change_digit:
FHs += FH_NUM
if self.change_punctuation:
FHs += FH_PUNCTUATION
if self.change_space:
FHs += FH_SPACE
self.convert_map = {k: v for k, v in FHs}
def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
sentence = ins[self.field_name].text
new_sentence = [None]*len(sentence)
for idx, char in enumerate(sentence):
if char in self.convert_map:
char = self.convert_map[char]
new_sentence[idx] = char
ins[self.field_name].text = ''.join(new_sentence)
return dataset


class SpeicalSpanProcessor(Processor):
# 这个类会将句子中的special span转换为对应的内容。
def __init__(self, field_name, new_added_field_name=None):
super(SpeicalSpanProcessor, self).__init__(field_name, new_added_field_name)

self.span_converters = []


def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
sentence = ins[self.field_name].text
for span_converter in self.span_converters:
sentence = span_converter.find_certain_span_and_replace(sentence)
if self.new_added_field_name!=self.field_name:
new_text_field = TextField(sentence, is_target=False)
ins[self.new_added_field_name] = new_text_field
else:
ins[self.field_name].text = sentence

return dataset

def add_span_converter(self, converter):
assert isinstance(converter, SpanConverterBase), "Only SpanConverterBase is allowed, not {}."\
.format(type(converter))
self.span_converters.append(converter)



class CWSCharSegProcessor(Processor):
def __init__(self, field_name, new_added_field_name):
super(CWSCharSegProcessor, self).__init__(field_name, new_added_field_name)

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
sentence = ins[self.field_name].text
chars = self._split_sent_into_chars(sentence)
new_token_field = TokenListFiled(chars, is_target=False)
ins[self.new_added_field_name] = new_token_field

return dataset

def _split_sent_into_chars(self, sentence):
sp_tag_match_iter = re.finditer(_SPECIAL_TAG_PATTERN, sentence)
sp_spans = [match_span.span() for match_span in sp_tag_match_iter]
sp_span_idx = 0
in_span_flag = False
chars = []
num_spans = len(sp_spans)
for idx, char in enumerate(sentence):
if sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][0]:
in_span_flag = True
elif in_span_flag and sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][1] - 1:
chars.append(sentence[sp_spans[sp_span_idx]
[0]:sp_spans[sp_span_idx][1]])
in_span_flag = False
sp_span_idx += 1
elif not in_span_flag:
# TODO 需要谨慎考虑如何处理空格的问题
if char != ' ':
chars.append(char)
else:
pass
return chars


class CWSTagProcessor(Processor):
def __init__(self, field_name, new_added_field_name=None):
super(CWSTagProcessor, self).__init__(field_name, new_added_field_name)

def _generate_tag(self, sentence):
sp_tag_match_iter = re.finditer(_SPECIAL_TAG_PATTERN, sentence)
sp_spans = [match_span.span() for match_span in sp_tag_match_iter]
sp_span_idx = 0
in_span_flag = False
tag_list = []
word_len = 0
num_spans = len(sp_spans)
for idx, char in enumerate(sentence):
if sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][0]:
in_span_flag = True
elif in_span_flag and sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][1] - 1:
word_len += 1
in_span_flag = False
sp_span_idx += 1
elif not in_span_flag:
if char == ' ':
if word_len!=0:
tag_list.extend(self._tags_from_word_len(word_len))
word_len = 0
else:
word_len += 1
else:
pass
if word_len!=0:
tag_list.extend(self._tags_from_word_len(word_len))

return tag_list

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
sentence = ins[self.field_name].text
tag_list = self._generate_tag(sentence)
new_tag_field = SeqLabelField(tag_list)
ins[self.new_added_field_name] = new_tag_field
return dataset

def _tags_from_word_len(self, word_len):
raise NotImplementedError


class CWSSegAppTagProcessor(CWSTagProcessor):
def __init__(self, field_name, new_added_field_name=None):
super(CWSSegAppTagProcessor, self).__init__(field_name, new_added_field_name)

def _tags_from_word_len(self, word_len):
tag_list = []
for _ in range(word_len-1):
tag_list.append(0)
tag_list.append(1)
return tag_list


class BigramProcessor(Processor):
def __init__(self, field_name, new_added_fielf_name=None):

super(BigramProcessor, self).__init__(field_name, new_added_fielf_name)

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))

for ins in dataset:
characters = ins[self.field_name].content
bigrams = self._generate_bigram(characters)
new_token_field = TokenListFiled(bigrams)
ins[self.new_added_field_name] = new_token_field

return dataset


def _generate_bigram(self, characters):
pass


class Pre2Post2BigramProcessor(BigramProcessor):
def __init__(self, field_name, new_added_fielf_name=None):

super(BigramProcessor, self).__init__(field_name, new_added_fielf_name)

def _generate_bigram(self, characters):
bigrams = []
characters = ['<SOS>', '<SOS>'] + characters + ['<EOS>', '<EOS>']
for idx in range(2, len(characters)-2):
cur_char = characters[idx]
pre_pre_char = characters[idx-2]
pre_char = characters[idx-1]
post_char = characters[idx+1]
post_post_char = characters[idx+2]
pre_pre_cur_bigram = pre_pre_char + cur_char
pre_cur_bigram = pre_char + cur_char
cur_post_bigram = cur_char + post_char
cur_post_post_bigram = cur_char + post_post_char
bigrams.extend([pre_pre_char, pre_char, post_char, post_post_char,
pre_pre_cur_bigram, pre_cur_bigram,
cur_post_bigram, cur_post_post_bigram])
return bigrams


# 这里需要建立vocabulary了,但是遇到了以下的问题
# (1) 如果使用Processor的方式的话,但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现,不借用
# Processor了
class IndexProcessor(Processor):
def __init__(self, vocab, field_name):

assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))

super(IndexProcessor, self).__init__(field_name, None)
self.vocab = vocab

def set_vocab(self, vocab):
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))

self.vocab = vocab

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
tokens = ins[self.field_name].content
index = [self.vocab.to_index(token) for token in tokens]
ins[self.field_name]._index = index

return dataset


class VocabProcessor(Processor):
def __init__(self, field_name):

super(VocabProcessor, self).__init__(field_name, None)
self.vocab = Vocabulary()

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
tokens = ins[self.field_name].content
self.vocab.update(tokens)

def get_vocab(self):
self.vocab.build_vocab()
return self.vocab

+ 3
- 0
reproduction/chinese_word_segment/train_context.py View File

@@ -0,0 +1,3 @@




+ 86
- 0
reproduction/chinese_word_segment/utils.py View File

@@ -0,0 +1,86 @@

import torch


def seq_lens_to_mask(seq_lens):
batch_size = seq_lens.size(0)
max_len = seq_lens.max()

indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device)
masks = indexes.lt(seq_lens.unsqueeze(1))

return masks


def cut_long_training_sentences(sentences, max_sample_length=200):
cutted_sentence = []
for sent in sentences:
sent_no_space = sent.replace(' ', '')
if len(sent_no_space) > max_sample_length:
parts = sent.strip().split()
new_line = ''
length = 0
for part in parts:
length += len(part)
new_line += part + ' '
if length > max_sample_length:
new_line = new_line[:-1]
cutted_sentence.append(new_line)
length = 0
new_line = ''
if new_line != '':
cutted_sentence.append(new_line[:-1])
else:
cutted_sentence.append(sent)
return cutted_sentence


from torch import nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
r"""
This criterion is a implemenation of Focal Loss, which is proposed in
Focal Loss for Dense Object Detection.

Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])

The losses are averaged across observations for each minibatch.
Args:
alpha(1D Tensor, Variable) : the scalar factor for this criterion
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
putting more focus on hard, misclassified examples
size_average(bool): size_average(bool): By default, the losses are averaged over observations for each minibatch.
However, if the field size_average is set to False, the losses are
instead summed for each minibatch.
"""

def __init__(self, class_num, gamma=2, size_average=True, reduce=False):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.class_num = class_num
self.size_average = size_average
self.reduce = reduce

def forward(self, inputs, targets):
N = inputs.size(0)
C = inputs.size(1)
P = F.softmax(inputs, dim=-1)

class_mask = inputs.data.new(N, C).fill_(0)
class_mask.requires_grad = True
ids = targets.view(-1, 1)
class_mask = class_mask.scatter(1, ids.data, 1.)

probs = (P * class_mask).sum(1).view(-1, 1)

log_p = probs.log()

batch_loss = - (torch.pow((1 - probs), self.gamma)) * log_p
if self.reduce:
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss
return batch_loss

Loading…
Cancel
Save