diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py new file mode 100644 index 00000000..202f782f --- /dev/null +++ b/fastNLP/api/api.py @@ -0,0 +1,11 @@ + + +class API: + def __init__(self): + pass + + def predict(self): + pass + + def load(self): + pass \ No newline at end of file diff --git a/fastNLP/api/pipeline.py b/fastNLP/api/pipeline.py index b5c4cc7a..745c8874 100644 --- a/fastNLP/api/pipeline.py +++ b/fastNLP/api/pipeline.py @@ -8,7 +8,6 @@ class Pipeline: def add_processor(self, processor): assert isinstance(processor, Processor), "Must be a Processor, not {}.".format(type(processor)) - processor_name = type(processor) self.pipeline.append(processor) def process(self, dataset): diff --git a/reproduction/chinese_word_segment/model/__init__.py b/reproduction/chinese_word_segment/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/chinese_word_segment/model/cws_model.py b/reproduction/chinese_word_segment/model/cws_model.py new file mode 100644 index 00000000..dfcfcafe --- /dev/null +++ b/reproduction/chinese_word_segment/model/cws_model.py @@ -0,0 +1,135 @@ + +from torch import nn +import torch +import torch.nn.functional as F + +from fastNLP.modules.decoder.MLP import MLP +from fastNLP.models.base_model import BaseModel +from reproduction.chinese_word_segment.utils import seq_lens_to_mask + +class CWSBiLSTMEncoder(BaseModel): + def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, + hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1): + super().__init__() + + self.input_size = 0 + self.num_bigram_per_char = num_bigram_per_char + self.bidirectional = bidirectional + self.num_layers = num_layers + self.embed_drop_p = embed_drop_p + if self.bidirectional: + self.hidden_size = hidden_size//2 + self.num_directions = 2 + else: + self.hidden_size = hidden_size + self.num_directions = 1 + + if not bigram_vocab_num is None: + assert not bigram_vocab_num is None, "Specify num_bigram_per_char." + + if vocab_num is not None: + self.char_embedding = nn.Embedding(num_embeddings=vocab_num, embedding_dim=embed_dim) + self.input_size += embed_dim + + if bigram_vocab_num is not None: + self.bigram_embedding = nn.Embedding(num_embeddings=bigram_vocab_num, embedding_dim=bigram_embed_dim) + self.input_size += self.num_bigram_per_char*bigram_embed_dim + + if self.num_criterion!=None: + if bidirectional: + self.backward_criterion_embedding = nn.Embedding(num_embeddings=self.num_criterion, + embedding_dim=self.hidden_size) + self.forward_criterion_embedding = nn.Embedding(num_embeddings=self.num_criterion, + embedding_dim=self.hidden_size) + + if not self.embed_drop_p is None: + self.embedding_drop = nn.Dropout(p=self.embed_drop_p) + + self.lstm = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, bidirectional=self.bidirectional, + batch_first=True, num_layers=self.num_layers) + + self.reset_parameters() + + def reset_parameters(self): + for name, param in self.named_parameters(): + if 'bias_hh' in name: + nn.init.constant_(param, 0) + elif 'bias_ih' in name: + nn.init.constant_(param, 1) + else: + nn.init.xavier_uniform_(param) + + def init_embedding(self, embedding, embed_name): + if embed_name == 'bigram': + self.bigram_embedding.weight.data = torch.from_numpy(embedding) + elif embed_name == 'char': + self.char_embedding.weight.data = torch.from_numpy(embedding) + + + def forward(self, chars, bigrams=None, seq_lens=None): + + batch_size, max_len = chars.size() + + x_tensor = self.char_embedding(chars) + + if not bigrams is None: + bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1) + x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2) + + sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True) + packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True) + + outputs, _ = self.lstm(packed_x) + outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) + + _, desorted_indices = torch.sort(sorted_indices, descending=False) + outputs = outputs[desorted_indices] + + return outputs + + +class CWSBiLSTMSegApp(BaseModel): + def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, + hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=2): + super(CWSBiLSTMSegApp, self).__init__() + + self.tag_size = tag_size + + self.encoder_model = CWSBiLSTMEncoder(vocab_num, embed_dim, bigram_vocab_num, bigram_embed_dim, num_bigram_per_char, + hidden_size, bidirectional, embed_drop_p, num_layers) + + size_layer = [hidden_size, 100, tag_size] + self.decoder_model = MLP(size_layer) + + + def forward(self, **kwargs): + chars = kwargs['chars'] + if 'bigram' in kwargs: + bigrams = kwargs['bigrams'] + else: + bigrams = None + seq_lens = kwargs['seq_lens'] + + feats = self.encoder_model(chars, bigrams, seq_lens) + probs = self.decoder_model(feats) + + pred_dict = {} + pred_dict['seq_lens'] = seq_lens + pred_dict['pred_prob'] = probs + + return pred_dict + + def loss_fn(self, pred_dict, true_dict): + seq_lens = pred_dict['seq_lens'] + masks = seq_lens_to_mask(seq_lens).float() + + pred_prob = pred_dict['pred_prob'] + true_y = true_dict['tags'] + + # TODO 当前把loss写死了 + loss = F.cross_entropy(pred_prob.view(-1, self.tag_size), + true_y.view(-1), reduction='none')*masks.view(-1)/torch.sum(masks) + + + return loss + diff --git a/reproduction/chinese_word_segment/process/__init__.py b/reproduction/chinese_word_segment/process/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/chinese_word_segment/process/cws_processor.py b/reproduction/chinese_word_segment/process/cws_processor.py new file mode 100644 index 00000000..1f7c0fc1 --- /dev/null +++ b/reproduction/chinese_word_segment/process/cws_processor.py @@ -0,0 +1,283 @@ + +import re + + +from fastNLP.core.field import SeqLabelField +from fastNLP.core.vocabulary import Vocabulary +from fastNLP.core.dataset import DataSet + +from fastNLP.api.processor import Processor + + +_SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' + +class FullSpaceToHalfSpaceProcessor(Processor): + def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True, + change_space=True): + super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None) + + self.change_alpha = change_alpha + self.change_digit = change_digit + self.change_punctuation = change_punctuation + self.change_space = change_space + + FH_SPACE = [(u" ", u" ")] + FH_NUM = [ + (u"0", u"0"), (u"1", u"1"), (u"2", u"2"), (u"3", u"3"), (u"4", u"4"), + (u"5", u"5"), (u"6", u"6"), (u"7", u"7"), (u"8", u"8"), (u"9", u"9")] + FH_ALPHA = [ + (u"a", u"a"), (u"b", u"b"), (u"c", u"c"), (u"d", u"d"), (u"e", u"e"), + (u"f", u"f"), (u"g", u"g"), (u"h", u"h"), (u"i", u"i"), (u"j", u"j"), + (u"k", u"k"), (u"l", u"l"), (u"m", u"m"), (u"n", u"n"), (u"o", u"o"), + (u"p", u"p"), (u"q", u"q"), (u"r", u"r"), (u"s", u"s"), (u"t", u"t"), + (u"u", u"u"), (u"v", u"v"), (u"w", u"w"), (u"x", u"x"), (u"y", u"y"), + (u"z", u"z"), + (u"A", u"A"), (u"B", u"B"), (u"C", u"C"), (u"D", u"D"), (u"E", u"E"), + (u"F", u"F"), (u"G", u"G"), (u"H", u"H"), (u"I", u"I"), (u"J", u"J"), + (u"K", u"K"), (u"L", u"L"), (u"M", u"M"), (u"N", u"N"), (u"O", u"O"), + (u"P", u"P"), (u"Q", u"Q"), (u"R", u"R"), (u"S", u"S"), (u"T", u"T"), + (u"U", u"U"), (u"V", u"V"), (u"W", u"W"), (u"X", u"X"), (u"Y", u"Y"), + (u"Z", u"Z")] + # 谨慎使用标点符号转换, 因为"5.12特大地震"转换后可能就成了"5.12特大地震" + FH_PUNCTUATION = [ + (u'%', u'%'), (u'!', u'!'), (u'"', u'\"'), (u''', u'\''), (u'#', u'#'), + (u'¥', u'$'), (u'&', u'&'), (u'(', u'('), (u')', u')'), (u'*', u'*'), + (u'+', u'+'), (u',', u','), (u'-', u'-'), (u'.', u'.'), (u'/', u'/'), + (u':', u':'), (u';', u';'), (u'<', u'<'), (u'=', u'='), (u'>', u'>'), + (u'?', u'?'), (u'@', u'@'), (u'[', u'['), (u']', u']'), (u'\', u'\\'), + (u'^', u'^'), (u'_', u'_'), (u'`', u'`'), (u'~', u'~'), (u'{', u'{'), + (u'}', u'}'), (u'|', u'|')] + FHs = [] + if self.change_alpha: + FHs = FH_ALPHA + if self.change_digit: + FHs += FH_NUM + if self.change_punctuation: + FHs += FH_PUNCTUATION + if self.change_space: + FHs += FH_SPACE + self.convert_map = {k: v for k, v in FHs} + def process(self, dataset): + assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) + for ins in dataset: + sentence = ins[self.field_name].text + new_sentence = [None]*len(sentence) + for idx, char in enumerate(sentence): + if char in self.convert_map: + char = self.convert_map[char] + new_sentence[idx] = char + ins[self.field_name].text = ''.join(new_sentence) + return dataset + + +class SpeicalSpanProcessor(Processor): + # 这个类会将句子中的special span转换为对应的内容。 + def __init__(self, field_name, new_added_field_name=None): + super(SpeicalSpanProcessor, self).__init__(field_name, new_added_field_name) + + self.span_converters = [] + + + def process(self, dataset): + assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) + for ins in dataset: + sentence = ins[self.field_name].text + for span_converter in self.span_converters: + sentence = span_converter.find_certain_span_and_replace(sentence) + if self.new_added_field_name!=self.field_name: + new_text_field = TextField(sentence, is_target=False) + ins[self.new_added_field_name] = new_text_field + else: + ins[self.field_name].text = sentence + + return dataset + + def add_span_converter(self, converter): + assert isinstance(converter, SpanConverterBase), "Only SpanConverterBase is allowed, not {}."\ + .format(type(converter)) + self.span_converters.append(converter) + + + +class CWSCharSegProcessor(Processor): + def __init__(self, field_name, new_added_field_name): + super(CWSCharSegProcessor, self).__init__(field_name, new_added_field_name) + + def process(self, dataset): + assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) + for ins in dataset: + sentence = ins[self.field_name].text + chars = self._split_sent_into_chars(sentence) + new_token_field = TokenListFiled(chars, is_target=False) + ins[self.new_added_field_name] = new_token_field + + return dataset + + def _split_sent_into_chars(self, sentence): + sp_tag_match_iter = re.finditer(_SPECIAL_TAG_PATTERN, sentence) + sp_spans = [match_span.span() for match_span in sp_tag_match_iter] + sp_span_idx = 0 + in_span_flag = False + chars = [] + num_spans = len(sp_spans) + for idx, char in enumerate(sentence): + if sp_span_idx', ''] + characters + ['', ''] + for idx in range(2, len(characters)-2): + cur_char = characters[idx] + pre_pre_char = characters[idx-2] + pre_char = characters[idx-1] + post_char = characters[idx+1] + post_post_char = characters[idx+2] + pre_pre_cur_bigram = pre_pre_char + cur_char + pre_cur_bigram = pre_char + cur_char + cur_post_bigram = cur_char + post_char + cur_post_post_bigram = cur_char + post_post_char + bigrams.extend([pre_pre_char, pre_char, post_char, post_post_char, + pre_pre_cur_bigram, pre_cur_bigram, + cur_post_bigram, cur_post_post_bigram]) + return bigrams + + +# 这里需要建立vocabulary了,但是遇到了以下的问题 +# (1) 如果使用Processor的方式的话,但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现,不借用 +# Processor了 +class IndexProcessor(Processor): + def __init__(self, vocab, field_name): + + assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) + + super(IndexProcessor, self).__init__(field_name, None) + self.vocab = vocab + + def set_vocab(self, vocab): + assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) + + self.vocab = vocab + + def process(self, dataset): + assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) + for ins in dataset: + tokens = ins[self.field_name].content + index = [self.vocab.to_index(token) for token in tokens] + ins[self.field_name]._index = index + + return dataset + + +class VocabProcessor(Processor): + def __init__(self, field_name): + + super(VocabProcessor, self).__init__(field_name, None) + self.vocab = Vocabulary() + + def process(self, dataset): + assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) + for ins in dataset: + tokens = ins[self.field_name].content + self.vocab.update(tokens) + + def get_vocab(self): + self.vocab.build_vocab() + return self.vocab diff --git a/reproduction/chinese_word_segment/train_context.py b/reproduction/chinese_word_segment/train_context.py new file mode 100644 index 00000000..b28b04f6 --- /dev/null +++ b/reproduction/chinese_word_segment/train_context.py @@ -0,0 +1,3 @@ + + + diff --git a/reproduction/chinese_word_segment/utils.py b/reproduction/chinese_word_segment/utils.py new file mode 100644 index 00000000..92cd19d1 --- /dev/null +++ b/reproduction/chinese_word_segment/utils.py @@ -0,0 +1,86 @@ + +import torch + + +def seq_lens_to_mask(seq_lens): + batch_size = seq_lens.size(0) + max_len = seq_lens.max() + + indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) + masks = indexes.lt(seq_lens.unsqueeze(1)) + + return masks + + +def cut_long_training_sentences(sentences, max_sample_length=200): + cutted_sentence = [] + for sent in sentences: + sent_no_space = sent.replace(' ', '') + if len(sent_no_space) > max_sample_length: + parts = sent.strip().split() + new_line = '' + length = 0 + for part in parts: + length += len(part) + new_line += part + ' ' + if length > max_sample_length: + new_line = new_line[:-1] + cutted_sentence.append(new_line) + length = 0 + new_line = '' + if new_line != '': + cutted_sentence.append(new_line[:-1]) + else: + cutted_sentence.append(sent) + return cutted_sentence + + +from torch import nn +import torch.nn.functional as F + +class FocalLoss(nn.Module): + r""" + This criterion is a implemenation of Focal Loss, which is proposed in + Focal Loss for Dense Object Detection. + + Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) + + The losses are averaged across observations for each minibatch. + Args: + alpha(1D Tensor, Variable) : the scalar factor for this criterion + gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), + putting more focus on hard, misclassified examples + size_average(bool): size_average(bool): By default, the losses are averaged over observations for each minibatch. + However, if the field size_average is set to False, the losses are + instead summed for each minibatch. + """ + + def __init__(self, class_num, gamma=2, size_average=True, reduce=False): + super(FocalLoss, self).__init__() + self.gamma = gamma + self.class_num = class_num + self.size_average = size_average + self.reduce = reduce + + def forward(self, inputs, targets): + N = inputs.size(0) + C = inputs.size(1) + P = F.softmax(inputs, dim=-1) + + class_mask = inputs.data.new(N, C).fill_(0) + class_mask.requires_grad = True + ids = targets.view(-1, 1) + class_mask = class_mask.scatter(1, ids.data, 1.) + + probs = (P * class_mask).sum(1).view(-1, 1) + + log_p = probs.log() + + batch_loss = - (torch.pow((1 - probs), self.gamma)) * log_p + if self.reduce: + if self.size_average: + loss = batch_loss.mean() + else: + loss = batch_loss.sum() + return loss + return batch_loss \ No newline at end of file