- 在api中创建converter.py - Pipeline添加初始化方法,方便一次性添加processors - 删除pos_tagger.py - 优化整体code styletags/v0.2.0
@@ -0,0 +1,182 @@ | |||||
import re | |||||
class SpanConverter: | |||||
def __init__(self, replace_tag, pattern): | |||||
super(SpanConverter, self).__init__() | |||||
self.replace_tag = replace_tag | |||||
self.pattern = pattern | |||||
def find_certain_span_and_replace(self, sentence): | |||||
replaced_sentence = '' | |||||
prev_end = 0 | |||||
for match in re.finditer(self.pattern, sentence): | |||||
start, end = match.span() | |||||
span = sentence[start:end] | |||||
replaced_sentence += sentence[prev_end:start] + \ | |||||
self.span_to_special_tag(span) | |||||
prev_end = end | |||||
replaced_sentence += sentence[prev_end:] | |||||
return replaced_sentence | |||||
def span_to_special_tag(self, span): | |||||
return self.replace_tag | |||||
def find_certain_span(self, sentence): | |||||
spans = [] | |||||
for match in re.finditer(self.pattern, sentence): | |||||
spans.append(match.span()) | |||||
return spans | |||||
class AlphaSpanConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<ALPHA>' | |||||
# 理想状态下仅处理纯为字母的情况, 但不处理<[a-zA-Z]+>(因为这应该是特殊的tag). | |||||
pattern = '[a-zA-Z]+(?=[\u4e00-\u9fff ,%.!<\\-"])' | |||||
super(AlphaSpanConverter, self).__init__(replace_tag, pattern) | |||||
class DigitSpanConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<NUM>' | |||||
pattern = '\d[\d\\.]*(?=[\u4e00-\u9fff ,%.!<-])' | |||||
super(DigitSpanConverter, self).__init__(replace_tag, pattern) | |||||
def span_to_special_tag(self, span): | |||||
# return self.special_tag | |||||
if span[0] == '0' and len(span) > 2: | |||||
return '<NUM>' | |||||
decimal_point_count = 0 # one might have more than one decimal pointers | |||||
for idx, char in enumerate(span): | |||||
if char == '.' or char == '﹒' or char == '·': | |||||
decimal_point_count += 1 | |||||
if span[-1] == '.' or span[-1] == '﹒' or span[ | |||||
-1] == '·': # last digit being decimal point means this is not a number | |||||
if decimal_point_count == 1: | |||||
return span | |||||
else: | |||||
return '<UNKDGT>' | |||||
if decimal_point_count == 1: | |||||
return '<DEC>' | |||||
elif decimal_point_count > 1: | |||||
return '<UNKDGT>' | |||||
else: | |||||
return '<NUM>' | |||||
class TimeConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<TOC>' | |||||
pattern = '\d+[::∶][\d::∶]+(?=[\u4e00-\u9fff ,%.!<-])' | |||||
super().__init__(replace_tag, pattern) | |||||
class MixNumAlphaConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<MIX>' | |||||
pattern = None | |||||
super().__init__(replace_tag, pattern) | |||||
def find_certain_span_and_replace(self, sentence): | |||||
replaced_sentence = '' | |||||
start = 0 | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
for idx in range(len(sentence)): | |||||
if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]): | |||||
if not matching_flag: | |||||
replaced_sentence += sentence[start:idx] | |||||
start = idx | |||||
if re.match('[0-9]', sentence[idx]): | |||||
number_flag = True | |||||
elif re.match('[\'′&\\-]', sentence[idx]): | |||||
link_flag = True | |||||
elif re.match('/', sentence[idx]): | |||||
slash_flag = True | |||||
elif re.match('[\\(\\)]', sentence[idx]): | |||||
bracket_flag = True | |||||
else: | |||||
alpha_flag = True | |||||
matching_flag = True | |||||
elif re.match('[\\.]', sentence[idx]): | |||||
pass | |||||
else: | |||||
if matching_flag: | |||||
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \ | |||||
or (slash_flag and alpha_flag) or (link_flag and number_flag) \ | |||||
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag): | |||||
span = sentence[start:idx] | |||||
start = idx | |||||
replaced_sentence += self.span_to_special_tag(span) | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
replaced_sentence += sentence[start:] | |||||
return replaced_sentence | |||||
def find_certain_span(self, sentence): | |||||
spans = [] | |||||
start = 0 | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
for idx in range(len(sentence)): | |||||
if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]): | |||||
if not matching_flag: | |||||
start = idx | |||||
if re.match('[0-9]', sentence[idx]): | |||||
number_flag = True | |||||
elif re.match('[\'′&\\-]', sentence[idx]): | |||||
link_flag = True | |||||
elif re.match('/', sentence[idx]): | |||||
slash_flag = True | |||||
elif re.match('[\\(\\)]', sentence[idx]): | |||||
bracket_flag = True | |||||
else: | |||||
alpha_flag = True | |||||
matching_flag = True | |||||
elif re.match('[\\.]', sentence[idx]): | |||||
pass | |||||
else: | |||||
if matching_flag: | |||||
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \ | |||||
or (slash_flag and alpha_flag) or (link_flag and number_flag) \ | |||||
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag): | |||||
spans.append((start, idx)) | |||||
start = idx | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
return spans | |||||
class EmailConverter(SpanConverter): | |||||
def __init__(self): | |||||
replaced_tag = "<EML>" | |||||
pattern = '[0-9a-zA-Z]+[@][.﹒0-9a-zA-Z@]+(?=[\u4e00-\u9fff ,%.!<\\-"$])' | |||||
super(EmailConverter, self).__init__(replaced_tag, pattern) |
@@ -1,17 +1,25 @@ | |||||
from fastNLP.api.processor import Processor | from fastNLP.api.processor import Processor | ||||
class Pipeline: | class Pipeline: | ||||
def __init__(self): | |||||
""" | |||||
Pipeline takes a DataSet object as input, runs multiple processors sequentially, and | |||||
outputs a DataSet object. | |||||
""" | |||||
def __init__(self, processors=None): | |||||
self.pipeline = [] | self.pipeline = [] | ||||
if isinstance(processors, list): | |||||
for proc in processors: | |||||
assert isinstance(proc, Processor), "Must be a Processor, not {}.".format(type(processor)) | |||||
self.pipeline = processors | |||||
def add_processor(self, processor): | def add_processor(self, processor): | ||||
assert isinstance(processor, Processor), "Must be a Processor, not {}.".format(type(processor)) | assert isinstance(processor, Processor), "Must be a Processor, not {}.".format(type(processor)) | ||||
self.pipeline.append(processor) | self.pipeline.append(processor) | ||||
def process(self, dataset): | def process(self, dataset): | ||||
assert len(self.pipeline)!=0, "You need to add some processor first." | |||||
assert len(self.pipeline) != 0, "You need to add some processor first." | |||||
for proc_name, proc in self.pipeline: | for proc_name, proc in self.pipeline: | ||||
dataset = proc(dataset) | dataset = proc(dataset) | ||||
@@ -19,4 +27,4 @@ class Pipeline: | |||||
return dataset | return dataset | ||||
def __call__(self, *args, **kwargs): | def __call__(self, *args, **kwargs): | ||||
return self.process(*args, **kwargs) | |||||
return self.process(*args, **kwargs) |
@@ -1,44 +0,0 @@ | |||||
import pickle | |||||
import numpy as np | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.loader.model_loader import ModelLoader | |||||
from fastNLP.core.predictor import Predictor | |||||
class POS_tagger: | |||||
def __init__(self): | |||||
pass | |||||
def predict(self, query): | |||||
""" | |||||
:param query: List[str] | |||||
:return answer: List[str] | |||||
""" | |||||
# TODO: 根据query 构建DataSet | |||||
pos_dataset = DataSet() | |||||
pos_dataset["text_field"] = np.array(query) | |||||
# 加载pipeline和model | |||||
pipeline = self.load_pipeline("./xxxx") | |||||
# 将DataSet作为参数运行 pipeline | |||||
pos_dataset = pipeline(pos_dataset) | |||||
# 加载模型 | |||||
model = ModelLoader().load_pytorch("./xxx") | |||||
# 调 predictor | |||||
predictor = Predictor() | |||||
output = predictor.predict(model, pos_dataset) | |||||
# TODO: 转成最终输出 | |||||
return None | |||||
@staticmethod | |||||
def load_pipeline(path): | |||||
with open(path, "r") as fp: | |||||
pipeline = pickle.load(fp) | |||||
return pipeline |
@@ -1,7 +1,7 @@ | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
class Processor: | class Processor: | ||||
def __init__(self, field_name, new_added_field_name): | def __init__(self, field_name, new_added_field_name): | ||||
self.field_name = field_name | self.field_name = field_name | ||||
@@ -10,15 +10,18 @@ class Processor: | |||||
else: | else: | ||||
self.new_added_field_name = new_added_field_name | self.new_added_field_name = new_added_field_name | ||||
def process(self): | |||||
def process(self, *args, **kwargs): | |||||
pass | pass | ||||
def __call__(self, *args, **kwargs): | def __call__(self, *args, **kwargs): | ||||
return self.process(*args, **kwargs) | return self.process(*args, **kwargs) | ||||
class FullSpaceToHalfSpaceProcessor(Processor): | class FullSpaceToHalfSpaceProcessor(Processor): | ||||
"""全角转半角,以字符为处理单元 | |||||
""" | |||||
def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True, | def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True, | ||||
change_space=True): | change_space=True): | ||||
super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None) | super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None) | ||||
@@ -64,11 +67,12 @@ class FullSpaceToHalfSpaceProcessor(Processor): | |||||
if self.change_space: | if self.change_space: | ||||
FHs += FH_SPACE | FHs += FH_SPACE | ||||
self.convert_map = {k: v for k, v in FHs} | self.convert_map = {k: v for k, v in FHs} | ||||
def process(self, dataset): | def process(self, dataset): | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | for ins in dataset: | ||||
sentence = ins[self.field_name] | sentence = ins[self.field_name] | ||||
new_sentence = [None]*len(sentence) | |||||
new_sentence = [None] * len(sentence) | |||||
for idx, char in enumerate(sentence): | for idx, char in enumerate(sentence): | ||||
if char in self.convert_map: | if char in self.convert_map: | ||||
char = self.convert_map[char] | char = self.convert_map[char] | ||||
@@ -98,7 +102,7 @@ class IndexerProcessor(Processor): | |||||
index = [self.vocab.to_index(token) for token in tokens] | index = [self.vocab.to_index(token) for token in tokens] | ||||
ins[self.new_added_field_name] = index | ins[self.new_added_field_name] = index | ||||
dataset.set_need_tensor(**{self.new_added_field_name:True}) | |||||
dataset.set_need_tensor(**{self.new_added_field_name: True}) | |||||
if self.delete_old_field: | if self.delete_old_field: | ||||
dataset.delete_field(self.field_name) | dataset.delete_field(self.field_name) | ||||
@@ -122,3 +126,16 @@ class VocabProcessor(Processor): | |||||
def get_vocab(self): | def get_vocab(self): | ||||
self.vocab.build_vocab() | self.vocab.build_vocab() | ||||
return self.vocab | return self.vocab | ||||
class SeqLenProcessor(Processor): | |||||
def __init__(self, field_name, new_added_field_name='seq_lens'): | |||||
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
length = len(ins[self.field_name]) | |||||
ins[self.new_added_field_name] = length | |||||
dataset.set_need_tensor(**{self.new_added_field_name: True}) | |||||
return dataset |
@@ -1,5 +1,3 @@ | |||||
from collections import defaultdict | |||||
import torch | import torch | ||||
@@ -68,4 +66,3 @@ class Batch(object): | |||||
self.curidx = endidx | self.curidx = endidx | ||||
return batch_x, batch_y | return batch_x, batch_y | ||||
@@ -1,23 +1,27 @@ | |||||
import random | |||||
import sys, os | |||||
sys.path.append('../..') | |||||
sys.path = [os.path.join(os.path.dirname(__file__), '../..')] + sys.path | |||||
from collections import defaultdict | |||||
from copy import deepcopy | |||||
import numpy as np | |||||
from fastNLP.core.field import TextField, LabelField | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.core.fieldarray import FieldArray | from fastNLP.core.fieldarray import FieldArray | ||||
_READERS = {} | _READERS = {} | ||||
def construct_dataset(sentences): | |||||
"""Construct a data set from a list of sentences. | |||||
:param sentences: list of str | |||||
:return dataset: a DataSet object | |||||
""" | |||||
dataset = DataSet() | |||||
for sentence in sentences: | |||||
instance = Instance() | |||||
instance['raw_sentence'] = sentence | |||||
dataset.append(instance) | |||||
return dataset | |||||
class DataSet(object): | class DataSet(object): | ||||
"""A DataSet object is a list of Instance objects. | """A DataSet object is a list of Instance objects. | ||||
""" | """ | ||||
class DataSetIter(object): | class DataSetIter(object): | ||||
def __init__(self, dataset): | def __init__(self, dataset): | ||||
self.dataset = dataset | self.dataset = dataset | ||||
@@ -34,13 +38,12 @@ class DataSet(object): | |||||
def __setitem__(self, name, val): | def __setitem__(self, name, val): | ||||
if name not in self.dataset: | if name not in self.dataset: | ||||
new_fields = [None]*len(self.dataset) | |||||
new_fields = [None] * len(self.dataset) | |||||
self.dataset.add_field(name, new_fields) | self.dataset.add_field(name, new_fields) | ||||
self.dataset[name][self.idx] = val | self.dataset[name][self.idx] = val | ||||
def __repr__(self): | def __repr__(self): | ||||
# TODO | |||||
pass | |||||
return " ".join([repr(self.dataset[name][self.idx]) for name in self.dataset]) | |||||
def __init__(self, instance=None): | def __init__(self, instance=None): | ||||
self.field_arrays = {} | self.field_arrays = {} | ||||
@@ -72,7 +75,7 @@ class DataSet(object): | |||||
self.field_arrays[name].append(field) | self.field_arrays[name].append(field) | ||||
def add_field(self, name, fields): | def add_field(self, name, fields): | ||||
if len(self.field_arrays)!=0: | |||||
if len(self.field_arrays) != 0: | |||||
assert len(self) == len(fields) | assert len(self) == len(fields) | ||||
self.field_arrays[name] = FieldArray(name, fields) | self.field_arrays[name] = FieldArray(name, fields) | ||||
@@ -90,27 +93,10 @@ class DataSet(object): | |||||
return len(field) | return len(field) | ||||
def get_length(self): | def get_length(self): | ||||
"""Fetch lengths of all fields in all instances in a dataset. | |||||
:return lengths: dict of (str: list). The str is the field name. | |||||
The list contains lengths of this field in all instances. | |||||
""" | |||||
pass | |||||
def shuffle(self): | |||||
pass | |||||
def split(self, ratio, shuffle=True): | |||||
"""Train/dev splitting | |||||
:param ratio: float, between 0 and 1. The ratio of development set in origin data set. | |||||
:param shuffle: bool, whether shuffle the data set before splitting. Default: True. | |||||
:return train_set: a DataSet object, representing the training set | |||||
dev_set: a DataSet object, representing the validation set | |||||
"""The same as __len__ | |||||
""" | """ | ||||
pass | |||||
return len(self) | |||||
def rename_field(self, old_name, new_name): | def rename_field(self, old_name, new_name): | ||||
"""rename a field | """rename a field | ||||
@@ -118,7 +104,7 @@ class DataSet(object): | |||||
if old_name in self.field_arrays: | if old_name in self.field_arrays: | ||||
self.field_arrays[new_name] = self.field_arrays.pop(old_name) | self.field_arrays[new_name] = self.field_arrays.pop(old_name) | ||||
else: | else: | ||||
raise KeyError | |||||
raise KeyError("{} is not a valid name. ".format(old_name)) | |||||
return self | return self | ||||
def set_is_target(self, **fields): | def set_is_target(self, **fields): | ||||
@@ -150,6 +136,7 @@ class DataSet(object): | |||||
data = _READERS[name]().load(*args, **kwargs) | data = _READERS[name]().load(*args, **kwargs) | ||||
self.extend(data) | self.extend(data) | ||||
return self | return self | ||||
return _read | return _read | ||||
else: | else: | ||||
return object.__getattribute__(self, name) | return object.__getattribute__(self, name) | ||||
@@ -159,18 +146,21 @@ class DataSet(object): | |||||
"""decorator to add dataloader support | """decorator to add dataloader support | ||||
""" | """ | ||||
assert isinstance(method_name, str) | assert isinstance(method_name, str) | ||||
def wrapper(read_cls): | def wrapper(read_cls): | ||||
_READERS[method_name] = read_cls | _READERS[method_name] = read_cls | ||||
return read_cls | return read_cls | ||||
return wrapper | return wrapper | ||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
ins = Instance(test='test0') | ins = Instance(test='test0') | ||||
dataset = DataSet([ins]) | dataset = DataSet([ins]) | ||||
for _iter in dataset: | for _iter in dataset: | ||||
print(_iter['test']) | print(_iter['test']) | ||||
_iter['test'] = 'abc' | _iter['test'] = 'abc' | ||||
print(_iter['test']) | print(_iter['test']) | ||||
print(dataset.field_arrays) | |||||
print(dataset.field_arrays) |
@@ -1,4 +1,4 @@ | |||||
import torch | |||||
class Instance(object): | class Instance(object): | ||||
"""An instance which consists of Fields is an example in the DataSet. | """An instance which consists of Fields is an example in the DataSet. | ||||
@@ -35,4 +35,4 @@ class Instance(object): | |||||
return self.add_field(name, field) | return self.add_field(name, field) | ||||
def __repr__(self): | def __repr__(self): | ||||
return self.fields.__repr__() | |||||
return self.fields.__repr__() |
@@ -1,9 +1,9 @@ | |||||
import os | import os | ||||
from fastNLP.loader.base_loader import BaseLoader | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.field import * | from fastNLP.core.field import * | ||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.loader.base_loader import BaseLoader | |||||
def convert_seq_dataset(data): | def convert_seq_dataset(data): | ||||
@@ -393,6 +393,7 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||||
sent_words.append(token) | sent_words.append(token) | ||||
pos_tag_examples.append([sent_words, sent_pos_tag]) | pos_tag_examples.append([sent_words, sent_pos_tag]) | ||||
ner_examples.append([sent_words, sent_ner]) | ner_examples.append([sent_words, sent_ner]) | ||||
# List[List[List[str], List[str]]] | |||||
return pos_tag_examples, ner_examples | return pos_tag_examples, ner_examples | ||||
def convert(self, data): | def convert(self, data): | ||||
@@ -44,6 +44,9 @@ class SeqLabeling(BaseModel): | |||||
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | :return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | ||||
If truth is not None, return loss, a scalar. Used in training. | If truth is not None, return loss, a scalar. Used in training. | ||||
""" | """ | ||||
assert word_seq.shape[0] == word_seq_origin_len.shape[0] | |||||
if truth is not None: | |||||
assert truth.shape == word_seq.shape | |||||
self.mask = self.make_mask(word_seq, word_seq_origin_len) | self.mask = self.make_mask(word_seq, word_seq_origin_len) | ||||
x = self.Embedding(word_seq) | x = self.Embedding(word_seq) | ||||
@@ -80,7 +83,7 @@ class SeqLabeling(BaseModel): | |||||
batch_size, max_len = x.size(0), x.size(1) | batch_size, max_len = x.size(0), x.size(1) | ||||
mask = seq_mask(seq_len, max_len) | mask = seq_mask(seq_len, max_len) | ||||
mask = mask.byte().view(batch_size, max_len) | mask = mask.byte().view(batch_size, max_len) | ||||
mask = mask.to(x) | |||||
mask = mask.to(x).float() | |||||
return mask | return mask | ||||
def decode(self, x, pad=True): | def decode(self, x, pad=True): | ||||
@@ -130,6 +133,9 @@ class AdvSeqLabel(SeqLabeling): | |||||
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | :return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | ||||
If truth is not None, return loss, a scalar. Used in training. | If truth is not None, return loss, a scalar. Used in training. | ||||
""" | """ | ||||
word_seq = word_seq.long() | |||||
word_seq_origin_len = word_seq_origin_len.long() | |||||
truth = truth.long() | |||||
self.mask = self.make_mask(word_seq, word_seq_origin_len) | self.mask = self.make_mask(word_seq, word_seq_origin_len) | ||||
batch_size = word_seq.size(0) | batch_size = word_seq.size(0) | ||||
@@ -3,6 +3,7 @@ from torch import nn | |||||
from fastNLP.modules.utils import initial_parameter | from fastNLP.modules.utils import initial_parameter | ||||
def log_sum_exp(x, dim=-1): | def log_sum_exp(x, dim=-1): | ||||
max_value, _ = x.max(dim=dim, keepdim=True) | max_value, _ = x.max(dim=dim, keepdim=True) | ||||
res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value | res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value | ||||
@@ -20,7 +21,7 @@ def seq_len_to_byte_mask(seq_lens): | |||||
class ConditionalRandomField(nn.Module): | class ConditionalRandomField(nn.Module): | ||||
def __init__(self, tag_size, include_start_end_trans=True ,initial_method = None): | |||||
def __init__(self, tag_size, include_start_end_trans=False, initial_method=None): | |||||
""" | """ | ||||
:param tag_size: int, num of tags | :param tag_size: int, num of tags | ||||
:param include_start_end_trans: bool, whether to include start/end tag | :param include_start_end_trans: bool, whether to include start/end tag | ||||
@@ -38,6 +39,7 @@ class ConditionalRandomField(nn.Module): | |||||
# self.reset_parameter() | # self.reset_parameter() | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def reset_parameter(self): | def reset_parameter(self): | ||||
nn.init.xavier_normal_(self.trans_m) | nn.init.xavier_normal_(self.trans_m) | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
@@ -81,15 +83,15 @@ class ConditionalRandomField(nn.Module): | |||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | ||||
# trans_socre [L-1, B] | # trans_socre [L-1, B] | ||||
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]] * mask[1:, :] | |||||
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]] * mask[1:, :] | |||||
# emit_score [L, B] | # emit_score [L, B] | ||||
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask | |||||
emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags] * mask | |||||
# score [L-1, B] | # score [L-1, B] | ||||
score = trans_score + emit_score[:seq_len-1, :] | |||||
score = trans_score + emit_score[:seq_len - 1, :] | |||||
score = score.sum(0) + emit_score[-1] | score = score.sum(0) + emit_score[-1] | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | ||||
last_idx = masks.long().sum(0) | |||||
last_idx = mask.long().sum(0) | |||||
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] | ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] | ||||
score += st_scores + ed_scores | score += st_scores + ed_scores | ||||
# return [B,] | # return [B,] | ||||
@@ -120,14 +122,14 @@ class ConditionalRandomField(nn.Module): | |||||
:return: scores, paths | :return: scores, paths | ||||
""" | """ | ||||
batch_size, seq_len, n_tags = data.size() | batch_size, seq_len, n_tags = data.size() | ||||
data = data.transpose(0, 1).data # L, B, H | |||||
mask = mask.transpose(0, 1).data.float() # L, B | |||||
data = data.transpose(0, 1).data # L, B, H | |||||
mask = mask.transpose(0, 1).data.float() # L, B | |||||
# dp | # dp | ||||
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | ||||
vscore = data[0] | vscore = data[0] | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
vscore += self.start_scores.view(1. -1) | |||||
vscore += self.start_scores.view(1. - 1) | |||||
for i in range(1, seq_len): | for i in range(1, seq_len): | ||||
prev_score = vscore.view(batch_size, n_tags, 1) | prev_score = vscore.view(batch_size, n_tags, 1) | ||||
cur_score = data[i].view(batch_size, 1, n_tags) | cur_score = data[i].view(batch_size, 1, n_tags) | ||||
@@ -145,15 +147,15 @@ class ConditionalRandomField(nn.Module): | |||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device) | seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device) | ||||
lens = (mask.long().sum(0) - 1) | lens = (mask.long().sum(0) - 1) | ||||
# idxes [L, B], batched idx from seq_len-1 to 0 | # idxes [L, B], batched idx from seq_len-1 to 0 | ||||
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len | |||||
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len | |||||
ans = data.new_empty((seq_len, batch_size), dtype=torch.long) | ans = data.new_empty((seq_len, batch_size), dtype=torch.long) | ||||
ans_score, last_tags = vscore.max(1) | ans_score, last_tags = vscore.max(1) | ||||
ans[idxes[0], batch_idx] = last_tags | ans[idxes[0], batch_idx] = last_tags | ||||
for i in range(seq_len - 1): | for i in range(seq_len - 1): | ||||
last_tags = vpath[idxes[i], batch_idx, last_tags] | last_tags = vpath[idxes[i], batch_idx, last_tags] | ||||
ans[idxes[i+1], batch_idx] = last_tags | |||||
ans[idxes[i + 1], batch_idx] = last_tags | |||||
if get_score: | if get_score: | ||||
return ans_score, ans.transpose(0, 1) | return ans_score, ans.transpose(0, 1) | ||||
return ans.transpose(0, 1) | |||||
return ans.transpose(0, 1) |
@@ -1,10 +1,12 @@ | |||||
[train] | [train] | ||||
epochs = 30 | |||||
batch_size = 64 | |||||
epochs = 5 | |||||
batch_size = 2 | |||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
validate = true | |||||
validate = false | |||||
save_best_dev = true | save_best_dev = true | ||||
model_saved_path = "./save/" | model_saved_path = "./save/" | ||||
[model] | |||||
rnn_hidden_units = 100 | rnn_hidden_units = 100 | ||||
word_emb_dim = 100 | word_emb_dim = 100 | ||||
use_crf = true | use_crf = true | ||||
@@ -1,130 +1,88 @@ | |||||
import os | import os | ||||
import sys | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
import torch | |||||
from fastNLP.api.pipeline import Pipeline | |||||
from fastNLP.api.processor import VocabProcessor, IndexerProcessor, SeqLenProcessor | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.trainer import Trainer | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.core.trainer import SeqLabelTrainer | |||||
from fastNLP.loader.dataset_loader import PeopleDailyCorpusLoader, BaseLoader | |||||
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||||
from fastNLP.saver.model_saver import ModelSaver | |||||
from fastNLP.loader.model_loader import ModelLoader | |||||
from fastNLP.core.tester import SeqLabelTester | |||||
from fastNLP.loader.dataset_loader import PeopleDailyCorpusLoader | |||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | from fastNLP.models.sequence_modeling import AdvSeqLabel | ||||
from fastNLP.core.predictor import SeqLabelInfer | |||||
# not in the file's dir | |||||
if len(os.path.dirname(__file__)) != 0: | |||||
os.chdir(os.path.dirname(__file__)) | |||||
datadir = "/home/zyfeng/data/" | |||||
cfgfile = './pos_tag.cfg' | cfgfile = './pos_tag.cfg' | ||||
data_name = "CWS_POS_TAG_NER_people_daily.txt" | |||||
datadir = "/home/zyfeng/fastnlp_0.2.0/test/data_for_tests/" | |||||
data_name = "people_daily_raw.txt" | |||||
pos_tag_data_path = os.path.join(datadir, data_name) | pos_tag_data_path = os.path.join(datadir, data_name) | ||||
pickle_path = "save" | pickle_path = "save" | ||||
data_infer_path = os.path.join(datadir, "infer.utf8") | data_infer_path = os.path.join(datadir, "infer.utf8") | ||||
def infer(): | |||||
# Config Loader | |||||
test_args = ConfigSection() | |||||
ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | |||||
word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
test_args["vocab_size"] = len(word2index) | |||||
index2label = load_pickle(pickle_path, "class2id.pkl") | |||||
test_args["num_classes"] = len(index2label) | |||||
# Define the same model | |||||
model = AdvSeqLabel(test_args) | |||||
try: | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
print('model loaded!') | |||||
except Exception as e: | |||||
print('cannot load model!') | |||||
raise | |||||
# Data Loader | |||||
raw_data_loader = BaseLoader(data_infer_path) | |||||
infer_data = raw_data_loader.load_lines() | |||||
print('data loaded') | |||||
# Inference interface | |||||
infer = SeqLabelInfer(pickle_path) | |||||
results = infer.predict(model, infer_data) | |||||
print(results) | |||||
print("Inference finished!") | |||||
def train(): | |||||
def train(): | |||||
# load config | # load config | ||||
trainer_args = ConfigSection() | |||||
model_args = ConfigSection() | |||||
ConfigLoader().load_config(cfgfile, {"train": train_args, "test": test_args}) | |||||
train_param = ConfigSection() | |||||
model_param = ConfigSection() | |||||
ConfigLoader().load_config(cfgfile, {"train": train_param, "model": model_param}) | |||||
print("config loaded") | |||||
# Data Loader | # Data Loader | ||||
loader = PeopleDailyCorpusLoader() | loader = PeopleDailyCorpusLoader() | ||||
train_data, _ = loader.load() | |||||
# TODO: define processors | |||||
# define pipeline | |||||
pp = Pipeline() | |||||
# TODO: pp.add_processor() | |||||
# run the pipeline, get data_set | |||||
train_data = pp(train_data) | |||||
train_data, _ = loader.load(os.path.join(datadir, data_name)) | |||||
print("data loaded") | |||||
dataset = DataSet() | |||||
for data in train_data: | |||||
instance = Instance() | |||||
instance["words"] = data[0] | |||||
instance["tag"] = data[1] | |||||
dataset.append(instance) | |||||
print("dataset transformed") | |||||
# processor_1 = FullSpaceToHalfSpaceProcessor('words') | |||||
# processor_1(dataset) | |||||
word_vocab_proc = VocabProcessor('words') | |||||
tag_vocab_proc = VocabProcessor("tag") | |||||
word_vocab_proc(dataset) | |||||
tag_vocab_proc(dataset) | |||||
word_indexer = IndexerProcessor(word_vocab_proc.get_vocab(), 'words', 'word_seq', delete_old_field=True) | |||||
word_indexer(dataset) | |||||
tag_indexer = IndexerProcessor(tag_vocab_proc.get_vocab(), 'tag', 'truth', delete_old_field=True) | |||||
tag_indexer(dataset) | |||||
seq_len_proc = SeqLenProcessor("word_seq", "word_seq_origin_len") | |||||
seq_len_proc(dataset) | |||||
print("processors defined") | |||||
# dataset.set_is_target(tag_ids=True) | |||||
model_param["vocab_size"] = len(word_vocab_proc.get_vocab()) | |||||
model_param["num_classes"] = len(tag_vocab_proc.get_vocab()) | |||||
print("vocab_size={} num_classes={}".format(len(word_vocab_proc.get_vocab()), len(tag_vocab_proc.get_vocab()))) | |||||
# define a model | # define a model | ||||
model = AdvSeqLabel(train_args) | |||||
model = AdvSeqLabel(model_param) | |||||
# call trainer to train | # call trainer to train | ||||
trainer = SeqLabelTrainer(train_args) | |||||
trainer.train(model, data_train, data_dev) | |||||
# save model | |||||
ModelSaver("./saved_model.pkl").save_pytorch(model, param_only=False) | |||||
# TODO:save pipeline | |||||
trainer = Trainer(**train_param.data) | |||||
trainer.train(model, dataset) | |||||
# save model & pipeline | |||||
pp = Pipeline([word_vocab_proc, word_indexer, seq_len_proc]) | |||||
save_dict = {"pipeline": pp, "model": model} | |||||
torch.save(save_dict, "model_pp.pkl") | |||||
def test(): | def test(): | ||||
# Config Loader | |||||
test_args = ConfigSection() | |||||
ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | |||||
word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
test_args["vocab_size"] = len(word2index) | |||||
index2label = load_pickle(pickle_path, "class2id.pkl") | |||||
test_args["num_classes"] = len(index2label) | |||||
# load dev data | |||||
dev_data = load_pickle(pickle_path, "data_dev.pkl") | |||||
# Define the same model | |||||
model = AdvSeqLabel(test_args) | |||||
pass | |||||
# Dump trained parameters into the model | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
print("model loaded!") | |||||
# Tester | |||||
tester = SeqLabelTester(**test_args.data) | |||||
# Start testing | |||||
tester.test(model, dev_data) | |||||
# print test results | |||||
print(tester.show_metrics()) | |||||
print("model tested!") | |||||
def infer(): | |||||
pass | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
train() | |||||
""" | |||||
import argparse | import argparse | ||||
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | ||||
@@ -139,3 +97,5 @@ if __name__ == "__main__": | |||||
else: | else: | ||||
print('no mode specified for model!') | print('no mode specified for model!') | ||||
parser.print_help() | parser.print_help() | ||||
""" |