- 在api中创建converter.py - Pipeline添加初始化方法,方便一次性添加processors - 删除pos_tagger.py - 优化整体code styletags/v0.2.0
@@ -0,0 +1,182 @@ | |||
import re | |||
class SpanConverter: | |||
def __init__(self, replace_tag, pattern): | |||
super(SpanConverter, self).__init__() | |||
self.replace_tag = replace_tag | |||
self.pattern = pattern | |||
def find_certain_span_and_replace(self, sentence): | |||
replaced_sentence = '' | |||
prev_end = 0 | |||
for match in re.finditer(self.pattern, sentence): | |||
start, end = match.span() | |||
span = sentence[start:end] | |||
replaced_sentence += sentence[prev_end:start] + \ | |||
self.span_to_special_tag(span) | |||
prev_end = end | |||
replaced_sentence += sentence[prev_end:] | |||
return replaced_sentence | |||
def span_to_special_tag(self, span): | |||
return self.replace_tag | |||
def find_certain_span(self, sentence): | |||
spans = [] | |||
for match in re.finditer(self.pattern, sentence): | |||
spans.append(match.span()) | |||
return spans | |||
class AlphaSpanConverter(SpanConverter): | |||
def __init__(self): | |||
replace_tag = '<ALPHA>' | |||
# 理想状态下仅处理纯为字母的情况, 但不处理<[a-zA-Z]+>(因为这应该是特殊的tag). | |||
pattern = '[a-zA-Z]+(?=[\u4e00-\u9fff ,%.!<\\-"])' | |||
super(AlphaSpanConverter, self).__init__(replace_tag, pattern) | |||
class DigitSpanConverter(SpanConverter): | |||
def __init__(self): | |||
replace_tag = '<NUM>' | |||
pattern = '\d[\d\\.]*(?=[\u4e00-\u9fff ,%.!<-])' | |||
super(DigitSpanConverter, self).__init__(replace_tag, pattern) | |||
def span_to_special_tag(self, span): | |||
# return self.special_tag | |||
if span[0] == '0' and len(span) > 2: | |||
return '<NUM>' | |||
decimal_point_count = 0 # one might have more than one decimal pointers | |||
for idx, char in enumerate(span): | |||
if char == '.' or char == '﹒' or char == '·': | |||
decimal_point_count += 1 | |||
if span[-1] == '.' or span[-1] == '﹒' or span[ | |||
-1] == '·': # last digit being decimal point means this is not a number | |||
if decimal_point_count == 1: | |||
return span | |||
else: | |||
return '<UNKDGT>' | |||
if decimal_point_count == 1: | |||
return '<DEC>' | |||
elif decimal_point_count > 1: | |||
return '<UNKDGT>' | |||
else: | |||
return '<NUM>' | |||
class TimeConverter(SpanConverter): | |||
def __init__(self): | |||
replace_tag = '<TOC>' | |||
pattern = '\d+[::∶][\d::∶]+(?=[\u4e00-\u9fff ,%.!<-])' | |||
super().__init__(replace_tag, pattern) | |||
class MixNumAlphaConverter(SpanConverter): | |||
def __init__(self): | |||
replace_tag = '<MIX>' | |||
pattern = None | |||
super().__init__(replace_tag, pattern) | |||
def find_certain_span_and_replace(self, sentence): | |||
replaced_sentence = '' | |||
start = 0 | |||
matching_flag = False | |||
number_flag = False | |||
alpha_flag = False | |||
link_flag = False | |||
slash_flag = False | |||
bracket_flag = False | |||
for idx in range(len(sentence)): | |||
if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]): | |||
if not matching_flag: | |||
replaced_sentence += sentence[start:idx] | |||
start = idx | |||
if re.match('[0-9]', sentence[idx]): | |||
number_flag = True | |||
elif re.match('[\'′&\\-]', sentence[idx]): | |||
link_flag = True | |||
elif re.match('/', sentence[idx]): | |||
slash_flag = True | |||
elif re.match('[\\(\\)]', sentence[idx]): | |||
bracket_flag = True | |||
else: | |||
alpha_flag = True | |||
matching_flag = True | |||
elif re.match('[\\.]', sentence[idx]): | |||
pass | |||
else: | |||
if matching_flag: | |||
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \ | |||
or (slash_flag and alpha_flag) or (link_flag and number_flag) \ | |||
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag): | |||
span = sentence[start:idx] | |||
start = idx | |||
replaced_sentence += self.span_to_special_tag(span) | |||
matching_flag = False | |||
number_flag = False | |||
alpha_flag = False | |||
link_flag = False | |||
slash_flag = False | |||
bracket_flag = False | |||
replaced_sentence += sentence[start:] | |||
return replaced_sentence | |||
def find_certain_span(self, sentence): | |||
spans = [] | |||
start = 0 | |||
matching_flag = False | |||
number_flag = False | |||
alpha_flag = False | |||
link_flag = False | |||
slash_flag = False | |||
bracket_flag = False | |||
for idx in range(len(sentence)): | |||
if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]): | |||
if not matching_flag: | |||
start = idx | |||
if re.match('[0-9]', sentence[idx]): | |||
number_flag = True | |||
elif re.match('[\'′&\\-]', sentence[idx]): | |||
link_flag = True | |||
elif re.match('/', sentence[idx]): | |||
slash_flag = True | |||
elif re.match('[\\(\\)]', sentence[idx]): | |||
bracket_flag = True | |||
else: | |||
alpha_flag = True | |||
matching_flag = True | |||
elif re.match('[\\.]', sentence[idx]): | |||
pass | |||
else: | |||
if matching_flag: | |||
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \ | |||
or (slash_flag and alpha_flag) or (link_flag and number_flag) \ | |||
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag): | |||
spans.append((start, idx)) | |||
start = idx | |||
matching_flag = False | |||
number_flag = False | |||
alpha_flag = False | |||
link_flag = False | |||
slash_flag = False | |||
bracket_flag = False | |||
return spans | |||
class EmailConverter(SpanConverter): | |||
def __init__(self): | |||
replaced_tag = "<EML>" | |||
pattern = '[0-9a-zA-Z]+[@][.﹒0-9a-zA-Z@]+(?=[\u4e00-\u9fff ,%.!<\\-"$])' | |||
super(EmailConverter, self).__init__(replaced_tag, pattern) |
@@ -1,17 +1,25 @@ | |||
from fastNLP.api.processor import Processor | |||
class Pipeline: | |||
def __init__(self): | |||
""" | |||
Pipeline takes a DataSet object as input, runs multiple processors sequentially, and | |||
outputs a DataSet object. | |||
""" | |||
def __init__(self, processors=None): | |||
self.pipeline = [] | |||
if isinstance(processors, list): | |||
for proc in processors: | |||
assert isinstance(proc, Processor), "Must be a Processor, not {}.".format(type(processor)) | |||
self.pipeline = processors | |||
def add_processor(self, processor): | |||
assert isinstance(processor, Processor), "Must be a Processor, not {}.".format(type(processor)) | |||
self.pipeline.append(processor) | |||
def process(self, dataset): | |||
assert len(self.pipeline)!=0, "You need to add some processor first." | |||
assert len(self.pipeline) != 0, "You need to add some processor first." | |||
for proc_name, proc in self.pipeline: | |||
dataset = proc(dataset) | |||
@@ -19,4 +27,4 @@ class Pipeline: | |||
return dataset | |||
def __call__(self, *args, **kwargs): | |||
return self.process(*args, **kwargs) | |||
return self.process(*args, **kwargs) |
@@ -1,44 +0,0 @@ | |||
import pickle | |||
import numpy as np | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.loader.model_loader import ModelLoader | |||
from fastNLP.core.predictor import Predictor | |||
class POS_tagger: | |||
def __init__(self): | |||
pass | |||
def predict(self, query): | |||
""" | |||
:param query: List[str] | |||
:return answer: List[str] | |||
""" | |||
# TODO: 根据query 构建DataSet | |||
pos_dataset = DataSet() | |||
pos_dataset["text_field"] = np.array(query) | |||
# 加载pipeline和model | |||
pipeline = self.load_pipeline("./xxxx") | |||
# 将DataSet作为参数运行 pipeline | |||
pos_dataset = pipeline(pos_dataset) | |||
# 加载模型 | |||
model = ModelLoader().load_pytorch("./xxx") | |||
# 调 predictor | |||
predictor = Predictor() | |||
output = predictor.predict(model, pos_dataset) | |||
# TODO: 转成最终输出 | |||
return None | |||
@staticmethod | |||
def load_pipeline(path): | |||
with open(path, "r") as fp: | |||
pipeline = pickle.load(fp) | |||
return pipeline |
@@ -1,7 +1,7 @@ | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.vocabulary import Vocabulary | |||
class Processor: | |||
def __init__(self, field_name, new_added_field_name): | |||
self.field_name = field_name | |||
@@ -10,15 +10,18 @@ class Processor: | |||
else: | |||
self.new_added_field_name = new_added_field_name | |||
def process(self): | |||
def process(self, *args, **kwargs): | |||
pass | |||
def __call__(self, *args, **kwargs): | |||
return self.process(*args, **kwargs) | |||
class FullSpaceToHalfSpaceProcessor(Processor): | |||
"""全角转半角,以字符为处理单元 | |||
""" | |||
def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True, | |||
change_space=True): | |||
super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None) | |||
@@ -64,11 +67,12 @@ class FullSpaceToHalfSpaceProcessor(Processor): | |||
if self.change_space: | |||
FHs += FH_SPACE | |||
self.convert_map = {k: v for k, v in FHs} | |||
def process(self, dataset): | |||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
for ins in dataset: | |||
sentence = ins[self.field_name] | |||
new_sentence = [None]*len(sentence) | |||
new_sentence = [None] * len(sentence) | |||
for idx, char in enumerate(sentence): | |||
if char in self.convert_map: | |||
char = self.convert_map[char] | |||
@@ -98,7 +102,7 @@ class IndexerProcessor(Processor): | |||
index = [self.vocab.to_index(token) for token in tokens] | |||
ins[self.new_added_field_name] = index | |||
dataset.set_need_tensor(**{self.new_added_field_name:True}) | |||
dataset.set_need_tensor(**{self.new_added_field_name: True}) | |||
if self.delete_old_field: | |||
dataset.delete_field(self.field_name) | |||
@@ -122,3 +126,16 @@ class VocabProcessor(Processor): | |||
def get_vocab(self): | |||
self.vocab.build_vocab() | |||
return self.vocab | |||
class SeqLenProcessor(Processor): | |||
def __init__(self, field_name, new_added_field_name='seq_lens'): | |||
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | |||
def process(self, dataset): | |||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
for ins in dataset: | |||
length = len(ins[self.field_name]) | |||
ins[self.new_added_field_name] = length | |||
dataset.set_need_tensor(**{self.new_added_field_name: True}) | |||
return dataset |
@@ -1,5 +1,3 @@ | |||
from collections import defaultdict | |||
import torch | |||
@@ -68,4 +66,3 @@ class Batch(object): | |||
self.curidx = endidx | |||
return batch_x, batch_y | |||
@@ -1,23 +1,27 @@ | |||
import random | |||
import sys, os | |||
sys.path.append('../..') | |||
sys.path = [os.path.join(os.path.dirname(__file__), '../..')] + sys.path | |||
from collections import defaultdict | |||
from copy import deepcopy | |||
import numpy as np | |||
from fastNLP.core.field import TextField, LabelField | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.core.fieldarray import FieldArray | |||
_READERS = {} | |||
def construct_dataset(sentences): | |||
"""Construct a data set from a list of sentences. | |||
:param sentences: list of str | |||
:return dataset: a DataSet object | |||
""" | |||
dataset = DataSet() | |||
for sentence in sentences: | |||
instance = Instance() | |||
instance['raw_sentence'] = sentence | |||
dataset.append(instance) | |||
return dataset | |||
class DataSet(object): | |||
"""A DataSet object is a list of Instance objects. | |||
""" | |||
class DataSetIter(object): | |||
def __init__(self, dataset): | |||
self.dataset = dataset | |||
@@ -34,13 +38,12 @@ class DataSet(object): | |||
def __setitem__(self, name, val): | |||
if name not in self.dataset: | |||
new_fields = [None]*len(self.dataset) | |||
new_fields = [None] * len(self.dataset) | |||
self.dataset.add_field(name, new_fields) | |||
self.dataset[name][self.idx] = val | |||
def __repr__(self): | |||
# TODO | |||
pass | |||
return " ".join([repr(self.dataset[name][self.idx]) for name in self.dataset]) | |||
def __init__(self, instance=None): | |||
self.field_arrays = {} | |||
@@ -72,7 +75,7 @@ class DataSet(object): | |||
self.field_arrays[name].append(field) | |||
def add_field(self, name, fields): | |||
if len(self.field_arrays)!=0: | |||
if len(self.field_arrays) != 0: | |||
assert len(self) == len(fields) | |||
self.field_arrays[name] = FieldArray(name, fields) | |||
@@ -90,27 +93,10 @@ class DataSet(object): | |||
return len(field) | |||
def get_length(self): | |||
"""Fetch lengths of all fields in all instances in a dataset. | |||
:return lengths: dict of (str: list). The str is the field name. | |||
The list contains lengths of this field in all instances. | |||
""" | |||
pass | |||
def shuffle(self): | |||
pass | |||
def split(self, ratio, shuffle=True): | |||
"""Train/dev splitting | |||
:param ratio: float, between 0 and 1. The ratio of development set in origin data set. | |||
:param shuffle: bool, whether shuffle the data set before splitting. Default: True. | |||
:return train_set: a DataSet object, representing the training set | |||
dev_set: a DataSet object, representing the validation set | |||
"""The same as __len__ | |||
""" | |||
pass | |||
return len(self) | |||
def rename_field(self, old_name, new_name): | |||
"""rename a field | |||
@@ -118,7 +104,7 @@ class DataSet(object): | |||
if old_name in self.field_arrays: | |||
self.field_arrays[new_name] = self.field_arrays.pop(old_name) | |||
else: | |||
raise KeyError | |||
raise KeyError("{} is not a valid name. ".format(old_name)) | |||
return self | |||
def set_is_target(self, **fields): | |||
@@ -150,6 +136,7 @@ class DataSet(object): | |||
data = _READERS[name]().load(*args, **kwargs) | |||
self.extend(data) | |||
return self | |||
return _read | |||
else: | |||
return object.__getattribute__(self, name) | |||
@@ -159,18 +146,21 @@ class DataSet(object): | |||
"""decorator to add dataloader support | |||
""" | |||
assert isinstance(method_name, str) | |||
def wrapper(read_cls): | |||
_READERS[method_name] = read_cls | |||
return read_cls | |||
return wrapper | |||
if __name__ == '__main__': | |||
from fastNLP.core.instance import Instance | |||
ins = Instance(test='test0') | |||
dataset = DataSet([ins]) | |||
for _iter in dataset: | |||
print(_iter['test']) | |||
_iter['test'] = 'abc' | |||
print(_iter['test']) | |||
print(dataset.field_arrays) | |||
print(dataset.field_arrays) |
@@ -1,4 +1,4 @@ | |||
import torch | |||
class Instance(object): | |||
"""An instance which consists of Fields is an example in the DataSet. | |||
@@ -35,4 +35,4 @@ class Instance(object): | |||
return self.add_field(name, field) | |||
def __repr__(self): | |||
return self.fields.__repr__() | |||
return self.fields.__repr__() |
@@ -1,9 +1,9 @@ | |||
import os | |||
from fastNLP.loader.base_loader import BaseLoader | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.field import * | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.loader.base_loader import BaseLoader | |||
def convert_seq_dataset(data): | |||
@@ -393,6 +393,7 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||
sent_words.append(token) | |||
pos_tag_examples.append([sent_words, sent_pos_tag]) | |||
ner_examples.append([sent_words, sent_ner]) | |||
# List[List[List[str], List[str]]] | |||
return pos_tag_examples, ner_examples | |||
def convert(self, data): | |||
@@ -44,6 +44,9 @@ class SeqLabeling(BaseModel): | |||
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | |||
If truth is not None, return loss, a scalar. Used in training. | |||
""" | |||
assert word_seq.shape[0] == word_seq_origin_len.shape[0] | |||
if truth is not None: | |||
assert truth.shape == word_seq.shape | |||
self.mask = self.make_mask(word_seq, word_seq_origin_len) | |||
x = self.Embedding(word_seq) | |||
@@ -80,7 +83,7 @@ class SeqLabeling(BaseModel): | |||
batch_size, max_len = x.size(0), x.size(1) | |||
mask = seq_mask(seq_len, max_len) | |||
mask = mask.byte().view(batch_size, max_len) | |||
mask = mask.to(x) | |||
mask = mask.to(x).float() | |||
return mask | |||
def decode(self, x, pad=True): | |||
@@ -130,6 +133,9 @@ class AdvSeqLabel(SeqLabeling): | |||
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | |||
If truth is not None, return loss, a scalar. Used in training. | |||
""" | |||
word_seq = word_seq.long() | |||
word_seq_origin_len = word_seq_origin_len.long() | |||
truth = truth.long() | |||
self.mask = self.make_mask(word_seq, word_seq_origin_len) | |||
batch_size = word_seq.size(0) | |||
@@ -3,6 +3,7 @@ from torch import nn | |||
from fastNLP.modules.utils import initial_parameter | |||
def log_sum_exp(x, dim=-1): | |||
max_value, _ = x.max(dim=dim, keepdim=True) | |||
res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value | |||
@@ -20,7 +21,7 @@ def seq_len_to_byte_mask(seq_lens): | |||
class ConditionalRandomField(nn.Module): | |||
def __init__(self, tag_size, include_start_end_trans=True ,initial_method = None): | |||
def __init__(self, tag_size, include_start_end_trans=False, initial_method=None): | |||
""" | |||
:param tag_size: int, num of tags | |||
:param include_start_end_trans: bool, whether to include start/end tag | |||
@@ -38,6 +39,7 @@ class ConditionalRandomField(nn.Module): | |||
# self.reset_parameter() | |||
initial_parameter(self, initial_method) | |||
def reset_parameter(self): | |||
nn.init.xavier_normal_(self.trans_m) | |||
if self.include_start_end_trans: | |||
@@ -81,15 +83,15 @@ class ConditionalRandomField(nn.Module): | |||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | |||
# trans_socre [L-1, B] | |||
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]] * mask[1:, :] | |||
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]] * mask[1:, :] | |||
# emit_score [L, B] | |||
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask | |||
emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags] * mask | |||
# score [L-1, B] | |||
score = trans_score + emit_score[:seq_len-1, :] | |||
score = trans_score + emit_score[:seq_len - 1, :] | |||
score = score.sum(0) + emit_score[-1] | |||
if self.include_start_end_trans: | |||
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | |||
last_idx = masks.long().sum(0) | |||
last_idx = mask.long().sum(0) | |||
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] | |||
score += st_scores + ed_scores | |||
# return [B,] | |||
@@ -120,14 +122,14 @@ class ConditionalRandomField(nn.Module): | |||
:return: scores, paths | |||
""" | |||
batch_size, seq_len, n_tags = data.size() | |||
data = data.transpose(0, 1).data # L, B, H | |||
mask = mask.transpose(0, 1).data.float() # L, B | |||
data = data.transpose(0, 1).data # L, B, H | |||
mask = mask.transpose(0, 1).data.float() # L, B | |||
# dp | |||
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||
vscore = data[0] | |||
if self.include_start_end_trans: | |||
vscore += self.start_scores.view(1. -1) | |||
vscore += self.start_scores.view(1. - 1) | |||
for i in range(1, seq_len): | |||
prev_score = vscore.view(batch_size, n_tags, 1) | |||
cur_score = data[i].view(batch_size, 1, n_tags) | |||
@@ -145,15 +147,15 @@ class ConditionalRandomField(nn.Module): | |||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device) | |||
lens = (mask.long().sum(0) - 1) | |||
# idxes [L, B], batched idx from seq_len-1 to 0 | |||
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len | |||
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len | |||
ans = data.new_empty((seq_len, batch_size), dtype=torch.long) | |||
ans_score, last_tags = vscore.max(1) | |||
ans[idxes[0], batch_idx] = last_tags | |||
for i in range(seq_len - 1): | |||
last_tags = vpath[idxes[i], batch_idx, last_tags] | |||
ans[idxes[i+1], batch_idx] = last_tags | |||
ans[idxes[i + 1], batch_idx] = last_tags | |||
if get_score: | |||
return ans_score, ans.transpose(0, 1) | |||
return ans.transpose(0, 1) | |||
return ans.transpose(0, 1) |
@@ -1,10 +1,12 @@ | |||
[train] | |||
epochs = 30 | |||
batch_size = 64 | |||
epochs = 5 | |||
batch_size = 2 | |||
pickle_path = "./save/" | |||
validate = true | |||
validate = false | |||
save_best_dev = true | |||
model_saved_path = "./save/" | |||
[model] | |||
rnn_hidden_units = 100 | |||
word_emb_dim = 100 | |||
use_crf = true | |||
@@ -1,130 +1,88 @@ | |||
import os | |||
import sys | |||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||
import torch | |||
from fastNLP.api.pipeline import Pipeline | |||
from fastNLP.api.processor import VocabProcessor, IndexerProcessor, SeqLenProcessor | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.trainer import Trainer | |||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
from fastNLP.core.trainer import SeqLabelTrainer | |||
from fastNLP.loader.dataset_loader import PeopleDailyCorpusLoader, BaseLoader | |||
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||
from fastNLP.saver.model_saver import ModelSaver | |||
from fastNLP.loader.model_loader import ModelLoader | |||
from fastNLP.core.tester import SeqLabelTester | |||
from fastNLP.loader.dataset_loader import PeopleDailyCorpusLoader | |||
from fastNLP.models.sequence_modeling import AdvSeqLabel | |||
from fastNLP.core.predictor import SeqLabelInfer | |||
# not in the file's dir | |||
if len(os.path.dirname(__file__)) != 0: | |||
os.chdir(os.path.dirname(__file__)) | |||
datadir = "/home/zyfeng/data/" | |||
cfgfile = './pos_tag.cfg' | |||
data_name = "CWS_POS_TAG_NER_people_daily.txt" | |||
datadir = "/home/zyfeng/fastnlp_0.2.0/test/data_for_tests/" | |||
data_name = "people_daily_raw.txt" | |||
pos_tag_data_path = os.path.join(datadir, data_name) | |||
pickle_path = "save" | |||
data_infer_path = os.path.join(datadir, "infer.utf8") | |||
def infer(): | |||
# Config Loader | |||
test_args = ConfigSection() | |||
ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args}) | |||
# fetch dictionary size and number of labels from pickle files | |||
word2index = load_pickle(pickle_path, "word2id.pkl") | |||
test_args["vocab_size"] = len(word2index) | |||
index2label = load_pickle(pickle_path, "class2id.pkl") | |||
test_args["num_classes"] = len(index2label) | |||
# Define the same model | |||
model = AdvSeqLabel(test_args) | |||
try: | |||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||
print('model loaded!') | |||
except Exception as e: | |||
print('cannot load model!') | |||
raise | |||
# Data Loader | |||
raw_data_loader = BaseLoader(data_infer_path) | |||
infer_data = raw_data_loader.load_lines() | |||
print('data loaded') | |||
# Inference interface | |||
infer = SeqLabelInfer(pickle_path) | |||
results = infer.predict(model, infer_data) | |||
print(results) | |||
print("Inference finished!") | |||
def train(): | |||
def train(): | |||
# load config | |||
trainer_args = ConfigSection() | |||
model_args = ConfigSection() | |||
ConfigLoader().load_config(cfgfile, {"train": train_args, "test": test_args}) | |||
train_param = ConfigSection() | |||
model_param = ConfigSection() | |||
ConfigLoader().load_config(cfgfile, {"train": train_param, "model": model_param}) | |||
print("config loaded") | |||
# Data Loader | |||
loader = PeopleDailyCorpusLoader() | |||
train_data, _ = loader.load() | |||
# TODO: define processors | |||
# define pipeline | |||
pp = Pipeline() | |||
# TODO: pp.add_processor() | |||
# run the pipeline, get data_set | |||
train_data = pp(train_data) | |||
train_data, _ = loader.load(os.path.join(datadir, data_name)) | |||
print("data loaded") | |||
dataset = DataSet() | |||
for data in train_data: | |||
instance = Instance() | |||
instance["words"] = data[0] | |||
instance["tag"] = data[1] | |||
dataset.append(instance) | |||
print("dataset transformed") | |||
# processor_1 = FullSpaceToHalfSpaceProcessor('words') | |||
# processor_1(dataset) | |||
word_vocab_proc = VocabProcessor('words') | |||
tag_vocab_proc = VocabProcessor("tag") | |||
word_vocab_proc(dataset) | |||
tag_vocab_proc(dataset) | |||
word_indexer = IndexerProcessor(word_vocab_proc.get_vocab(), 'words', 'word_seq', delete_old_field=True) | |||
word_indexer(dataset) | |||
tag_indexer = IndexerProcessor(tag_vocab_proc.get_vocab(), 'tag', 'truth', delete_old_field=True) | |||
tag_indexer(dataset) | |||
seq_len_proc = SeqLenProcessor("word_seq", "word_seq_origin_len") | |||
seq_len_proc(dataset) | |||
print("processors defined") | |||
# dataset.set_is_target(tag_ids=True) | |||
model_param["vocab_size"] = len(word_vocab_proc.get_vocab()) | |||
model_param["num_classes"] = len(tag_vocab_proc.get_vocab()) | |||
print("vocab_size={} num_classes={}".format(len(word_vocab_proc.get_vocab()), len(tag_vocab_proc.get_vocab()))) | |||
# define a model | |||
model = AdvSeqLabel(train_args) | |||
model = AdvSeqLabel(model_param) | |||
# call trainer to train | |||
trainer = SeqLabelTrainer(train_args) | |||
trainer.train(model, data_train, data_dev) | |||
# save model | |||
ModelSaver("./saved_model.pkl").save_pytorch(model, param_only=False) | |||
# TODO:save pipeline | |||
trainer = Trainer(**train_param.data) | |||
trainer.train(model, dataset) | |||
# save model & pipeline | |||
pp = Pipeline([word_vocab_proc, word_indexer, seq_len_proc]) | |||
save_dict = {"pipeline": pp, "model": model} | |||
torch.save(save_dict, "model_pp.pkl") | |||
def test(): | |||
# Config Loader | |||
test_args = ConfigSection() | |||
ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args}) | |||
# fetch dictionary size and number of labels from pickle files | |||
word2index = load_pickle(pickle_path, "word2id.pkl") | |||
test_args["vocab_size"] = len(word2index) | |||
index2label = load_pickle(pickle_path, "class2id.pkl") | |||
test_args["num_classes"] = len(index2label) | |||
# load dev data | |||
dev_data = load_pickle(pickle_path, "data_dev.pkl") | |||
# Define the same model | |||
model = AdvSeqLabel(test_args) | |||
pass | |||
# Dump trained parameters into the model | |||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||
print("model loaded!") | |||
# Tester | |||
tester = SeqLabelTester(**test_args.data) | |||
# Start testing | |||
tester.test(model, dev_data) | |||
# print test results | |||
print(tester.show_metrics()) | |||
print("model tested!") | |||
def infer(): | |||
pass | |||
if __name__ == "__main__": | |||
train() | |||
""" | |||
import argparse | |||
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | |||
@@ -139,3 +97,5 @@ if __name__ == "__main__": | |||
else: | |||
print('no mode specified for model!') | |||
parser.print_help() | |||
""" |