From 64a9bacbc25d3890b6112c512e5823f4a4e3e338 Mon Sep 17 00:00:00 2001 From: yunfan Date: Sat, 10 Nov 2018 16:50:56 +0800 Subject: [PATCH 1/3] fix crf --- fastNLP/modules/decoder/CRF.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index 11cde48a..e24f4d27 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -89,8 +89,9 @@ class ConditionalRandomField(nn.Module): score = score.sum(0) + emit_score[-1] if self.include_start_end_trans: st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] - last_idx = mask.long().sum(0) + last_idx = mask.long().sum(0) - 1 ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] + print(score.size(), st_scores.size(), ed_scores.size()) score += st_scores + ed_scores # return [B,] return score @@ -104,8 +105,8 @@ class ConditionalRandomField(nn.Module): :return:FloatTensor, batch_size """ feats = feats.transpose(0, 1) - tags = tags.transpose(0, 1) - mask = mask.transpose(0, 1) + tags = tags.transpose(0, 1).long() + mask = mask.transpose(0, 1).float() all_path_score = self._normalizer_likelihood(feats, mask) gold_path_score = self._glod_score(feats, tags, mask) @@ -156,4 +157,4 @@ class ConditionalRandomField(nn.Module): if get_score: return ans_score, ans.transpose(0, 1) - return ans.transpose(0, 1) \ No newline at end of file + return ans.transpose(0, 1) From 26e3abdf58c1b4b7d9d40826cc67b4a448ef9ea3 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sat, 10 Nov 2018 16:58:27 +0800 Subject: [PATCH 2/3] =?UTF-8?q?-=20=E4=BF=AE=E6=94=B9pos=20tag=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E8=84=9A=E6=9C=AC=EF=BC=8C=E5=8F=AF=E4=BB=A5=E8=B7=91?= =?UTF-8?q?=20-=20=E5=9C=A8api=E4=B8=AD=E5=88=9B=E5=BB=BAconverter.py=20-?= =?UTF-8?q?=20Pipeline=E6=B7=BB=E5=8A=A0=E5=88=9D=E5=A7=8B=E5=8C=96?= =?UTF-8?q?=E6=96=B9=E6=B3=95=EF=BC=8C=E6=96=B9=E4=BE=BF=E4=B8=80=E6=AC=A1?= =?UTF-8?q?=E6=80=A7=E6=B7=BB=E5=8A=A0processors=20-=20=E5=88=A0=E9=99=A4p?= =?UTF-8?q?os=5Ftagger.py=20-=20=E4=BC=98=E5=8C=96=E6=95=B4=E4=BD=93code?= =?UTF-8?q?=20style?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/api/converter.py | 182 ++++++++++++++++++++ fastNLP/api/pipeline.py | 16 +- fastNLP/api/pos_tagger.py | 44 ----- fastNLP/api/processor.py | 27 ++- fastNLP/core/batch.py | 3 - fastNLP/core/dataset.py | 64 +++---- fastNLP/core/instance.py | 4 +- fastNLP/loader/dataset_loader.py | 5 +- fastNLP/models/sequence_modeling.py | 8 +- fastNLP/modules/decoder/CRF.py | 24 +-- reproduction/pos_tag_model/pos_tag.cfg | 8 +- reproduction/pos_tag_model/train_pos_tag.py | 154 ++++++----------- 12 files changed, 330 insertions(+), 209 deletions(-) create mode 100644 fastNLP/api/converter.py delete mode 100644 fastNLP/api/pos_tagger.py diff --git a/fastNLP/api/converter.py b/fastNLP/api/converter.py new file mode 100644 index 00000000..9ce24749 --- /dev/null +++ b/fastNLP/api/converter.py @@ -0,0 +1,182 @@ +import re + + +class SpanConverter: + def __init__(self, replace_tag, pattern): + super(SpanConverter, self).__init__() + + self.replace_tag = replace_tag + self.pattern = pattern + + def find_certain_span_and_replace(self, sentence): + replaced_sentence = '' + prev_end = 0 + for match in re.finditer(self.pattern, sentence): + start, end = match.span() + span = sentence[start:end] + replaced_sentence += sentence[prev_end:start] + \ + self.span_to_special_tag(span) + prev_end = end + replaced_sentence += sentence[prev_end:] + + return replaced_sentence + + def span_to_special_tag(self, span): + + return self.replace_tag + + def find_certain_span(self, sentence): + spans = [] + for match in re.finditer(self.pattern, sentence): + spans.append(match.span()) + return spans + + +class AlphaSpanConverter(SpanConverter): + def __init__(self): + replace_tag = '' + # 理想状态下仅处理纯为字母的情况, 但不处理<[a-zA-Z]+>(因为这应该是特殊的tag). + pattern = '[a-zA-Z]+(?=[\u4e00-\u9fff ,%.!<\\-"])' + + super(AlphaSpanConverter, self).__init__(replace_tag, pattern) + + +class DigitSpanConverter(SpanConverter): + def __init__(self): + replace_tag = '' + pattern = '\d[\d\\.]*(?=[\u4e00-\u9fff ,%.!<-])' + + super(DigitSpanConverter, self).__init__(replace_tag, pattern) + + def span_to_special_tag(self, span): + # return self.special_tag + if span[0] == '0' and len(span) > 2: + return '' + decimal_point_count = 0 # one might have more than one decimal pointers + for idx, char in enumerate(span): + if char == '.' or char == '﹒' or char == '·': + decimal_point_count += 1 + if span[-1] == '.' or span[-1] == '﹒' or span[ + -1] == '·': # last digit being decimal point means this is not a number + if decimal_point_count == 1: + return span + else: + return '' + if decimal_point_count == 1: + return '' + elif decimal_point_count > 1: + return '' + else: + return '' + + +class TimeConverter(SpanConverter): + def __init__(self): + replace_tag = '' + pattern = '\d+[::∶][\d::∶]+(?=[\u4e00-\u9fff ,%.!<-])' + + super().__init__(replace_tag, pattern) + + +class MixNumAlphaConverter(SpanConverter): + def __init__(self): + replace_tag = '' + pattern = None + + super().__init__(replace_tag, pattern) + + def find_certain_span_and_replace(self, sentence): + replaced_sentence = '' + start = 0 + matching_flag = False + number_flag = False + alpha_flag = False + link_flag = False + slash_flag = False + bracket_flag = False + for idx in range(len(sentence)): + if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]): + if not matching_flag: + replaced_sentence += sentence[start:idx] + start = idx + if re.match('[0-9]', sentence[idx]): + number_flag = True + elif re.match('[\'′&\\-]', sentence[idx]): + link_flag = True + elif re.match('/', sentence[idx]): + slash_flag = True + elif re.match('[\\(\\)]', sentence[idx]): + bracket_flag = True + else: + alpha_flag = True + matching_flag = True + elif re.match('[\\.]', sentence[idx]): + pass + else: + if matching_flag: + if (number_flag and alpha_flag) or (link_flag and alpha_flag) \ + or (slash_flag and alpha_flag) or (link_flag and number_flag) \ + or (number_flag and bracket_flag) or (bracket_flag and alpha_flag): + span = sentence[start:idx] + start = idx + replaced_sentence += self.span_to_special_tag(span) + matching_flag = False + number_flag = False + alpha_flag = False + link_flag = False + slash_flag = False + bracket_flag = False + + replaced_sentence += sentence[start:] + return replaced_sentence + + def find_certain_span(self, sentence): + spans = [] + start = 0 + matching_flag = False + number_flag = False + alpha_flag = False + link_flag = False + slash_flag = False + bracket_flag = False + for idx in range(len(sentence)): + if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]): + if not matching_flag: + start = idx + if re.match('[0-9]', sentence[idx]): + number_flag = True + elif re.match('[\'′&\\-]', sentence[idx]): + link_flag = True + elif re.match('/', sentence[idx]): + slash_flag = True + elif re.match('[\\(\\)]', sentence[idx]): + bracket_flag = True + else: + alpha_flag = True + matching_flag = True + elif re.match('[\\.]', sentence[idx]): + pass + else: + if matching_flag: + if (number_flag and alpha_flag) or (link_flag and alpha_flag) \ + or (slash_flag and alpha_flag) or (link_flag and number_flag) \ + or (number_flag and bracket_flag) or (bracket_flag and alpha_flag): + spans.append((start, idx)) + start = idx + + matching_flag = False + number_flag = False + alpha_flag = False + link_flag = False + slash_flag = False + bracket_flag = False + + return spans + + +class EmailConverter(SpanConverter): + def __init__(self): + replaced_tag = "" + pattern = '[0-9a-zA-Z]+[@][.﹒0-9a-zA-Z@]+(?=[\u4e00-\u9fff ,%.!<\\-"$])' + + super(EmailConverter, self).__init__(replaced_tag, pattern) diff --git a/fastNLP/api/pipeline.py b/fastNLP/api/pipeline.py index 745c8874..aea4797f 100644 --- a/fastNLP/api/pipeline.py +++ b/fastNLP/api/pipeline.py @@ -1,17 +1,25 @@ from fastNLP.api.processor import Processor - class Pipeline: - def __init__(self): + """ + Pipeline takes a DataSet object as input, runs multiple processors sequentially, and + outputs a DataSet object. + """ + + def __init__(self, processors=None): self.pipeline = [] + if isinstance(processors, list): + for proc in processors: + assert isinstance(proc, Processor), "Must be a Processor, not {}.".format(type(processor)) + self.pipeline = processors def add_processor(self, processor): assert isinstance(processor, Processor), "Must be a Processor, not {}.".format(type(processor)) self.pipeline.append(processor) def process(self, dataset): - assert len(self.pipeline)!=0, "You need to add some processor first." + assert len(self.pipeline) != 0, "You need to add some processor first." for proc_name, proc in self.pipeline: dataset = proc(dataset) @@ -19,4 +27,4 @@ class Pipeline: return dataset def __call__(self, *args, **kwargs): - return self.process(*args, **kwargs) \ No newline at end of file + return self.process(*args, **kwargs) diff --git a/fastNLP/api/pos_tagger.py b/fastNLP/api/pos_tagger.py deleted file mode 100644 index fbd689c1..00000000 --- a/fastNLP/api/pos_tagger.py +++ /dev/null @@ -1,44 +0,0 @@ -import pickle - -import numpy as np - -from fastNLP.core.dataset import DataSet -from fastNLP.loader.model_loader import ModelLoader -from fastNLP.core.predictor import Predictor - - -class POS_tagger: - def __init__(self): - pass - - def predict(self, query): - """ - :param query: List[str] - :return answer: List[str] - - """ - # TODO: 根据query 构建DataSet - pos_dataset = DataSet() - pos_dataset["text_field"] = np.array(query) - - # 加载pipeline和model - pipeline = self.load_pipeline("./xxxx") - - # 将DataSet作为参数运行 pipeline - pos_dataset = pipeline(pos_dataset) - - # 加载模型 - model = ModelLoader().load_pytorch("./xxx") - - # 调 predictor - predictor = Predictor() - output = predictor.predict(model, pos_dataset) - - # TODO: 转成最终输出 - return None - - @staticmethod - def load_pipeline(path): - with open(path, "r") as fp: - pipeline = pickle.load(fp) - return pipeline diff --git a/fastNLP/api/processor.py b/fastNLP/api/processor.py index 3f8cc057..391e781b 100644 --- a/fastNLP/api/processor.py +++ b/fastNLP/api/processor.py @@ -1,7 +1,7 @@ - from fastNLP.core.dataset import DataSet from fastNLP.core.vocabulary import Vocabulary + class Processor: def __init__(self, field_name, new_added_field_name): self.field_name = field_name @@ -10,15 +10,18 @@ class Processor: else: self.new_added_field_name = new_added_field_name - def process(self): + def process(self, *args, **kwargs): pass def __call__(self, *args, **kwargs): return self.process(*args, **kwargs) - class FullSpaceToHalfSpaceProcessor(Processor): + """全角转半角,以字符为处理单元 + + """ + def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True, change_space=True): super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None) @@ -64,11 +67,12 @@ class FullSpaceToHalfSpaceProcessor(Processor): if self.change_space: FHs += FH_SPACE self.convert_map = {k: v for k, v in FHs} + def process(self, dataset): assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) for ins in dataset: sentence = ins[self.field_name] - new_sentence = [None]*len(sentence) + new_sentence = [None] * len(sentence) for idx, char in enumerate(sentence): if char in self.convert_map: char = self.convert_map[char] @@ -98,7 +102,7 @@ class IndexerProcessor(Processor): index = [self.vocab.to_index(token) for token in tokens] ins[self.new_added_field_name] = index - dataset.set_need_tensor(**{self.new_added_field_name:True}) + dataset.set_need_tensor(**{self.new_added_field_name: True}) if self.delete_old_field: dataset.delete_field(self.field_name) @@ -122,3 +126,16 @@ class VocabProcessor(Processor): def get_vocab(self): self.vocab.build_vocab() return self.vocab + + +class SeqLenProcessor(Processor): + def __init__(self, field_name, new_added_field_name='seq_lens'): + super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) + + def process(self, dataset): + assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) + for ins in dataset: + length = len(ins[self.field_name]) + ins[self.new_added_field_name] = length + dataset.set_need_tensor(**{self.new_added_field_name: True}) + return dataset diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 856a6eac..bc19ffb2 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -1,5 +1,3 @@ -from collections import defaultdict - import torch @@ -68,4 +66,3 @@ class Batch(object): self.curidx = endidx return batch_x, batch_y - diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index e3162356..0071e443 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -1,23 +1,27 @@ -import random -import sys, os -sys.path.append('../..') -sys.path = [os.path.join(os.path.dirname(__file__), '../..')] + sys.path - -from collections import defaultdict -from copy import deepcopy -import numpy as np - -from fastNLP.core.field import TextField, LabelField -from fastNLP.core.instance import Instance -from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.fieldarray import FieldArray _READERS = {} + +def construct_dataset(sentences): + """Construct a data set from a list of sentences. + + :param sentences: list of str + :return dataset: a DataSet object + """ + dataset = DataSet() + for sentence in sentences: + instance = Instance() + instance['raw_sentence'] = sentence + dataset.append(instance) + return dataset + + class DataSet(object): """A DataSet object is a list of Instance objects. """ + class DataSetIter(object): def __init__(self, dataset): self.dataset = dataset @@ -34,13 +38,12 @@ class DataSet(object): def __setitem__(self, name, val): if name not in self.dataset: - new_fields = [None]*len(self.dataset) + new_fields = [None] * len(self.dataset) self.dataset.add_field(name, new_fields) self.dataset[name][self.idx] = val def __repr__(self): - # TODO - pass + return " ".join([repr(self.dataset[name][self.idx]) for name in self.dataset]) def __init__(self, instance=None): self.field_arrays = {} @@ -72,7 +75,7 @@ class DataSet(object): self.field_arrays[name].append(field) def add_field(self, name, fields): - if len(self.field_arrays)!=0: + if len(self.field_arrays) != 0: assert len(self) == len(fields) self.field_arrays[name] = FieldArray(name, fields) @@ -90,27 +93,10 @@ class DataSet(object): return len(field) def get_length(self): - """Fetch lengths of all fields in all instances in a dataset. - - :return lengths: dict of (str: list). The str is the field name. - The list contains lengths of this field in all instances. - - """ - pass - - def shuffle(self): - pass - - def split(self, ratio, shuffle=True): - """Train/dev splitting - - :param ratio: float, between 0 and 1. The ratio of development set in origin data set. - :param shuffle: bool, whether shuffle the data set before splitting. Default: True. - :return train_set: a DataSet object, representing the training set - dev_set: a DataSet object, representing the validation set + """The same as __len__ """ - pass + return len(self) def rename_field(self, old_name, new_name): """rename a field @@ -118,7 +104,7 @@ class DataSet(object): if old_name in self.field_arrays: self.field_arrays[new_name] = self.field_arrays.pop(old_name) else: - raise KeyError + raise KeyError("{} is not a valid name. ".format(old_name)) return self def set_is_target(self, **fields): @@ -150,6 +136,7 @@ class DataSet(object): data = _READERS[name]().load(*args, **kwargs) self.extend(data) return self + return _read else: return object.__getattribute__(self, name) @@ -159,18 +146,21 @@ class DataSet(object): """decorator to add dataloader support """ assert isinstance(method_name, str) + def wrapper(read_cls): _READERS[method_name] = read_cls return read_cls + return wrapper if __name__ == '__main__': from fastNLP.core.instance import Instance + ins = Instance(test='test0') dataset = DataSet([ins]) for _iter in dataset: print(_iter['test']) _iter['test'] = 'abc' print(_iter['test']) - print(dataset.field_arrays) \ No newline at end of file + print(dataset.field_arrays) diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py index a2686da8..12de4efa 100644 --- a/fastNLP/core/instance.py +++ b/fastNLP/core/instance.py @@ -1,4 +1,4 @@ -import torch + class Instance(object): """An instance which consists of Fields is an example in the DataSet. @@ -35,4 +35,4 @@ class Instance(object): return self.add_field(name, field) def __repr__(self): - return self.fields.__repr__() \ No newline at end of file + return self.fields.__repr__() diff --git a/fastNLP/loader/dataset_loader.py b/fastNLP/loader/dataset_loader.py index 4ba121dd..7537c638 100644 --- a/fastNLP/loader/dataset_loader.py +++ b/fastNLP/loader/dataset_loader.py @@ -1,9 +1,9 @@ import os -from fastNLP.loader.base_loader import BaseLoader from fastNLP.core.dataset import DataSet -from fastNLP.core.instance import Instance from fastNLP.core.field import * +from fastNLP.core.instance import Instance +from fastNLP.loader.base_loader import BaseLoader def convert_seq_dataset(data): @@ -393,6 +393,7 @@ class PeopleDailyCorpusLoader(DataSetLoader): sent_words.append(token) pos_tag_examples.append([sent_words, sent_pos_tag]) ner_examples.append([sent_words, sent_ner]) + # List[List[List[str], List[str]]] return pos_tag_examples, ner_examples def convert(self, data): diff --git a/fastNLP/models/sequence_modeling.py b/fastNLP/models/sequence_modeling.py index 11e49ee1..822c9286 100644 --- a/fastNLP/models/sequence_modeling.py +++ b/fastNLP/models/sequence_modeling.py @@ -44,6 +44,9 @@ class SeqLabeling(BaseModel): :return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. If truth is not None, return loss, a scalar. Used in training. """ + assert word_seq.shape[0] == word_seq_origin_len.shape[0] + if truth is not None: + assert truth.shape == word_seq.shape self.mask = self.make_mask(word_seq, word_seq_origin_len) x = self.Embedding(word_seq) @@ -80,7 +83,7 @@ class SeqLabeling(BaseModel): batch_size, max_len = x.size(0), x.size(1) mask = seq_mask(seq_len, max_len) mask = mask.byte().view(batch_size, max_len) - mask = mask.to(x) + mask = mask.to(x).float() return mask def decode(self, x, pad=True): @@ -130,6 +133,9 @@ class AdvSeqLabel(SeqLabeling): :return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. If truth is not None, return loss, a scalar. Used in training. """ + word_seq = word_seq.long() + word_seq_origin_len = word_seq_origin_len.long() + truth = truth.long() self.mask = self.make_mask(word_seq, word_seq_origin_len) batch_size = word_seq.size(0) diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index cd68d35d..0358bf9e 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -3,6 +3,7 @@ from torch import nn from fastNLP.modules.utils import initial_parameter + def log_sum_exp(x, dim=-1): max_value, _ = x.max(dim=dim, keepdim=True) res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value @@ -20,7 +21,7 @@ def seq_len_to_byte_mask(seq_lens): class ConditionalRandomField(nn.Module): - def __init__(self, tag_size, include_start_end_trans=True ,initial_method = None): + def __init__(self, tag_size, include_start_end_trans=False, initial_method=None): """ :param tag_size: int, num of tags :param include_start_end_trans: bool, whether to include start/end tag @@ -38,6 +39,7 @@ class ConditionalRandomField(nn.Module): # self.reset_parameter() initial_parameter(self, initial_method) + def reset_parameter(self): nn.init.xavier_normal_(self.trans_m) if self.include_start_end_trans: @@ -81,15 +83,15 @@ class ConditionalRandomField(nn.Module): seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) # trans_socre [L-1, B] - trans_score = self.trans_m[tags[:seq_len-1], tags[1:]] * mask[1:, :] + trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]] * mask[1:, :] # emit_score [L, B] - emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask + emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags] * mask # score [L-1, B] - score = trans_score + emit_score[:seq_len-1, :] + score = trans_score + emit_score[:seq_len - 1, :] score = score.sum(0) + emit_score[-1] if self.include_start_end_trans: st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] - last_idx = masks.long().sum(0) + last_idx = mask.long().sum(0) ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] score += st_scores + ed_scores # return [B,] @@ -120,14 +122,14 @@ class ConditionalRandomField(nn.Module): :return: scores, paths """ batch_size, seq_len, n_tags = data.size() - data = data.transpose(0, 1).data # L, B, H - mask = mask.transpose(0, 1).data.float() # L, B + data = data.transpose(0, 1).data # L, B, H + mask = mask.transpose(0, 1).data.float() # L, B # dp vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) vscore = data[0] if self.include_start_end_trans: - vscore += self.start_scores.view(1. -1) + vscore += self.start_scores.view(1. - 1) for i in range(1, seq_len): prev_score = vscore.view(batch_size, n_tags, 1) cur_score = data[i].view(batch_size, 1, n_tags) @@ -145,15 +147,15 @@ class ConditionalRandomField(nn.Module): seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device) lens = (mask.long().sum(0) - 1) # idxes [L, B], batched idx from seq_len-1 to 0 - idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len + idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len ans = data.new_empty((seq_len, batch_size), dtype=torch.long) ans_score, last_tags = vscore.max(1) ans[idxes[0], batch_idx] = last_tags for i in range(seq_len - 1): last_tags = vpath[idxes[i], batch_idx, last_tags] - ans[idxes[i+1], batch_idx] = last_tags + ans[idxes[i + 1], batch_idx] = last_tags if get_score: return ans_score, ans.transpose(0, 1) - return ans.transpose(0, 1) \ No newline at end of file + return ans.transpose(0, 1) diff --git a/reproduction/pos_tag_model/pos_tag.cfg b/reproduction/pos_tag_model/pos_tag.cfg index eb5e315d..2e1f37b6 100644 --- a/reproduction/pos_tag_model/pos_tag.cfg +++ b/reproduction/pos_tag_model/pos_tag.cfg @@ -1,10 +1,12 @@ [train] -epochs = 30 -batch_size = 64 +epochs = 5 +batch_size = 2 pickle_path = "./save/" -validate = true +validate = false save_best_dev = true model_saved_path = "./save/" + +[model] rnn_hidden_units = 100 word_emb_dim = 100 use_crf = true diff --git a/reproduction/pos_tag_model/train_pos_tag.py b/reproduction/pos_tag_model/train_pos_tag.py index fb077fe3..027358ef 100644 --- a/reproduction/pos_tag_model/train_pos_tag.py +++ b/reproduction/pos_tag_model/train_pos_tag.py @@ -1,130 +1,88 @@ import os -import sys -sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) +import torch +from fastNLP.api.pipeline import Pipeline +from fastNLP.api.processor import VocabProcessor, IndexerProcessor, SeqLenProcessor +from fastNLP.core.dataset import DataSet +from fastNLP.core.instance import Instance +from fastNLP.core.trainer import Trainer from fastNLP.loader.config_loader import ConfigLoader, ConfigSection -from fastNLP.core.trainer import SeqLabelTrainer -from fastNLP.loader.dataset_loader import PeopleDailyCorpusLoader, BaseLoader -from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle -from fastNLP.saver.model_saver import ModelSaver -from fastNLP.loader.model_loader import ModelLoader -from fastNLP.core.tester import SeqLabelTester +from fastNLP.loader.dataset_loader import PeopleDailyCorpusLoader from fastNLP.models.sequence_modeling import AdvSeqLabel -from fastNLP.core.predictor import SeqLabelInfer -# not in the file's dir -if len(os.path.dirname(__file__)) != 0: - os.chdir(os.path.dirname(__file__)) -datadir = "/home/zyfeng/data/" cfgfile = './pos_tag.cfg' -data_name = "CWS_POS_TAG_NER_people_daily.txt" +datadir = "/home/zyfeng/fastnlp_0.2.0/test/data_for_tests/" +data_name = "people_daily_raw.txt" pos_tag_data_path = os.path.join(datadir, data_name) pickle_path = "save" data_infer_path = os.path.join(datadir, "infer.utf8") -def infer(): - # Config Loader - test_args = ConfigSection() - ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args}) - - # fetch dictionary size and number of labels from pickle files - word2index = load_pickle(pickle_path, "word2id.pkl") - test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "class2id.pkl") - test_args["num_classes"] = len(index2label) - - # Define the same model - model = AdvSeqLabel(test_args) - - try: - ModelLoader.load_pytorch(model, "./save/saved_model.pkl") - print('model loaded!') - except Exception as e: - print('cannot load model!') - raise - - # Data Loader - raw_data_loader = BaseLoader(data_infer_path) - infer_data = raw_data_loader.load_lines() - print('data loaded') - - # Inference interface - infer = SeqLabelInfer(pickle_path) - results = infer.predict(model, infer_data) - - print(results) - print("Inference finished!") - - -def train(): +def train(): # load config - trainer_args = ConfigSection() - model_args = ConfigSection() - ConfigLoader().load_config(cfgfile, {"train": train_args, "test": test_args}) + train_param = ConfigSection() + model_param = ConfigSection() + ConfigLoader().load_config(cfgfile, {"train": train_param, "model": model_param}) + print("config loaded") # Data Loader loader = PeopleDailyCorpusLoader() - train_data, _ = loader.load() - - # TODO: define processors - - # define pipeline - pp = Pipeline() - # TODO: pp.add_processor() - - # run the pipeline, get data_set - train_data = pp(train_data) + train_data, _ = loader.load(os.path.join(datadir, data_name)) + print("data loaded") + + dataset = DataSet() + for data in train_data: + instance = Instance() + instance["words"] = data[0] + instance["tag"] = data[1] + dataset.append(instance) + print("dataset transformed") + + # processor_1 = FullSpaceToHalfSpaceProcessor('words') + # processor_1(dataset) + word_vocab_proc = VocabProcessor('words') + tag_vocab_proc = VocabProcessor("tag") + word_vocab_proc(dataset) + tag_vocab_proc(dataset) + word_indexer = IndexerProcessor(word_vocab_proc.get_vocab(), 'words', 'word_seq', delete_old_field=True) + word_indexer(dataset) + tag_indexer = IndexerProcessor(tag_vocab_proc.get_vocab(), 'tag', 'truth', delete_old_field=True) + tag_indexer(dataset) + seq_len_proc = SeqLenProcessor("word_seq", "word_seq_origin_len") + seq_len_proc(dataset) + + print("processors defined") + # dataset.set_is_target(tag_ids=True) + model_param["vocab_size"] = len(word_vocab_proc.get_vocab()) + model_param["num_classes"] = len(tag_vocab_proc.get_vocab()) + print("vocab_size={} num_classes={}".format(len(word_vocab_proc.get_vocab()), len(tag_vocab_proc.get_vocab()))) # define a model - model = AdvSeqLabel(train_args) + model = AdvSeqLabel(model_param) # call trainer to train - trainer = SeqLabelTrainer(train_args) - trainer.train(model, data_train, data_dev) - - # save model - ModelSaver("./saved_model.pkl").save_pytorch(model, param_only=False) - - # TODO:save pipeline + trainer = Trainer(**train_param.data) + trainer.train(model, dataset) + # save model & pipeline + pp = Pipeline([word_vocab_proc, word_indexer, seq_len_proc]) + save_dict = {"pipeline": pp, "model": model} + torch.save(save_dict, "model_pp.pkl") def test(): - # Config Loader - test_args = ConfigSection() - ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args}) - - # fetch dictionary size and number of labels from pickle files - word2index = load_pickle(pickle_path, "word2id.pkl") - test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "class2id.pkl") - test_args["num_classes"] = len(index2label) - - # load dev data - dev_data = load_pickle(pickle_path, "data_dev.pkl") - - # Define the same model - model = AdvSeqLabel(test_args) + pass - # Dump trained parameters into the model - ModelLoader.load_pytorch(model, "./save/saved_model.pkl") - print("model loaded!") - # Tester - tester = SeqLabelTester(**test_args.data) - - # Start testing - tester.test(model, dev_data) - - # print test results - print(tester.show_metrics()) - print("model tested!") +def infer(): + pass if __name__ == "__main__": + train() + """ import argparse parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') @@ -139,3 +97,5 @@ if __name__ == "__main__": else: print('no mode specified for model!') parser.print_help() + +""" From 5e84ca618e68e3f88c645f33a221ef9ff39740f8 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sat, 10 Nov 2018 17:04:37 +0800 Subject: [PATCH 3/3] merge and update --- fastNLP/api/pos_tagger.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 fastNLP/api/pos_tagger.py diff --git a/fastNLP/api/pos_tagger.py b/fastNLP/api/pos_tagger.py deleted file mode 100644 index e69de29b..00000000