A brand new version update (0.1.1)tags/v0.2.0
@@ -1,75 +0,0 @@ | |||||
from fastNLP.core.loss import Loss | |||||
from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.predictor import ClassificationInfer | |||||
from fastNLP.core.preprocess import ClassPreprocess | |||||
from fastNLP.core.trainer import ClassificationTrainer | |||||
from fastNLP.loader.dataset_loader import ClassDataSetLoader | |||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.modules import aggregator | |||||
from fastNLP.modules import decoder | |||||
from fastNLP.modules import encoder | |||||
class ClassificationModel(BaseModel): | |||||
""" | |||||
Simple text classification model based on CNN. | |||||
""" | |||||
def __init__(self, num_classes, vocab_size): | |||||
super(ClassificationModel, self).__init__() | |||||
self.emb = encoder.Embedding(nums=vocab_size, dims=300) | |||||
self.enc = encoder.Conv( | |||||
in_channels=300, out_channels=100, kernel_size=3) | |||||
self.agg = aggregator.MaxPool() | |||||
self.dec = decoder.MLP(size_layer=[100, num_classes]) | |||||
def forward(self, x): | |||||
x = self.emb(x) # [N,L] -> [N,L,C] | |||||
x = self.enc(x) # [N,L,C_in] -> [N,L,C_out] | |||||
x = self.agg(x) # [N,L,C] -> [N,C] | |||||
x = self.dec(x) # [N,C] -> [N, N_class] | |||||
return x | |||||
data_dir = 'save/' # directory to save data and model | |||||
train_path = './data_for_tests/text_classify.txt' # training set file | |||||
# load dataset | |||||
ds_loader = ClassDataSetLoader() | |||||
data = ds_loader.load() | |||||
# pre-process dataset | |||||
pre = ClassPreprocess() | |||||
train_set, dev_set = pre.run(data, train_dev_split=0.3, pickle_path=data_dir) | |||||
n_classes, vocab_size = pre.num_classes, pre.vocab_size | |||||
# construct model | |||||
model_args = { | |||||
'num_classes': n_classes, | |||||
'vocab_size': vocab_size | |||||
} | |||||
model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size) | |||||
# construct trainer | |||||
train_args = { | |||||
"epochs": 3, | |||||
"batch_size": 16, | |||||
"pickle_path": data_dir, | |||||
"validate": False, | |||||
"save_best_dev": False, | |||||
"model_saved_path": None, | |||||
"use_cuda": True, | |||||
"loss": Loss("cross_entropy"), | |||||
"optimizer": Optimizer("Adam", lr=0.001) | |||||
} | |||||
trainer = ClassificationTrainer(**train_args) | |||||
# start training | |||||
trainer.train(model, train_data=train_set, dev_data=dev_set) | |||||
# predict using model | |||||
data_infer = [x[0] for x in data] | |||||
infer = ClassificationInfer(data_dir) | |||||
labels_pred = infer.predict(model.cpu(), data_infer) | |||||
print(labels_pred) |
@@ -0,0 +1,3 @@ | |||||
from .core import * | |||||
from . import models | |||||
from . import modules |
@@ -0,0 +1,314 @@ | |||||
import warnings | |||||
import torch | |||||
warnings.filterwarnings('ignore') | |||||
import os | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.api.model_zoo import load_url | |||||
from fastNLP.api.processor import ModelProcessor | |||||
from reproduction.chinese_word_segment.cws_io.cws_reader import ConlluCWSReader | |||||
from reproduction.pos_tag_model.pos_io.pos_reader import ConlluPOSReader | |||||
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.batch import Batch | |||||
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 | |||||
from fastNLP.api.pipeline import Pipeline | |||||
from fastNLP.core.metrics import SeqLabelEvaluator2 | |||||
from fastNLP.core.tester import Tester | |||||
# TODO add pretrain urls | |||||
model_urls = { | |||||
} | |||||
class API: | |||||
def __init__(self): | |||||
self.pipeline = None | |||||
def predict(self, *args, **kwargs): | |||||
raise NotImplementedError | |||||
def load(self, path, device): | |||||
if os.path.exists(os.path.expanduser(path)): | |||||
_dict = torch.load(path, map_location='cpu') | |||||
else: | |||||
_dict = load_url(path, map_location='cpu') | |||||
self.pipeline = _dict['pipeline'] | |||||
self._dict = _dict | |||||
for processor in self.pipeline.pipeline: | |||||
if isinstance(processor, ModelProcessor): | |||||
processor.set_model_device(device) | |||||
class POS(API): | |||||
"""FastNLP API for Part-Of-Speech tagging. | |||||
""" | |||||
def __init__(self, model_path=None, device='cpu'): | |||||
super(POS, self).__init__() | |||||
if model_path is None: | |||||
model_path = model_urls['pos'] | |||||
self.load(model_path, device) | |||||
def predict(self, content): | |||||
""" | |||||
:param content: list of list of str. Each string is a token(word). | |||||
:return answer: list of list of str. Each string is a tag. | |||||
""" | |||||
if not hasattr(self, 'pipeline'): | |||||
raise ValueError("You have to load model first.") | |||||
sentence_list = [] | |||||
# 1. 检查sentence的类型 | |||||
if isinstance(content, str): | |||||
sentence_list.append(content) | |||||
elif isinstance(content, list): | |||||
sentence_list = content | |||||
# 2. 组建dataset | |||||
dataset = DataSet() | |||||
dataset.add_field('words', sentence_list) | |||||
# 3. 使用pipeline | |||||
self.pipeline(dataset) | |||||
output = dataset['word_pos_output'].content | |||||
if isinstance(content, str): | |||||
return output[0] | |||||
elif isinstance(content, list): | |||||
return output | |||||
def test(self, filepath): | |||||
tag_proc = self._dict['tag_indexer'] | |||||
model = self.pipeline.pipeline[2].model | |||||
pipeline = self.pipeline.pipeline[0:2] | |||||
pipeline.append(tag_proc) | |||||
pp = Pipeline(pipeline) | |||||
reader = ConlluPOSReader() | |||||
te_dataset = reader.load(filepath) | |||||
evaluator = SeqLabelEvaluator2('word_seq_origin_len') | |||||
end_tagidx_set = set() | |||||
tag_proc.vocab.build_vocab() | |||||
for key, value in tag_proc.vocab.word2idx.items(): | |||||
if key.startswith('E-'): | |||||
end_tagidx_set.add(value) | |||||
if key.startswith('S-'): | |||||
end_tagidx_set.add(value) | |||||
evaluator.end_tagidx_set = end_tagidx_set | |||||
default_valid_args = {"batch_size": 64, | |||||
"use_cuda": True, "evaluator": evaluator} | |||||
pp(te_dataset) | |||||
te_dataset.set_target(truth=True) | |||||
tester = Tester(**default_valid_args) | |||||
test_result = tester.test(model, te_dataset) | |||||
f1 = round(test_result['F'] * 100, 2) | |||||
pre = round(test_result['P'] * 100, 2) | |||||
rec = round(test_result['R'] * 100, 2) | |||||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | |||||
return f1, pre, rec | |||||
class CWS(API): | |||||
def __init__(self, model_path=None, device='cpu'): | |||||
super(CWS, self).__init__() | |||||
if model_path is None: | |||||
model_path = model_urls['cws'] | |||||
self.load(model_path, device) | |||||
def predict(self, content): | |||||
if not hasattr(self, 'pipeline'): | |||||
raise ValueError("You have to load model first.") | |||||
sentence_list = [] | |||||
# 1. 检查sentence的类型 | |||||
if isinstance(content, str): | |||||
sentence_list.append(content) | |||||
elif isinstance(content, list): | |||||
sentence_list = content | |||||
# 2. 组建dataset | |||||
dataset = DataSet() | |||||
dataset.add_field('raw_sentence', sentence_list) | |||||
# 3. 使用pipeline | |||||
self.pipeline(dataset) | |||||
output = dataset['output'].content | |||||
if isinstance(content, str): | |||||
return output[0] | |||||
elif isinstance(content, list): | |||||
return output | |||||
def test(self, filepath): | |||||
tag_proc = self._dict['tag_indexer'] | |||||
cws_model = self.pipeline.pipeline[-2].model | |||||
pipeline = self.pipeline.pipeline[:5] | |||||
pipeline.insert(1, tag_proc) | |||||
pp = Pipeline(pipeline) | |||||
reader = ConlluCWSReader() | |||||
# te_filename = '/home/hyan/ctb3/test.conllx' | |||||
te_dataset = reader.load(filepath) | |||||
pp(te_dataset) | |||||
batch_size = 64 | |||||
te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False) | |||||
pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher, type='bmes') | |||||
f1 = round(f1 * 100, 2) | |||||
pre = round(pre * 100, 2) | |||||
rec = round(rec * 100, 2) | |||||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | |||||
return f1, pre, rec | |||||
class Parser(API): | |||||
def __init__(self, model_path=None, device='cpu'): | |||||
super(Parser, self).__init__() | |||||
if model_path is None: | |||||
model_path = model_urls['parser'] | |||||
self.load(model_path, device) | |||||
def predict(self, content): | |||||
if not hasattr(self, 'pipeline'): | |||||
raise ValueError("You have to load model first.") | |||||
sentence_list = [] | |||||
# 1. 检查sentence的类型 | |||||
if isinstance(content, str): | |||||
sentence_list.append(content) | |||||
elif isinstance(content, list): | |||||
sentence_list = content | |||||
# 2. 组建dataset | |||||
dataset = DataSet() | |||||
dataset.add_field('words', sentence_list) | |||||
# dataset.add_field('tag', sentence_list) | |||||
# 3. 使用pipeline | |||||
self.pipeline(dataset) | |||||
for ins in dataset: | |||||
ins['heads'] = ins['heads'].tolist() | |||||
return dataset['heads'], dataset['labels'] | |||||
def test(self, filepath): | |||||
data = ConllxDataLoader().load(filepath) | |||||
ds = DataSet() | |||||
for ins1, ins2 in zip(add_seg_tag(data), data): | |||||
ds.append(Instance(words=ins1[0], tag=ins1[1], | |||||
gold_words=ins2[0], gold_pos=ins2[1], | |||||
gold_heads=ins2[2], gold_head_tags=ins2[3])) | |||||
pp = self.pipeline | |||||
for p in pp: | |||||
if p.field_name == 'word_list': | |||||
p.field_name = 'gold_words' | |||||
elif p.field_name == 'pos_list': | |||||
p.field_name = 'gold_pos' | |||||
pp(ds) | |||||
head_cor, label_cor, total = 0, 0, 0 | |||||
for ins in ds: | |||||
head_gold = ins['gold_heads'] | |||||
head_pred = ins['heads'] | |||||
length = len(head_gold) | |||||
total += length | |||||
for i in range(length): | |||||
head_cor += 1 if head_pred[i] == head_gold[i] else 0 | |||||
uas = head_cor / total | |||||
print('uas:{:.2f}'.format(uas)) | |||||
for p in pp: | |||||
if p.field_name == 'gold_words': | |||||
p.field_name = 'word_list' | |||||
elif p.field_name == 'gold_pos': | |||||
p.field_name = 'pos_list' | |||||
return uas | |||||
class Analyzer: | |||||
def __init__(self, device='cpu'): | |||||
self.cws = CWS(device=device) | |||||
self.pos = POS(device=device) | |||||
self.parser = Parser(device=device) | |||||
def predict(self, content, seg=False, pos=False, parser=False): | |||||
if seg is False and pos is False and parser is False: | |||||
seg = True | |||||
output_dict = {} | |||||
if seg: | |||||
seg_output = self.cws.predict(content) | |||||
output_dict['seg'] = seg_output | |||||
if pos: | |||||
pos_output = self.pos.predict(content) | |||||
output_dict['pos'] = pos_output | |||||
if parser: | |||||
parser_output = self.parser.predict(content) | |||||
output_dict['parser'] = parser_output | |||||
return output_dict | |||||
def test(self, filepath): | |||||
output_dict = {} | |||||
if self.seg: | |||||
seg_output = self.cws.test(filepath) | |||||
output_dict['seg'] = seg_output | |||||
if self.pos: | |||||
pos_output = self.pos.test(filepath) | |||||
output_dict['pos'] = pos_output | |||||
if self.parser: | |||||
parser_output = self.parser.test(filepath) | |||||
output_dict['parser'] = parser_output | |||||
return output_dict | |||||
if __name__ == "__main__": | |||||
# pos_model_path = '../../reproduction/pos_tag_model/pos_crf.pkl' | |||||
# pos = POS(device='cpu') | |||||
# s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' , | |||||
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
# '那么这款无人机到底有多厉害?'] | |||||
# print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll')) | |||||
# print(pos.predict(s)) | |||||
# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' | |||||
# cws = CWS(device='cpu') | |||||
# s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' , | |||||
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
# '那么这款无人机到底有多厉害?'] | |||||
# print(cws.test('/Users/yh/Desktop/test_data/cws_test.conll')) | |||||
# print(cws.predict(s)) | |||||
parser = Parser(device='cpu') | |||||
# print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll')) | |||||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
'那么这款无人机到底有多厉害?'] | |||||
print(parser.predict(s)) |
@@ -0,0 +1,181 @@ | |||||
import re | |||||
class SpanConverter: | |||||
def __init__(self, replace_tag, pattern): | |||||
super(SpanConverter, self).__init__() | |||||
self.replace_tag = replace_tag | |||||
self.pattern = pattern | |||||
def find_certain_span_and_replace(self, sentence): | |||||
replaced_sentence = '' | |||||
prev_end = 0 | |||||
for match in re.finditer(self.pattern, sentence): | |||||
start, end = match.span() | |||||
span = sentence[start:end] | |||||
replaced_sentence += sentence[prev_end:start] + self.span_to_special_tag(span) | |||||
prev_end = end | |||||
replaced_sentence += sentence[prev_end:] | |||||
return replaced_sentence | |||||
def span_to_special_tag(self, span): | |||||
return self.replace_tag | |||||
def find_certain_span(self, sentence): | |||||
spans = [] | |||||
for match in re.finditer(self.pattern, sentence): | |||||
spans.append(match.span()) | |||||
return spans | |||||
class AlphaSpanConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<ALPHA>' | |||||
# 理想状态下仅处理纯为字母的情况, 但不处理<[a-zA-Z]+>(因为这应该是特殊的tag). | |||||
pattern = '[a-zA-Z]+(?=[\u4e00-\u9fff ,%.!<\\-"])' | |||||
super(AlphaSpanConverter, self).__init__(replace_tag, pattern) | |||||
class DigitSpanConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<NUM>' | |||||
pattern = '\d[\d\\.]*(?=[\u4e00-\u9fff ,%.!<-])' | |||||
super(DigitSpanConverter, self).__init__(replace_tag, pattern) | |||||
def span_to_special_tag(self, span): | |||||
# return self.special_tag | |||||
if span[0] == '0' and len(span) > 2: | |||||
return '<NUM>' | |||||
decimal_point_count = 0 # one might have more than one decimal pointers | |||||
for idx, char in enumerate(span): | |||||
if char == '.' or char == '﹒' or char == '·': | |||||
decimal_point_count += 1 | |||||
if span[-1] == '.' or span[-1] == '﹒' or span[-1] == '·': | |||||
# last digit being decimal point means this is not a number | |||||
if decimal_point_count == 1: | |||||
return span | |||||
else: | |||||
return '<UNKDGT>' | |||||
if decimal_point_count == 1: | |||||
return '<DEC>' | |||||
elif decimal_point_count > 1: | |||||
return '<UNKDGT>' | |||||
else: | |||||
return '<NUM>' | |||||
class TimeConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<TOC>' | |||||
pattern = '\d+[::∶][\d::∶]+(?=[\u4e00-\u9fff ,%.!<-])' | |||||
super().__init__(replace_tag, pattern) | |||||
class MixNumAlphaConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<MIX>' | |||||
pattern = None | |||||
super().__init__(replace_tag, pattern) | |||||
def find_certain_span_and_replace(self, sentence): | |||||
replaced_sentence = '' | |||||
start = 0 | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
for idx in range(len(sentence)): | |||||
if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]): | |||||
if not matching_flag: | |||||
replaced_sentence += sentence[start:idx] | |||||
start = idx | |||||
if re.match('[0-9]', sentence[idx]): | |||||
number_flag = True | |||||
elif re.match('[\'′&\\-]', sentence[idx]): | |||||
link_flag = True | |||||
elif re.match('/', sentence[idx]): | |||||
slash_flag = True | |||||
elif re.match('[\\(\\)]', sentence[idx]): | |||||
bracket_flag = True | |||||
else: | |||||
alpha_flag = True | |||||
matching_flag = True | |||||
elif re.match('[\\.]', sentence[idx]): | |||||
pass | |||||
else: | |||||
if matching_flag: | |||||
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \ | |||||
or (slash_flag and alpha_flag) or (link_flag and number_flag) \ | |||||
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag): | |||||
span = sentence[start:idx] | |||||
start = idx | |||||
replaced_sentence += self.span_to_special_tag(span) | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
replaced_sentence += sentence[start:] | |||||
return replaced_sentence | |||||
def find_certain_span(self, sentence): | |||||
spans = [] | |||||
start = 0 | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
for idx in range(len(sentence)): | |||||
if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]): | |||||
if not matching_flag: | |||||
start = idx | |||||
if re.match('[0-9]', sentence[idx]): | |||||
number_flag = True | |||||
elif re.match('[\'′&\\-]', sentence[idx]): | |||||
link_flag = True | |||||
elif re.match('/', sentence[idx]): | |||||
slash_flag = True | |||||
elif re.match('[\\(\\)]', sentence[idx]): | |||||
bracket_flag = True | |||||
else: | |||||
alpha_flag = True | |||||
matching_flag = True | |||||
elif re.match('[\\.]', sentence[idx]): | |||||
pass | |||||
else: | |||||
if matching_flag: | |||||
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \ | |||||
or (slash_flag and alpha_flag) or (link_flag and number_flag) \ | |||||
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag): | |||||
spans.append((start, idx)) | |||||
start = idx | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
return spans | |||||
class EmailConverter(SpanConverter): | |||||
def __init__(self): | |||||
replaced_tag = "<EML>" | |||||
pattern = '[0-9a-zA-Z]+[@][.﹒0-9a-zA-Z@]+(?=[\u4e00-\u9fff ,%.!<\\-"$])' | |||||
super(EmailConverter, self).__init__(replaced_tag, pattern) |
@@ -0,0 +1,138 @@ | |||||
import torch | |||||
import hashlib | |||||
import os | |||||
import re | |||||
import shutil | |||||
import sys | |||||
import tempfile | |||||
try: | |||||
from requests.utils import urlparse | |||||
from requests import get as urlopen | |||||
requests_available = True | |||||
except ImportError: | |||||
requests_available = False | |||||
if sys.version_info[0] == 2: | |||||
from urlparse import urlparse # noqa f811 | |||||
from urllib2 import urlopen # noqa f811 | |||||
else: | |||||
from urllib.request import urlopen | |||||
from urllib.parse import urlparse | |||||
try: | |||||
from tqdm import tqdm | |||||
except ImportError: | |||||
tqdm = None # defined below | |||||
# matches bfd8deac from resnet18-bfd8deac.pth | |||||
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') | |||||
def load_url(url, model_dir=None, map_location=None, progress=True): | |||||
r"""Loads the Torch serialized object at the given URL. | |||||
If the object is already present in `model_dir`, it's deserialized and | |||||
returned. The filename part of the URL should follow the naming convention | |||||
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more | |||||
digits of the SHA256 hash of the contents of the file. The hash is used to | |||||
ensure unique names and to verify the contents of the file. | |||||
The default value of `model_dir` is ``$TORCH_HOME/models`` where | |||||
``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be | |||||
overridden with the ``$TORCH_MODEL_ZOO`` environment variable. | |||||
Args: | |||||
url (string): URL of the object to download | |||||
model_dir (string, optional): directory in which to save the object | |||||
map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) | |||||
progress (bool, optional): whether or not to display a progress bar to stderr | |||||
Example: | |||||
# >>> state_dict = model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') | |||||
""" | |||||
if model_dir is None: | |||||
torch_home = os.path.expanduser(os.getenv('fastNLP_HOME', '~/.fastNLP')) | |||||
model_dir = os.getenv('fastNLP_MODEL_ZOO', os.path.join(torch_home, 'models')) | |||||
if not os.path.exists(model_dir): | |||||
os.makedirs(model_dir) | |||||
parts = urlparse(url) | |||||
filename = os.path.basename(parts.path) | |||||
cached_file = os.path.join(model_dir, filename) | |||||
if not os.path.exists(cached_file): | |||||
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) | |||||
# hash_prefix = HASH_REGEX.search(filename).group(1) | |||||
_download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) | |||||
return torch.load(cached_file, map_location=map_location) | |||||
def _download_url_to_file(url, dst, hash_prefix, progress): | |||||
if requests_available: | |||||
u = urlopen(url, stream=True) | |||||
file_size = int(u.headers["Content-Length"]) | |||||
u = u.raw | |||||
else: | |||||
u = urlopen(url) | |||||
meta = u.info() | |||||
if hasattr(meta, 'getheaders'): | |||||
file_size = int(meta.getheaders("Content-Length")[0]) | |||||
else: | |||||
file_size = int(meta.get_all("Content-Length")[0]) | |||||
f = tempfile.NamedTemporaryFile(delete=False) | |||||
try: | |||||
if hash_prefix is not None: | |||||
sha256 = hashlib.sha256() | |||||
with tqdm(total=file_size, disable=not progress) as pbar: | |||||
while True: | |||||
buffer = u.read(8192) | |||||
if len(buffer) == 0: | |||||
break | |||||
f.write(buffer) | |||||
if hash_prefix is not None: | |||||
sha256.update(buffer) | |||||
pbar.update(len(buffer)) | |||||
f.close() | |||||
if hash_prefix is not None: | |||||
digest = sha256.hexdigest() | |||||
if digest[:len(hash_prefix)] != hash_prefix: | |||||
raise RuntimeError('invalid hash value (expected "{}", got "{}")' | |||||
.format(hash_prefix, digest)) | |||||
shutil.move(f.name, dst) | |||||
finally: | |||||
f.close() | |||||
if os.path.exists(f.name): | |||||
os.remove(f.name) | |||||
if tqdm is None: | |||||
# fake tqdm if it's not installed | |||||
class tqdm(object): | |||||
def __init__(self, total, disable=False): | |||||
self.total = total | |||||
self.disable = disable | |||||
self.n = 0 | |||||
def update(self, n): | |||||
if self.disable: | |||||
return | |||||
self.n += n | |||||
sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) | |||||
sys.stderr.flush() | |||||
def __enter__(self): | |||||
return self | |||||
def __exit__(self, exc_type, exc_val, exc_tb): | |||||
if self.disable: | |||||
return | |||||
sys.stderr.write('\n') | |||||
if __name__ == '__main__': | |||||
pipeline = load_url('http://10.141.208.102:5000/file/download/infer_context-4e86fd93.pkl', model_dir='.') | |||||
print(type(pipeline)) |
@@ -0,0 +1,33 @@ | |||||
from fastNLP.api.processor import Processor | |||||
class Pipeline: | |||||
""" | |||||
Pipeline takes a DataSet object as input, runs multiple processors sequentially, and | |||||
outputs a DataSet object. | |||||
""" | |||||
def __init__(self, processors=None): | |||||
self.pipeline = [] | |||||
if isinstance(processors, list): | |||||
for proc in processors: | |||||
assert isinstance(proc, Processor), "Must be a Processor, not {}.".format(type(proc)) | |||||
self.pipeline = processors | |||||
def add_processor(self, processor): | |||||
assert isinstance(processor, Processor), "Must be a Processor, not {}.".format(type(processor)) | |||||
self.pipeline.append(processor) | |||||
def process(self, dataset): | |||||
assert len(self.pipeline) != 0, "You need to add some processor first." | |||||
for proc in self.pipeline: | |||||
dataset = proc(dataset) | |||||
return dataset | |||||
def __call__(self, *args, **kwargs): | |||||
return self.process(*args, **kwargs) | |||||
def __getitem__(self, item): | |||||
return self.pipeline[item] |
@@ -0,0 +1,276 @@ | |||||
import torch | |||||
from collections import defaultdict | |||||
import re | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
class Processor: | |||||
def __init__(self, field_name, new_added_field_name): | |||||
self.field_name = field_name | |||||
if new_added_field_name is None: | |||||
self.new_added_field_name = field_name | |||||
else: | |||||
self.new_added_field_name = new_added_field_name | |||||
def process(self, *args, **kwargs): | |||||
pass | |||||
def __call__(self, *args, **kwargs): | |||||
return self.process(*args, **kwargs) | |||||
class FullSpaceToHalfSpaceProcessor(Processor): | |||||
"""全角转半角,以字符为处理单元 | |||||
""" | |||||
def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True, | |||||
change_space=True): | |||||
super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None) | |||||
self.change_alpha = change_alpha | |||||
self.change_digit = change_digit | |||||
self.change_punctuation = change_punctuation | |||||
self.change_space = change_space | |||||
FH_SPACE = [(u" ", u" ")] | |||||
FH_NUM = [ | |||||
(u"0", u"0"), (u"1", u"1"), (u"2", u"2"), (u"3", u"3"), (u"4", u"4"), | |||||
(u"5", u"5"), (u"6", u"6"), (u"7", u"7"), (u"8", u"8"), (u"9", u"9")] | |||||
FH_ALPHA = [ | |||||
(u"a", u"a"), (u"b", u"b"), (u"c", u"c"), (u"d", u"d"), (u"e", u"e"), | |||||
(u"f", u"f"), (u"g", u"g"), (u"h", u"h"), (u"i", u"i"), (u"j", u"j"), | |||||
(u"k", u"k"), (u"l", u"l"), (u"m", u"m"), (u"n", u"n"), (u"o", u"o"), | |||||
(u"p", u"p"), (u"q", u"q"), (u"r", u"r"), (u"s", u"s"), (u"t", u"t"), | |||||
(u"u", u"u"), (u"v", u"v"), (u"w", u"w"), (u"x", u"x"), (u"y", u"y"), | |||||
(u"z", u"z"), | |||||
(u"A", u"A"), (u"B", u"B"), (u"C", u"C"), (u"D", u"D"), (u"E", u"E"), | |||||
(u"F", u"F"), (u"G", u"G"), (u"H", u"H"), (u"I", u"I"), (u"J", u"J"), | |||||
(u"K", u"K"), (u"L", u"L"), (u"M", u"M"), (u"N", u"N"), (u"O", u"O"), | |||||
(u"P", u"P"), (u"Q", u"Q"), (u"R", u"R"), (u"S", u"S"), (u"T", u"T"), | |||||
(u"U", u"U"), (u"V", u"V"), (u"W", u"W"), (u"X", u"X"), (u"Y", u"Y"), | |||||
(u"Z", u"Z")] | |||||
# 谨慎使用标点符号转换, 因为"5.12特大地震"转换后可能就成了"5.12特大地震" | |||||
FH_PUNCTUATION = [ | |||||
(u'%', u'%'), (u'!', u'!'), (u'"', u'\"'), (u''', u'\''), (u'#', u'#'), | |||||
(u'¥', u'$'), (u'&', u'&'), (u'(', u'('), (u')', u')'), (u'*', u'*'), | |||||
(u'+', u'+'), (u',', u','), (u'-', u'-'), (u'.', u'.'), (u'/', u'/'), | |||||
(u':', u':'), (u';', u';'), (u'<', u'<'), (u'=', u'='), (u'>', u'>'), | |||||
(u'?', u'?'), (u'@', u'@'), (u'[', u'['), (u']', u']'), (u'\', u'\\'), | |||||
(u'^', u'^'), (u'_', u'_'), (u'`', u'`'), (u'~', u'~'), (u'{', u'{'), | |||||
(u'}', u'}'), (u'|', u'|')] | |||||
FHs = [] | |||||
if self.change_alpha: | |||||
FHs = FH_ALPHA | |||||
if self.change_digit: | |||||
FHs += FH_NUM | |||||
if self.change_punctuation: | |||||
FHs += FH_PUNCTUATION | |||||
if self.change_space: | |||||
FHs += FH_SPACE | |||||
self.convert_map = {k: v for k, v in FHs} | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
sentence = ins[self.field_name] | |||||
new_sentence = [None] * len(sentence) | |||||
for idx, char in enumerate(sentence): | |||||
if char in self.convert_map: | |||||
char = self.convert_map[char] | |||||
new_sentence[idx] = char | |||||
ins[self.field_name] = ''.join(new_sentence) | |||||
return dataset | |||||
class PreAppendProcessor(Processor): | |||||
def __init__(self, data, field_name, new_added_field_name=None): | |||||
super(PreAppendProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.data = data | |||||
def process(self, dataset): | |||||
for ins in dataset: | |||||
sent = ins[self.field_name] | |||||
ins[self.new_added_field_name] = [self.data] + sent | |||||
return dataset | |||||
class SliceProcessor(Processor): | |||||
def __init__(self, start, end, step, field_name, new_added_field_name=None): | |||||
super(SliceProcessor, self).__init__(field_name, new_added_field_name) | |||||
for o in (start, end, step): | |||||
assert isinstance(o, int) or o is None | |||||
self.slice = slice(start, end, step) | |||||
def process(self, dataset): | |||||
for ins in dataset: | |||||
sent = ins[self.field_name] | |||||
ins[self.new_added_field_name] = sent[self.slice] | |||||
return dataset | |||||
class Num2TagProcessor(Processor): | |||||
def __init__(self, tag, field_name, new_added_field_name=None): | |||||
super(Num2TagProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.tag = tag | |||||
self.pattern = r'[-+]?([0-9]+[.]?[0-9]*)+[/eE]?[-+]?([0-9]+[.]?[0-9]*)' | |||||
def process(self, dataset): | |||||
for ins in dataset: | |||||
s = ins[self.field_name] | |||||
new_s = [None] * len(s) | |||||
for i, w in enumerate(s): | |||||
if re.search(self.pattern, w) is not None: | |||||
w = self.tag | |||||
new_s[i] = w | |||||
ins[self.new_added_field_name] = new_s | |||||
return dataset | |||||
class IndexerProcessor(Processor): | |||||
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False): | |||||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | |||||
super(IndexerProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.vocab = vocab | |||||
self.delete_old_field = delete_old_field | |||||
def set_vocab(self, vocab): | |||||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | |||||
self.vocab = vocab | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
tokens = ins[self.field_name] | |||||
index = [self.vocab.to_index(token) for token in tokens] | |||||
ins[self.new_added_field_name] = index | |||||
dataset._set_need_tensor(**{self.new_added_field_name: True}) | |||||
if self.delete_old_field: | |||||
dataset.delete_field(self.field_name) | |||||
return dataset | |||||
class VocabProcessor(Processor): | |||||
def __init__(self, field_name): | |||||
super(VocabProcessor, self).__init__(field_name, None) | |||||
self.vocab = Vocabulary() | |||||
def process(self, *datasets): | |||||
for dataset in datasets: | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
tokens = ins[self.field_name] | |||||
self.vocab.update(tokens) | |||||
def get_vocab(self): | |||||
self.vocab.build_vocab() | |||||
return self.vocab | |||||
class SeqLenProcessor(Processor): | |||||
def __init__(self, field_name, new_added_field_name='seq_lens'): | |||||
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
length = len(ins[self.field_name]) | |||||
ins[self.new_added_field_name] = length | |||||
dataset._set_need_tensor(**{self.new_added_field_name: True}) | |||||
return dataset | |||||
class ModelProcessor(Processor): | |||||
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | |||||
""" | |||||
迭代模型并将结果的padding drop掉 | |||||
:param seq_len_field_name: | |||||
:param batch_size: | |||||
""" | |||||
super(ModelProcessor, self).__init__(None, None) | |||||
self.batch_size = batch_size | |||||
self.seq_len_field_name = seq_len_field_name | |||||
self.model = model | |||||
def process(self, dataset): | |||||
self.model.eval() | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False) | |||||
batch_output = defaultdict(list) | |||||
with torch.no_grad(): | |||||
for batch_x, _ in data_iterator: | |||||
prediction = self.model.predict(**batch_x) | |||||
seq_lens = batch_x[self.seq_len_field_name].cpu().numpy().tolist() | |||||
for key, value in prediction.items(): | |||||
tmp_batch = [] | |||||
value = value.cpu().numpy() | |||||
if len(value.shape) == 1 or (len(value.shape)==2 and value.shape[1]==1): | |||||
batch_output[key].extend(value.tolist()) | |||||
else: | |||||
for idx, seq_len in enumerate(seq_lens): | |||||
tmp_batch.append(value[idx, :seq_len]) | |||||
batch_output[key].extend(tmp_batch) | |||||
batch_output[self.seq_len_field_name].extend(seq_lens) | |||||
# TODO 当前的实现会导致之后的processor需要知道model输出的output的key是什么 | |||||
for field_name, fields in batch_output.items(): | |||||
dataset.add_field(field_name, fields, need_tensor=False, is_target=False) | |||||
return dataset | |||||
def set_model(self, model): | |||||
self.model = model | |||||
def set_model_device(self, device): | |||||
device = torch.device(device) | |||||
self.model.to(device) | |||||
class Index2WordProcessor(Processor): | |||||
def __init__(self, vocab, field_name, new_added_field_name): | |||||
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.vocab = vocab | |||||
def process(self, dataset): | |||||
for ins in dataset: | |||||
new_sent = [self.vocab.to_word(w) for w in ins[self.field_name]] | |||||
ins[self.new_added_field_name] = new_sent | |||||
return dataset | |||||
class SetTensorProcessor(Processor): | |||||
def __init__(self, field_dict, default=False): | |||||
super(SetTensorProcessor, self).__init__(None, None) | |||||
self.field_dict = field_dict | |||||
self.default = default | |||||
def process(self, dataset): | |||||
set_dict = {name: self.default for name in dataset.get_fields().keys()} | |||||
set_dict.update(self.field_dict) | |||||
dataset._set_need_tensor(**set_dict) | |||||
return dataset | |||||
class SetIsTargetProcessor(Processor): | |||||
def __init__(self, field_dict, default=False): | |||||
super(SetIsTargetProcessor, self).__init__(None, None) | |||||
self.field_dict = field_dict | |||||
self.default = default | |||||
def process(self, dataset): | |||||
set_dict = {name: self.default for name in dataset.get_fields().keys()} | |||||
set_dict.update(self.field_dict) | |||||
dataset.set_target(**set_dict) | |||||
return dataset |
@@ -0,0 +1,11 @@ | |||||
from .batch import Batch | |||||
from .dataset import DataSet | |||||
from .fieldarray import FieldArray | |||||
from .instance import Instance | |||||
from .metrics import Evaluator, ClassifyEvaluator, SNLIEvaluator, SeqLabelEvaluator | |||||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler | |||||
from .tester import Tester | |||||
from .trainer import Trainer | |||||
from .vocabulary import Vocabulary | |||||
from .optimizer import Optimizer | |||||
from .loss import Loss |
@@ -1,5 +1,3 @@ | |||||
from collections import defaultdict | |||||
import torch | import torch | ||||
@@ -7,25 +5,24 @@ class Batch(object): | |||||
"""Batch is an iterable object which iterates over mini-batches. | """Batch is an iterable object which iterates over mini-batches. | ||||
:: | :: | ||||
for batch_x, batch_y in Batch(data_set): | |||||
for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()): | |||||
""" | """ | ||||
def __init__(self, dataset, batch_size, sampler, use_cuda, sort_in_batch=False, sort_key=None): | |||||
def __init__(self, dataset, batch_size, sampler, as_numpy=False): | |||||
""" | """ | ||||
:param dataset: a DataSet object | :param dataset: a DataSet object | ||||
:param batch_size: int, the size of the batch | :param batch_size: int, the size of the batch | ||||
:param sampler: a Sampler object | :param sampler: a Sampler object | ||||
:param use_cuda: bool, whether to use GPU | |||||
:param as_numpy: bool. If True, return Numpy array. Otherwise, return torch tensors. | |||||
""" | """ | ||||
self.dataset = dataset | self.dataset = dataset | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.sampler = sampler | self.sampler = sampler | ||||
self.use_cuda = use_cuda | |||||
self.sort_in_batch = sort_in_batch | |||||
self.sort_key = sort_key if sort_key is not None else 'word_seq' | |||||
self.as_numpy = as_numpy | |||||
self.idx_list = None | self.idx_list = None | ||||
self.curidx = 0 | self.curidx = 0 | ||||
@@ -36,49 +33,24 @@ class Batch(object): | |||||
return self | return self | ||||
def __next__(self): | def __next__(self): | ||||
""" | |||||
:return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | |||||
E.g. | |||||
:: | |||||
{'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]}) | |||||
batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | |||||
All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True. | |||||
""" | |||||
if self.curidx >= len(self.idx_list): | if self.curidx >= len(self.idx_list): | ||||
raise StopIteration | raise StopIteration | ||||
else: | else: | ||||
endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | ||||
batch_idxes = self.idx_list[self.curidx: endidx] | |||||
padding_length = {field_name: max([field_length[idx] for idx in batch_idxes]) | |||||
for field_name, field_length in self.lengths.items()} | |||||
batch_x, batch_y = defaultdict(list), defaultdict(list) | |||||
# transform index to tensor and do padding for sequences | |||||
batch = [] | |||||
for idx in batch_idxes: | |||||
x, y = self.dataset.to_tensor(idx, padding_length) | |||||
batch.append((self.lengths[self.sort_key][idx] if self.sort_in_batch else None, x, y)) | |||||
if self.sort_in_batch: | |||||
batch = sorted(batch, key=lambda x: x[0], reverse=True) | |||||
batch_x, batch_y = {}, {} | |||||
for _, x, y in batch: | |||||
for name, tensor in x.items(): | |||||
batch_x[name].append(tensor) | |||||
for name, tensor in y.items(): | |||||
batch_y[name].append(tensor) | |||||
indices = self.idx_list[self.curidx:endidx] | |||||
# combine instances to form a batch | |||||
for batch in (batch_x, batch_y): | |||||
for name, tensor_list in batch.items(): | |||||
if self.use_cuda: | |||||
batch[name] = torch.stack(tensor_list, dim=0).cuda() | |||||
else: | |||||
batch[name] = torch.stack(tensor_list, dim=0) | |||||
for field_name, field in self.dataset.get_fields().items(): | |||||
if field.is_target or field.is_input: | |||||
batch = field.get(indices) | |||||
if not self.as_numpy: | |||||
batch = torch.from_numpy(batch) | |||||
if field.is_target: | |||||
batch_y[field_name] = batch | |||||
if field.is_input: | |||||
batch_x[field_name] = batch | |||||
self.curidx = endidx | self.curidx = endidx | ||||
return batch_x, batch_y | |||||
return batch_x, batch_y |
@@ -1,160 +1,326 @@ | |||||
import random | |||||
import sys | |||||
from collections import defaultdict | |||||
from copy import deepcopy | |||||
import numpy as np | |||||
from copy import copy | |||||
from fastNLP.core.field import TextField, LabelField | |||||
from fastNLP.core.fieldarray import FieldArray | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.vocabulary import Vocabulary | |||||
_READERS = {} | _READERS = {} | ||||
class DataSet(list): | |||||
"""A DataSet object is a list of Instance objects. | |||||
def construct_dataset(sentences): | |||||
"""Construct a data set from a list of sentences. | |||||
:param sentences: list of list of str | |||||
:return dataset: a DataSet object | |||||
""" | """ | ||||
dataset = DataSet() | |||||
for sentence in sentences: | |||||
instance = Instance() | |||||
instance['raw_sentence'] = sentence | |||||
dataset.append(instance) | |||||
return dataset | |||||
class DataSet(object): | |||||
"""DataSet is the collection of examples. | |||||
DataSet provides instance-level interface. You can append and access an instance of the DataSet. | |||||
However, it stores data in a different way: Field-first, Instance-second. | |||||
""" | |||||
class Instance(object): | |||||
def __init__(self, dataset, idx=-1, **fields): | |||||
self.dataset = dataset | |||||
self.idx = idx | |||||
self.fields = fields | |||||
def __next__(self): | |||||
self.idx += 1 | |||||
if self.idx >= len(self.dataset): | |||||
raise StopIteration | |||||
return copy(self) | |||||
def __init__(self, name="", instances=None): | |||||
def add_field(self, field_name, field): | |||||
"""Add a new field to the instance. | |||||
:param field_name: str, the name of the field. | |||||
:param field: | |||||
""" | |||||
self.fields[field_name] = field | |||||
def __getitem__(self, name): | |||||
return self.dataset[name][self.idx] | |||||
def __setitem__(self, name, val): | |||||
if name not in self.dataset: | |||||
new_fields = [None] * len(self.dataset) | |||||
self.dataset.add_field(name, new_fields) | |||||
self.dataset[name][self.idx] = val | |||||
def __repr__(self): | |||||
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name | |||||
in self.dataset.get_fields().keys()]) | |||||
def __init__(self, data=None): | |||||
""" | """ | ||||
:param name: str, the name of the dataset. (default: "") | |||||
:param instances: list of Instance objects. (default: None) | |||||
:param data: a dict or a list. If it is a dict, the key is the name of a field and the value is the field. | |||||
All values must be of the same length. | |||||
If it is a list, it must be a list of Instance objects. | |||||
""" | """ | ||||
list.__init__([]) | |||||
self.name = name | |||||
self.origin_len = None | |||||
if instances is not None: | |||||
self.extend(instances) | |||||
def index_all(self, vocab): | |||||
for ins in self: | |||||
ins.index_all(vocab) | |||||
return self | |||||
self.field_arrays = {} | |||||
if data is not None: | |||||
if isinstance(data, dict): | |||||
length_set = set() | |||||
for key, value in data.items(): | |||||
length_set.add(len(value)) | |||||
assert len(length_set) == 1, "Arrays must all be same length." | |||||
for key, value in data.items(): | |||||
self.add_field(name=key, fields=value) | |||||
elif isinstance(data, list): | |||||
for ins in data: | |||||
assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins)) | |||||
self.append(ins) | |||||
else: | |||||
raise ValueError("data only be dict or list type.") | |||||
def index_field(self, field_name, vocab): | |||||
if isinstance(field_name, str): | |||||
field_list = [field_name] | |||||
vocab_list = [vocab] | |||||
def __contains__(self, item): | |||||
return item in self.field_arrays | |||||
def __iter__(self): | |||||
return self.Instance(self) | |||||
def _convert_ins(self, ins_list): | |||||
if isinstance(ins_list, list): | |||||
for ins in ins_list: | |||||
self.append(ins) | |||||
else: | else: | ||||
classes = (list, tuple) | |||||
assert isinstance(field_name, classes) and isinstance(vocab, classes) and len(field_name) == len(vocab) | |||||
field_list = field_name | |||||
vocab_list = vocab | |||||
for name, vocabs in zip(field_list, vocab_list): | |||||
for ins in self: | |||||
ins.index_field(name, vocabs) | |||||
return self | |||||
self.append(ins_list) | |||||
def to_tensor(self, idx: int, padding_length: dict): | |||||
"""Convert an instance in a dataset to tensor. | |||||
def append(self, ins): | |||||
"""Add an instance to the DataSet. | |||||
If the DataSet is not empty, the instance must have the same field names as the rest instances in the DataSet. | |||||
:param idx: int, the index of the instance in the dataset. | |||||
:param padding_length: int | |||||
:return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||||
tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||||
:param ins: an Instance object | |||||
""" | """ | ||||
ins = self[idx] | |||||
return ins.to_tensor(padding_length, self.origin_len) | |||||
if len(self.field_arrays) == 0: | |||||
# DataSet has no field yet | |||||
for name, field in ins.fields.items(): | |||||
self.field_arrays[name] = FieldArray(name, [field]) | |||||
else: | |||||
assert len(self.field_arrays) == len(ins.fields) | |||||
for name, field in ins.fields.items(): | |||||
assert name in self.field_arrays | |||||
self.field_arrays[name].append(field) | |||||
def get_length(self): | |||||
"""Fetch lengths of all fields in all instances in a dataset. | |||||
def add_field(self, name, fields, padding_val=0, is_input=False, is_target=False): | |||||
"""Add a new field to the DataSet. | |||||
:param str name: the name of the field. | |||||
:param fields: a list of int, float, or other objects. | |||||
:param int padding_val: integer for padding. | |||||
:param bool is_input: whether this field is model input. | |||||
:param bool is_target: whether this field is label or target. | |||||
""" | |||||
if len(self.field_arrays) != 0: | |||||
assert len(self) == len(fields) | |||||
self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target, | |||||
is_input=is_input) | |||||
:return lengths: dict of (str: list). The str is the field name. | |||||
The list contains lengths of this field in all instances. | |||||
def delete_field(self, name): | |||||
"""Delete a field based on the field name. | |||||
:param str name: the name of the field to be deleted. | |||||
""" | """ | ||||
lengths = defaultdict(list) | |||||
for ins in self: | |||||
for field_name, field_length in ins.get_length().items(): | |||||
lengths[field_name].append(field_length) | |||||
return lengths | |||||
def shuffle(self): | |||||
random.shuffle(self) | |||||
return self | |||||
self.field_arrays.pop(name) | |||||
def split(self, ratio, shuffle=True): | |||||
"""Train/dev splitting | |||||
def get_fields(self): | |||||
"""Return all the fields with their names. | |||||
:param ratio: float, between 0 and 1. The ratio of development set in origin data set. | |||||
:param shuffle: bool, whether shuffle the data set before splitting. Default: True. | |||||
:return train_set: a DataSet object, representing the training set | |||||
dev_set: a DataSet object, representing the validation set | |||||
:return dict field_arrays: the internal data structure of DataSet. | |||||
""" | |||||
return self.field_arrays | |||||
def __getitem__(self, idx): | |||||
""" | """ | ||||
assert 0 < ratio < 1 | |||||
if shuffle: | |||||
self.shuffle() | |||||
split_idx = int(len(self) * ratio) | |||||
dev_set = deepcopy(self) | |||||
train_set = deepcopy(self) | |||||
del train_set[:split_idx] | |||||
del dev_set[split_idx:] | |||||
return train_set, dev_set | |||||
:param idx: can be int, slice, or str. | |||||
:return: If `idx` is int, return an Instance object. | |||||
If `idx` is slice, return a DataSet object. | |||||
If `idx` is str, it must be a field name, return the field. | |||||
""" | |||||
if isinstance(idx, int): | |||||
return self.Instance(self, idx, **{name: self.field_arrays[name][idx] for name in self.field_arrays}) | |||||
elif isinstance(idx, slice): | |||||
data_set = DataSet() | |||||
for field in self.field_arrays.values(): | |||||
data_set.add_field(name=field.name, | |||||
fields=field.content[idx], | |||||
padding_val=field.padding_val, | |||||
is_input=field.is_input, | |||||
is_target=field.is_target) | |||||
return data_set | |||||
elif isinstance(idx, str): | |||||
return self.field_arrays[idx] | |||||
else: | |||||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||||
def __len__(self): | |||||
if len(self.field_arrays) == 0: | |||||
return 0 | |||||
field = iter(self.field_arrays.values()).__next__() | |||||
return len(field) | |||||
def get_length(self): | |||||
"""The same as __len__ | |||||
""" | |||||
return len(self) | |||||
def rename_field(self, old_name, new_name): | def rename_field(self, old_name, new_name): | ||||
"""rename a field | """rename a field | ||||
""" | """ | ||||
for ins in self: | |||||
ins.rename_field(old_name, new_name) | |||||
return self | |||||
if old_name in self.field_arrays: | |||||
self.field_arrays[new_name] = self.field_arrays.pop(old_name) | |||||
else: | |||||
raise KeyError("{} is not a valid name. ".format(old_name)) | |||||
def set_target(self, **fields): | def set_target(self, **fields): | ||||
"""Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged. | """Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged. | ||||
:param key-value pairs for field-name and `is_target` value(True, False or None). | |||||
:param key-value pairs for field-name and `is_target` value(True, False). | |||||
""" | """ | ||||
for ins in self: | |||||
ins.set_target(**fields) | |||||
for name, val in fields.items(): | |||||
if name in self.field_arrays: | |||||
assert isinstance(val, bool) | |||||
self.field_arrays[name].is_target = val | |||||
else: | |||||
raise KeyError("{} is not a valid field name.".format(name)) | |||||
return self | return self | ||||
def update_vocab(self, **name_vocab): | |||||
"""using certain field data to update vocabulary. | |||||
e.g. :: | |||||
# update word vocab and label vocab seperately | |||||
dataset.update_vocab(word_seq=word_vocab, label_seq=label_vocab) | |||||
""" | |||||
for field_name, vocab in name_vocab.items(): | |||||
for ins in self: | |||||
vocab.update(ins[field_name].contents()) | |||||
def set_input(self, **fields): | |||||
for name, val in fields.items(): | |||||
if name in self.field_arrays: | |||||
assert isinstance(val, bool) | |||||
self.field_arrays[name].is_input = val | |||||
else: | |||||
raise KeyError("{} is not a valid field name.".format(name)) | |||||
return self | return self | ||||
def set_origin_len(self, origin_field, origin_len_name=None): | |||||
"""make dataset tensor output contain origin_len field. | |||||
def get_input_name(self): | |||||
return [name for name, field in self.field_arrays.items() if field.is_input] | |||||
e.g. :: | |||||
def get_target_name(self): | |||||
return [name for name, field in self.field_arrays.items() if field.is_target] | |||||
# output "word_seq_origin_len", lengths based on "word_seq" field | |||||
dataset.set_origin_len("word_seq") | |||||
""" | |||||
if origin_field is None: | |||||
self.origin_len = None | |||||
else: | |||||
self.origin_len = (origin_field + "_origin_len", origin_field) \ | |||||
if origin_len_name is None else (origin_len_name, origin_field) | |||||
return self | |||||
def __getattr__(self, item): | |||||
# block infinite recursion for copy, pickle | |||||
if item == '__setstate__': | |||||
raise AttributeError(item) | |||||
try: | |||||
return self.field_arrays.__getitem__(item) | |||||
except KeyError: | |||||
pass | |||||
try: | |||||
reader_cls = _READERS[item] | |||||
def __getattribute__(self, name): | |||||
if name in _READERS: | |||||
# add read_*data() support | # add read_*data() support | ||||
def _read(*args, **kwargs): | def _read(*args, **kwargs): | ||||
data = _READERS[name]().load(*args, **kwargs) | |||||
data = reader_cls().load(*args, **kwargs) | |||||
self.extend(data) | self.extend(data) | ||||
return self | return self | ||||
return _read | return _read | ||||
else: | |||||
return object.__getattribute__(self, name) | |||||
except KeyError: | |||||
raise AttributeError('{} does not exist.'.format(item)) | |||||
@classmethod | @classmethod | ||||
def set_reader(cls, method_name): | def set_reader(cls, method_name): | ||||
"""decorator to add dataloader support | """decorator to add dataloader support | ||||
""" | """ | ||||
assert isinstance(method_name, str) | assert isinstance(method_name, str) | ||||
def wrapper(read_cls): | def wrapper(read_cls): | ||||
_READERS[method_name] = read_cls | _READERS[method_name] = read_cls | ||||
return read_cls | return read_cls | ||||
return wrapper | return wrapper | ||||
def apply(self, func, new_field_name=None): | |||||
"""Apply a function to every instance of the DataSet. | |||||
:param func: a function that takes an instance as input. | |||||
:param str new_field_name: If not None, results of the function will be stored as a new field. | |||||
:return results: returned values of the function over all instances. | |||||
""" | |||||
results = [func(ins) for ins in self] | |||||
if new_field_name is not None: | |||||
if new_field_name in self.field_arrays: | |||||
# overwrite the field, keep same attributes | |||||
old_field = self.field_arrays[new_field_name] | |||||
self.add_field(name=new_field_name, | |||||
fields=results, | |||||
padding_val=old_field.padding_val, | |||||
is_input=old_field.is_input, | |||||
is_target=old_field.is_target) | |||||
else: | |||||
self.add_field(name=new_field_name, fields=results) | |||||
else: | |||||
return results | |||||
def drop(self, func): | |||||
results = [ins for ins in self if not func(ins)] | |||||
for name, old_field in self.field_arrays.items(): | |||||
self.field_arrays[name].content = [ins[name] for ins in results] | |||||
# print(self.field_arrays[name]) | |||||
def split(self, dev_ratio): | |||||
"""Split the dataset into training and development(validation) set. | |||||
:param float dev_ratio: the ratio of test set in all data. | |||||
:return DataSet train_set: the training set | |||||
DataSet dev_set: the development set | |||||
""" | |||||
assert isinstance(dev_ratio, float) | |||||
assert 0 < dev_ratio < 1 | |||||
all_indices = [_ for _ in range(len(self))] | |||||
np.random.shuffle(all_indices) | |||||
split = int(dev_ratio * len(self)) | |||||
dev_indices = all_indices[:split] | |||||
train_indices = all_indices[split:] | |||||
dev_set = DataSet() | |||||
train_set = DataSet() | |||||
for idx in dev_indices: | |||||
dev_set.append(self[idx]) | |||||
for idx in train_indices: | |||||
train_set.append(self[idx]) | |||||
return train_set, dev_set | |||||
@classmethod | |||||
def read_csv(cls, csv_path, headers=None, sep='\t', dropna=True): | |||||
with open(csv_path, 'r') as f: | |||||
start_idx = 0 | |||||
if headers is None: | |||||
headers = f.readline().rstrip('\r\n') | |||||
headers = headers.split(sep) | |||||
start_idx += 1 | |||||
else: | |||||
assert isinstance(headers, (list, tuple)), "headers should be list or tuple, not {}.".format(type(headers)) | |||||
_dict = {} | |||||
for col in headers: | |||||
_dict[col] = [] | |||||
for line_idx, line in enumerate(f, start_idx): | |||||
contents = line.split(sep) | |||||
if len(contents)!=len(headers): | |||||
if dropna: | |||||
continue | |||||
else: | |||||
#TODO change error type | |||||
raise ValueError("Line {} has {} parts, while header has {} parts."\ | |||||
.format(line_idx, len(contents), len(headers))) | |||||
for header, content in zip(headers, contents): | |||||
_dict[header].append(content) | |||||
return cls(_dict) |
@@ -1,167 +0,0 @@ | |||||
import torch | |||||
class Field(object): | |||||
"""A field defines a data type. | |||||
""" | |||||
def __init__(self, is_target: bool): | |||||
self.is_target = is_target | |||||
def index(self, vocab): | |||||
raise NotImplementedError | |||||
def get_length(self): | |||||
raise NotImplementedError | |||||
def to_tensor(self, padding_length): | |||||
raise NotImplementedError | |||||
def contents(self): | |||||
raise NotImplementedError | |||||
class TextField(Field): | |||||
def __init__(self, text, is_target): | |||||
""" | |||||
:param text: list of strings | |||||
:param is_target: bool | |||||
""" | |||||
super(TextField, self).__init__(is_target) | |||||
self.text = text | |||||
self._index = None | |||||
def index(self, vocab): | |||||
if self._index is None: | |||||
self._index = [vocab[c] for c in self.text] | |||||
else: | |||||
raise RuntimeError("Replicate indexing of this field.") | |||||
return self._index | |||||
def get_length(self): | |||||
"""Fetch the length of the text field. | |||||
:return length: int, the length of the text. | |||||
""" | |||||
return len(self.text) | |||||
def to_tensor(self, padding_length: int): | |||||
"""Convert text field to tensor. | |||||
:param padding_length: int | |||||
:return tensor: torch.LongTensor, of shape [padding_length, ] | |||||
""" | |||||
pads = [] | |||||
if self._index is None: | |||||
raise RuntimeError("Indexing not done before to_tensor in TextField.") | |||||
if padding_length > self.get_length(): | |||||
pads = [0] * (padding_length - self.get_length()) | |||||
return torch.LongTensor(self._index + pads) | |||||
def contents(self): | |||||
return self.text.copy() | |||||
class LabelField(Field): | |||||
"""The Field representing a single label. Can be a string or integer. | |||||
""" | |||||
def __init__(self, label, is_target=True): | |||||
super(LabelField, self).__init__(is_target) | |||||
self.label = label | |||||
self._index = None | |||||
def get_length(self): | |||||
"""Fetch the length of the label field. | |||||
:return length: int, the length of the label, always 1. | |||||
""" | |||||
return 1 | |||||
def index(self, vocab): | |||||
if self._index is None: | |||||
if isinstance(self.label, str): | |||||
self._index = vocab[self.label] | |||||
return self._index | |||||
def to_tensor(self, padding_length): | |||||
if self._index is None: | |||||
if isinstance(self.label, int): | |||||
return torch.tensor(self.label) | |||||
elif isinstance(self.label, str): | |||||
raise RuntimeError("Field {} not indexed. Call index method.".format(self.label)) | |||||
else: | |||||
raise RuntimeError( | |||||
"Not support type for LabelField. Expect str or int, got {}.".format(type(self.label))) | |||||
else: | |||||
return torch.LongTensor([self._index]) | |||||
def contents(self): | |||||
return [self.label] | |||||
class SeqLabelField(Field): | |||||
def __init__(self, label_seq, is_target=True): | |||||
super(SeqLabelField, self).__init__(is_target) | |||||
self.label_seq = label_seq | |||||
self._index = None | |||||
def get_length(self): | |||||
return len(self.label_seq) | |||||
def index(self, vocab): | |||||
if self._index is None: | |||||
self._index = [vocab[c] for c in self.label_seq] | |||||
return self._index | |||||
def to_tensor(self, padding_length): | |||||
pads = [0] * (padding_length - self.get_length()) | |||||
if self._index is None: | |||||
if self.get_length() == 0: | |||||
return torch.LongTensor(pads) | |||||
elif isinstance(self.label_seq[0], int): | |||||
return torch.LongTensor(self.label_seq + pads) | |||||
elif isinstance(self.label_seq[0], str): | |||||
raise RuntimeError("Field {} not indexed. Call index method.".format(self.label)) | |||||
else: | |||||
raise RuntimeError( | |||||
"Not support type for SeqLabelField. Expect str or int, got {}.".format(type(self.label))) | |||||
else: | |||||
return torch.LongTensor(self._index + pads) | |||||
def contents(self): | |||||
return self.label_seq.copy() | |||||
class CharTextField(Field): | |||||
def __init__(self, text, max_word_len, is_target=False): | |||||
super(CharTextField, self).__init__(is_target) | |||||
self.text = text | |||||
self.max_word_len = max_word_len | |||||
self._index = [] | |||||
def get_length(self): | |||||
return len(self.text) | |||||
def contents(self): | |||||
return self.text.copy() | |||||
def index(self, char_vocab): | |||||
if len(self._index) == 0: | |||||
for word in self.text: | |||||
char_index = [char_vocab[ch] for ch in word] | |||||
if self.max_word_len >= len(char_index): | |||||
char_index += [0] * (self.max_word_len - len(char_index)) | |||||
else: | |||||
self._index.clear() | |||||
raise RuntimeError("Word {} has more than {} characters. ".format(word, self.max_word_len)) | |||||
self._index.append(char_index) | |||||
return self._index | |||||
def to_tensor(self, padding_length): | |||||
""" | |||||
:param padding_length: int, the padding length of the word sequence. | |||||
:return : tensor of shape (padding_length, max_word_len) | |||||
""" | |||||
pads = [[0] * self.max_word_len] * (padding_length - self.get_length()) | |||||
return torch.LongTensor(self._index + pads) |
@@ -0,0 +1,72 @@ | |||||
import numpy as np | |||||
class FieldArray(object): | |||||
"""FieldArray is the collection of Instances of the same Field. | |||||
It is the basic element of DataSet class. | |||||
""" | |||||
def __init__(self, name, content, padding_val=0, is_target=False, is_input=False): | |||||
""" | |||||
:param str name: the name of the FieldArray | |||||
:param list content: a list of int, float, or other objects. | |||||
:param int padding_val: the integer for padding. Default: 0. | |||||
:param bool is_target: If True, this FieldArray is used to compute loss. | |||||
:param bool is_input: If True, this FieldArray is used to the model input. | |||||
""" | |||||
self.name = name | |||||
self.content = content | |||||
self.padding_val = padding_val | |||||
self.is_target = is_target | |||||
self.is_input = is_input | |||||
# TODO: auto detect dtype | |||||
self.dtype = None | |||||
def __repr__(self): | |||||
return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | |||||
def append(self, val): | |||||
self.content.append(val) | |||||
def __getitem__(self, name): | |||||
return self.get(name) | |||||
def __setitem__(self, name, val): | |||||
assert isinstance(name, int) | |||||
self.content[name] = val | |||||
def get(self, indices): | |||||
"""Fetch instances based on indices. | |||||
:param indices: an int, or a list of int. | |||||
:return: | |||||
""" | |||||
if isinstance(indices, int): | |||||
return self.content[indices] | |||||
assert self.is_input is True or self.is_target is True | |||||
batch_size = len(indices) | |||||
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 | |||||
if not isiterable(self.content[0]): | |||||
if self.dtype is None: | |||||
self.dtype = np.int64 if isinstance(self.content[0], int) else np.double | |||||
array = np.array([self.content[i] for i in indices], dtype=self.dtype) | |||||
else: | |||||
if self.dtype is None: | |||||
self.dtype = np.int64 | |||||
max_len = max([len(self.content[i]) for i in indices]) | |||||
array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype) | |||||
for i, idx in enumerate(indices): | |||||
array[i][:len(self.content[idx])] = self.content[idx] | |||||
return array | |||||
def __len__(self): | |||||
return len(self.content) | |||||
def isiterable(content): | |||||
try: | |||||
_ = (e for e in content) | |||||
except TypeError: | |||||
return False | |||||
return True |
@@ -1,33 +1,27 @@ | |||||
import torch | |||||
class Instance(object): | class Instance(object): | ||||
"""An instance which consists of Fields is an example in the DataSet. | |||||
"""An Instance is an example of data. It is the collection of Fields. | |||||
:: | |||||
Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) | |||||
""" | """ | ||||
def __init__(self, **fields): | def __init__(self, **fields): | ||||
""" | |||||
:param fields: a dict of (str: list). | |||||
""" | |||||
self.fields = fields | self.fields = fields | ||||
self.has_index = False | |||||
self.indexes = {} | |||||
def add_field(self, field_name, field): | def add_field(self, field_name, field): | ||||
self.fields[field_name] = field | |||||
return self | |||||
def rename_field(self, old_name, new_name): | |||||
if old_name in self.fields: | |||||
self.fields[new_name] = self.fields.pop(old_name) | |||||
if old_name in self.indexes: | |||||
self.indexes[new_name] = self.indexes.pop(old_name) | |||||
else: | |||||
raise KeyError("error, no such field: {}".format(old_name)) | |||||
return self | |||||
"""Add a new field to the instance. | |||||
def set_target(self, **fields): | |||||
for name, val in fields.items(): | |||||
if name in self.fields: | |||||
self.fields[name].is_target = val | |||||
return self | |||||
:param field_name: str, the name of the field. | |||||
:param field: | |||||
""" | |||||
self.fields[field_name] = field | |||||
def __getitem__(self, name): | def __getitem__(self, name): | ||||
if name in self.fields: | if name in self.fields: | ||||
@@ -35,50 +29,8 @@ class Instance(object): | |||||
else: | else: | ||||
raise KeyError("{} not found".format(name)) | raise KeyError("{} not found".format(name)) | ||||
def get_length(self): | |||||
"""Fetch the length of all fields in the instance. | |||||
def __setitem__(self, name, field): | |||||
return self.add_field(name, field) | |||||
:return length: dict of (str: int), which means (field name: field length). | |||||
""" | |||||
length = {name: field.get_length() for name, field in self.fields.items()} | |||||
return length | |||||
def index_field(self, field_name, vocab): | |||||
"""use `vocab` to index certain field | |||||
""" | |||||
self.indexes[field_name] = self.fields[field_name].index(vocab) | |||||
return self | |||||
def index_all(self, vocab): | |||||
"""use `vocab` to index all fields | |||||
""" | |||||
if self.has_index: | |||||
print("error") | |||||
return self.indexes | |||||
indexes = {name: field.index(vocab) for name, field in self.fields.items()} | |||||
self.indexes = indexes | |||||
return indexes | |||||
def to_tensor(self, padding_length: dict, origin_len=None): | |||||
"""Convert instance to tensor. | |||||
:param padding_length: dict of (str: int), which means (field name: padding_length of this field) | |||||
:return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||||
tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||||
If is_target is False for all fields, tensor_y would be an empty dict. | |||||
""" | |||||
tensor_x = {} | |||||
tensor_y = {} | |||||
for name, field in self.fields.items(): | |||||
if field.is_target is True: | |||||
tensor_y[name] = field.to_tensor(padding_length[name]) | |||||
elif field.is_target is False: | |||||
tensor_x[name] = field.to_tensor(padding_length[name]) | |||||
else: | |||||
# is_target is None | |||||
continue | |||||
if origin_len is not None: | |||||
name, field_name = origin_len | |||||
tensor_x[name] = torch.LongTensor([self.fields[field_name].get_length()]) | |||||
return tensor_x, tensor_y | |||||
def __repr__(self): | |||||
return self.fields.__repr__() |
@@ -35,27 +35,72 @@ class SeqLabelEvaluator(Evaluator): | |||||
def __init__(self): | def __init__(self): | ||||
super(SeqLabelEvaluator, self).__init__() | super(SeqLabelEvaluator, self).__init__() | ||||
def __call__(self, predict, truth): | |||||
def __call__(self, predict, truth, **_): | |||||
""" | """ | ||||
:param predict: list of List, the network outputs from all batches. | :param predict: list of List, the network outputs from all batches. | ||||
:param truth: list of dict, the ground truths from all batch_y. | :param truth: list of dict, the ground truths from all batch_y. | ||||
:return accuracy: | :return accuracy: | ||||
""" | """ | ||||
truth = [item["truth"] for item in truth] | |||||
total_correct, total_count= 0., 0. | |||||
total_correct, total_count = 0., 0. | |||||
for x, y in zip(predict, truth): | for x, y in zip(predict, truth): | ||||
x = torch.Tensor(x) | |||||
x = torch.tensor(x) | |||||
y = y.to(x) # make sure they are in the same device | y = y.to(x) # make sure they are in the same device | ||||
mask = x.ge(1).float() | |||||
# correct = torch.sum(x * mask.float() == (y * mask.long()).float()) | |||||
correct = torch.sum(x * mask == y * mask) | |||||
correct -= torch.sum(x.le(0)) | |||||
mask = (y > 0) | |||||
correct = torch.sum(((x == y) * mask).long()) | |||||
total_correct += float(correct) | total_correct += float(correct) | ||||
total_count += float(torch.sum(mask)) | |||||
total_count += float(torch.sum(mask.long())) | |||||
accuracy = total_correct / total_count | accuracy = total_correct / total_count | ||||
return {"accuracy": float(accuracy)} | return {"accuracy": float(accuracy)} | ||||
class SeqLabelEvaluator2(Evaluator): | |||||
# 上面的evaluator应该是错误的 | |||||
def __init__(self, seq_lens_field_name='word_seq_origin_len'): | |||||
super(SeqLabelEvaluator2, self).__init__() | |||||
self.end_tagidx_set = set() | |||||
self.seq_lens_field_name = seq_lens_field_name | |||||
def __call__(self, predict, truth, **_): | |||||
""" | |||||
:param predict: list of batch, the network outputs from all batches. | |||||
:param truth: list of dict, the ground truths from all batch_y. | |||||
:return accuracy: | |||||
""" | |||||
seq_lens = _[self.seq_lens_field_name] | |||||
corr_count = 0 | |||||
pred_count = 0 | |||||
truth_count = 0 | |||||
for x, y, seq_len in zip(predict, truth, seq_lens): | |||||
x = x.cpu().numpy() | |||||
y = y.cpu().numpy() | |||||
for idx, s_l in enumerate(seq_len): | |||||
x_ = x[idx] | |||||
y_ = y[idx] | |||||
x_ = x_[:s_l] | |||||
y_ = y_[:s_l] | |||||
flag = True | |||||
start = 0 | |||||
for idx_i, (x_i, y_i) in enumerate(zip(x_, y_)): | |||||
if x_i in self.end_tagidx_set: | |||||
truth_count += 1 | |||||
for j in range(start, idx_i + 1): | |||||
if y_[j]!=x_[j]: | |||||
flag = False | |||||
break | |||||
if flag: | |||||
corr_count += 1 | |||||
flag = True | |||||
start = idx_i + 1 | |||||
if y_i in self.end_tagidx_set: | |||||
pred_count += 1 | |||||
P = corr_count / (float(pred_count) + 1e-6) | |||||
R = corr_count / (float(truth_count) + 1e-6) | |||||
F = 2 * P * R / (P + R + 1e-6) | |||||
return {"P": P, 'R':R, 'F': F} | |||||
class SNLIEvaluator(Evaluator): | class SNLIEvaluator(Evaluator): | ||||
def __init__(self): | def __init__(self): | ||||
@@ -275,8 +320,3 @@ def pred_topk(y_prob, k=1): | |||||
(1, k)) | (1, k)) | ||||
y_prob_topk = y_prob[x_axis_index, y_pred_topk] | y_prob_topk = y_prob[x_axis_index, y_pred_topk] | ||||
return y_pred_topk, y_prob_topk | return y_pred_topk, y_prob_topk | ||||
if __name__ == '__main__': | |||||
y = np.array([1, 0, 1, 0, 1, 1]) | |||||
print(_label_types(y)) |
@@ -2,9 +2,7 @@ import numpy as np | |||||
import torch | import torch | ||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.preprocess import load_pickle | |||||
from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
from fastNLP.loader.dataset_loader import convert_seq2seq_dataset, convert_seq2tag_dataset, convert_seq_dataset | |||||
class Predictor(object): | class Predictor(object): | ||||
@@ -16,19 +14,9 @@ class Predictor(object): | |||||
Currently, Predictor does not support GPU. | Currently, Predictor does not support GPU. | ||||
""" | """ | ||||
def __init__(self, pickle_path, post_processor): | |||||
""" | |||||
:param pickle_path: str, the path to the pickle files. | |||||
:param post_processor: a function or callable object, that takes list of batch outputs as input | |||||
""" | |||||
def __init__(self): | |||||
self.batch_size = 1 | self.batch_size = 1 | ||||
self.batch_output = [] | self.batch_output = [] | ||||
self.pickle_path = pickle_path | |||||
self._post_processor = post_processor | |||||
self.label_vocab = load_pickle(self.pickle_path, "label2id.pkl") | |||||
self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl") | |||||
def predict(self, network, data): | def predict(self, network, data): | ||||
"""Perform inference using the trained model. | """Perform inference using the trained model. | ||||
@@ -37,9 +25,6 @@ class Predictor(object): | |||||
:param data: a DataSet object. | :param data: a DataSet object. | ||||
:return: list of list of strings, [num_examples, tag_seq_length] | :return: list of list of strings, [num_examples, tag_seq_length] | ||||
""" | """ | ||||
# transform strings into DataSet object | |||||
# data = self.prepare_input(data) | |||||
# turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
self.mode(network, test=True) | self.mode(network, test=True) | ||||
batch_output = [] | batch_output = [] | ||||
@@ -51,7 +36,7 @@ class Predictor(object): | |||||
prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
batch_output.append(prediction) | batch_output.append(prediction) | ||||
return self._post_processor(batch_output, self.label_vocab) | |||||
return batch_output | |||||
def mode(self, network, test=True): | def mode(self, network, test=True): | ||||
if test: | if test: | ||||
@@ -64,38 +49,6 @@ class Predictor(object): | |||||
y = network(**x) | y = network(**x) | ||||
return y | return y | ||||
def prepare_input(self, data): | |||||
"""Transform two-level list of strings into an DataSet object. | |||||
In the training pipeline, this is done by Preprocessor. But in inference time, we do not call Preprocessor. | |||||
:param data: list of list of strings. | |||||
:: | |||||
[ | |||||
[word_11, word_12, ...], | |||||
[word_21, word_22, ...], | |||||
... | |||||
] | |||||
:return data_set: a DataSet instance. | |||||
""" | |||||
assert isinstance(data, list) | |||||
data = convert_seq_dataset(data) | |||||
data.index_field("word_seq", self.word_vocab) | |||||
class SeqLabelInfer(Predictor): | |||||
def __init__(self, pickle_path): | |||||
print( | |||||
"[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor directly.") | |||||
super(SeqLabelInfer, self).__init__(pickle_path, seq_label_post_processor) | |||||
class ClassificationInfer(Predictor): | |||||
def __init__(self, pickle_path): | |||||
print( | |||||
"[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor directly.") | |||||
super(ClassificationInfer, self).__init__(pickle_path, text_classify_post_processor) | |||||
def seq_label_post_processor(batch_outputs, label_vocab): | def seq_label_post_processor(batch_outputs, label_vocab): | ||||
results = [] | results = [] | ||||
@@ -1,48 +0,0 @@ | |||||
import _pickle | |||||
import os | |||||
# the first vocab in dict with the index = 5 | |||||
def save_pickle(obj, pickle_path, file_name): | |||||
"""Save an object into a pickle file. | |||||
:param obj: an object | |||||
:param pickle_path: str, the directory where the pickle file is to be saved | |||||
:param file_name: str, the name of the pickle file. In general, it should be ended by "pkl". | |||||
""" | |||||
if not os.path.exists(pickle_path): | |||||
os.mkdir(pickle_path) | |||||
print("make dir {} before saving pickle file".format(pickle_path)) | |||||
with open(os.path.join(pickle_path, file_name), "wb") as f: | |||||
_pickle.dump(obj, f) | |||||
print("{} saved in {}".format(file_name, pickle_path)) | |||||
def load_pickle(pickle_path, file_name): | |||||
"""Load an object from a given pickle file. | |||||
:param pickle_path: str, the directory where the pickle file is. | |||||
:param file_name: str, the name of the pickle file. | |||||
:return obj: an object stored in the pickle | |||||
""" | |||||
with open(os.path.join(pickle_path, file_name), "rb") as f: | |||||
obj = _pickle.load(f) | |||||
print("{} loaded from {}".format(file_name, pickle_path)) | |||||
return obj | |||||
def pickle_exist(pickle_path, pickle_name): | |||||
"""Check if a given pickle file exists in the directory. | |||||
:param pickle_path: the directory of target pickle file | |||||
:param pickle_name: the filename of target pickle file | |||||
:return: True if file exists else False | |||||
""" | |||||
if not os.path.exists(pickle_path): | |||||
os.makedirs(pickle_path) | |||||
file_name = os.path.join(pickle_path, pickle_name) | |||||
if os.path.exists(file_name): | |||||
return True | |||||
else: | |||||
return False |
@@ -1,3 +1,5 @@ | |||||
from itertools import chain | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
@@ -44,6 +46,48 @@ class RandomSampler(BaseSampler): | |||||
return list(np.random.permutation(len(data_set))) | return list(np.random.permutation(len(data_set))) | ||||
class BucketSampler(BaseSampler): | |||||
def __init__(self, num_buckets=10, batch_size=32, seq_lens_field_name='seq_lens'): | |||||
self.num_buckets = num_buckets | |||||
self.batch_size = batch_size | |||||
self.seq_lens_field_name = seq_lens_field_name | |||||
def __call__(self, data_set): | |||||
seq_lens = data_set[self.seq_lens_field_name].content | |||||
total_sample_num = len(seq_lens) | |||||
bucket_indexes = [] | |||||
num_sample_per_bucket = total_sample_num // self.num_buckets | |||||
for i in range(self.num_buckets): | |||||
bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)]) | |||||
bucket_indexes[-1][1] = total_sample_num | |||||
sorted_seq_lens = list(sorted([(idx, seq_len) for | |||||
idx, seq_len in zip(range(total_sample_num), seq_lens)], | |||||
key=lambda x: x[1])) | |||||
batchs = [] | |||||
left_init_indexes = [] | |||||
for b_idx in range(self.num_buckets): | |||||
start_idx = bucket_indexes[b_idx][0] | |||||
end_idx = bucket_indexes[b_idx][1] | |||||
sorted_bucket_seq_lens = sorted_seq_lens[start_idx:end_idx] | |||||
left_init_indexes.extend([tup[0] for tup in sorted_bucket_seq_lens]) | |||||
num_batch_per_bucket = len(left_init_indexes) // self.batch_size | |||||
np.random.shuffle(left_init_indexes) | |||||
for i in range(num_batch_per_bucket): | |||||
batchs.append(left_init_indexes[i * self.batch_size:(i + 1) * self.batch_size]) | |||||
left_init_indexes = left_init_indexes[num_batch_per_bucket * self.batch_size:] | |||||
if (left_init_indexes) != 0: | |||||
batchs.append(left_init_indexes) | |||||
np.random.shuffle(batchs) | |||||
return list(chain(*batchs)) | |||||
def simple_sort_bucketing(lengths): | def simple_sort_bucketing(lengths): | ||||
""" | """ | ||||
@@ -63,6 +107,7 @@ def simple_sort_bucketing(lengths): | |||||
# TODO: need to return buckets | # TODO: need to return buckets | ||||
return [idx for idx, _ in sorted_lengths] | return [idx for idx, _ in sorted_lengths] | ||||
def k_means_1d(x, k, max_iter=100): | def k_means_1d(x, k, max_iter=100): | ||||
"""Perform k-means on 1-D data. | """Perform k-means on 1-D data. | ||||
@@ -117,4 +162,3 @@ def k_means_bucketing(lengths, buckets): | |||||
if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]: | if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]: | ||||
bucket_data[bucket_id].append(idx) | bucket_data[bucket_id].append(idx) | ||||
return bucket_data | return bucket_data | ||||
@@ -1,89 +1,58 @@ | |||||
import itertools | |||||
from collections import defaultdict | |||||
import torch | import torch | ||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.metrics import Evaluator | |||||
from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
from fastNLP.saver.logger import create_logger | |||||
logger = create_logger(__name__, "./train_test.log") | |||||
from fastNLP.core.utils import _build_args | |||||
class Tester(object): | class Tester(object): | ||||
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | ||||
def __init__(self, **kwargs): | |||||
""" | |||||
:param kwargs: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | |||||
""" | |||||
def __init__(self, data, model, batch_size=16, use_cuda=False): | |||||
super(Tester, self).__init__() | super(Tester, self).__init__() | ||||
""" | |||||
"default_args" provides default value for important settings. | |||||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||||
Otherwise, error will raise. | |||||
""" | |||||
default_args = {"batch_size": 8, | |||||
"use_cuda": False, | |||||
"pickle_path": "./save/", | |||||
"model_name": "dev_best_model.pkl", | |||||
"evaluator": Evaluator() | |||||
} | |||||
""" | |||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||||
This is used to warn users of essential settings in the training. | |||||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | |||||
""" | |||||
required_args = {} | |||||
for req_key in required_args: | |||||
if req_key not in kwargs: | |||||
logger.error("Tester lacks argument {}".format(req_key)) | |||||
raise ValueError("Tester lacks argument {}".format(req_key)) | |||||
for key in default_args: | |||||
if key in kwargs: | |||||
if isinstance(kwargs[key], type(default_args[key])): | |||||
default_args[key] = kwargs[key] | |||||
else: | |||||
msg = "Argument %s type mismatch: expected %s while get %s" % ( | |||||
key, type(default_args[key]), type(kwargs[key])) | |||||
logger.error(msg) | |||||
raise ValueError(msg) | |||||
else: | |||||
# Tester doesn't care about extra arguments | |||||
pass | |||||
print(default_args) | |||||
self.batch_size = default_args["batch_size"] | |||||
self.pickle_path = default_args["pickle_path"] | |||||
self.use_cuda = default_args["use_cuda"] | |||||
self._evaluator = default_args["evaluator"] | |||||
self._model = None | |||||
self.eval_history = [] # evaluation results of all batches | |||||
def test(self, network, dev_data): | |||||
self.use_cuda = use_cuda | |||||
self.data = data | |||||
self.batch_size = batch_size | |||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
self._model = network.cuda() | |||||
self._model = model.cuda() | |||||
else: | |||||
self._model = model | |||||
if hasattr(self._model, 'predict'): | |||||
assert callable(self._model.predict) | |||||
self._predict_func = self._model.predict | |||||
else: | else: | ||||
self._model = network | |||||
self._predict_func = self._model | |||||
assert hasattr(model, 'evaluate') | |||||
self._evaluator = model.evaluate | |||||
self.eval_history = [] # evaluation results of all batches | |||||
def test(self): | |||||
# turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
network = self._model | |||||
self.mode(network, is_test=True) | self.mode(network, is_test=True) | ||||
self.eval_history.clear() | self.eval_history.clear() | ||||
output_list = [] | |||||
truth_list = [] | |||||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | |||||
output, truths = defaultdict(list), defaultdict(list) | |||||
data_iterator = Batch(self.data, self.batch_size, sampler=RandomSampler(), as_numpy=False) | |||||
for batch_x, batch_y in data_iterator: | |||||
with torch.no_grad(): | |||||
with torch.no_grad(): | |||||
for batch_x, batch_y in data_iterator: | |||||
prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
output_list.append(prediction) | |||||
truth_list.append(batch_y) | |||||
eval_results = self.evaluate(output_list, truth_list) | |||||
assert isinstance(prediction, dict) | |||||
for k, v in prediction.items(): | |||||
output[k].append(v) | |||||
for k, v in batch_y.items(): | |||||
truths[k].append(v) | |||||
for k, v in output.items(): | |||||
output[k] = itertools.chain(*v) | |||||
for k, v in truths.items(): | |||||
truths[k] = itertools.chain(*v) | |||||
args = _build_args(self._evaluator, **output, **truths) | |||||
eval_results = self._evaluator(**args) | |||||
print("[tester] {}".format(self.print_eval_results(eval_results))) | print("[tester] {}".format(self.print_eval_results(eval_results))) | ||||
logger.info("[tester] {}".format(self.print_eval_results(eval_results))) | |||||
self.mode(network, is_test=False) | |||||
return eval_results | |||||
def mode(self, model, is_test=False): | def mode(self, model, is_test=False): | ||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
@@ -99,18 +68,10 @@ class Tester(object): | |||||
def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
"""A forward pass of the model. """ | """A forward pass of the model. """ | ||||
y = network(**x) | |||||
x = _build_args(network.forward, **x) | |||||
y = self._predict_func(**x) | |||||
return y | return y | ||||
def evaluate(self, predict, truth): | |||||
"""Compute evaluation metrics. | |||||
:param predict: list of Tensor | |||||
:param truth: list of dict | |||||
:return eval_results: can be anything. It will be stored in self.eval_history | |||||
""" | |||||
return self._evaluator(predict, truth) | |||||
def print_eval_results(self, results): | def print_eval_results(self, results): | ||||
"""Override this method to support more print formats. | """Override this method to support more print formats. | ||||
@@ -118,24 +79,3 @@ class Tester(object): | |||||
""" | """ | ||||
return ", ".join([str(key) + "=" + str(value) for key, value in results.items()]) | return ", ".join([str(key) + "=" + str(value) for key, value in results.items()]) | ||||
class SeqLabelTester(Tester): | |||||
def __init__(self, **test_args): | |||||
print( | |||||
"[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester directly.") | |||||
super(SeqLabelTester, self).__init__(**test_args) | |||||
class ClassificationTester(Tester): | |||||
def __init__(self, **test_args): | |||||
print( | |||||
"[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester directly.") | |||||
super(ClassificationTester, self).__init__(**test_args) | |||||
class SNLITester(Tester): | |||||
def __init__(self, **test_args): | |||||
print( | |||||
"[FastNLP Warning] SNLITester will be deprecated. Please use Tester directly.") | |||||
super(SNLITester, self).__init__(**test_args) |
@@ -1,147 +1,117 @@ | |||||
import os | |||||
import time | import time | ||||
from datetime import timedelta | from datetime import timedelta | ||||
from datetime import datetime | |||||
import warnings | |||||
from collections import defaultdict | |||||
import os | |||||
import itertools | |||||
import shutil | |||||
import torch | |||||
from tensorboardX import SummaryWriter | from tensorboardX import SummaryWriter | ||||
import torch | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.loss import Loss | from fastNLP.core.loss import Loss | ||||
from fastNLP.core.metrics import Evaluator | from fastNLP.core.metrics import Evaluator | ||||
from fastNLP.core.optimizer import Optimizer | from fastNLP.core.optimizer import Optimizer | ||||
from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
from fastNLP.core.tester import SeqLabelTester, ClassificationTester, SNLITester | |||||
from fastNLP.saver.logger import create_logger | |||||
from fastNLP.saver.model_saver import ModelSaver | |||||
logger = create_logger(__name__, "./train_test.log") | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.tester import Tester | |||||
from fastNLP.core.utils import _check_arg_dict_list | |||||
from fastNLP.core.utils import _build_args | |||||
from fastNLP.core.utils import _syn_model_data | |||||
from fastNLP.core.utils import get_func_signature | |||||
class Trainer(object): | class Trainer(object): | ||||
"""Operations of training a model, including data loading, gradient descent, and validation. | |||||
"""Main Training Loop | |||||
""" | """ | ||||
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, | |||||
dev_data=None, use_cuda=False, save_path="./save", | |||||
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, | |||||
**kwargs): | |||||
super(Trainer, self).__init__() | |||||
def __init__(self, **kwargs): | |||||
""" | |||||
:param kwargs: dict of (key, value), or dict-like object. key is str. | |||||
self.train_data = train_data | |||||
self.dev_data = dev_data # If None, No validation. | |||||
self.model = model | |||||
self.n_epochs = int(n_epochs) | |||||
self.batch_size = int(batch_size) | |||||
self.use_cuda = bool(use_cuda) | |||||
self.save_path = save_path | |||||
self.print_every = int(print_every) | |||||
self.validate_every = int(validate_every) | |||||
self._best_accuracy = 0 | |||||
if need_check_code: | |||||
_check_code(dataset=train_data, model=model, dev_data=dev_data) | |||||
model_name = model.__class__.__name__ | |||||
assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name) | |||||
self.loss_func = self.model.get_loss | |||||
if isinstance(optimizer, torch.optim.Optimizer): | |||||
self.optimizer = optimizer | |||||
else: | |||||
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | |||||
The base trainer requires the following keys: | |||||
- epochs: int, the number of epochs in training | |||||
- validate: bool, whether or not to validate on dev set | |||||
- batch_size: int | |||||
- pickle_path: str, the path to pickle files for pre-processing | |||||
""" | |||||
super(Trainer, self).__init__() | |||||
assert hasattr(self.model, 'evaluate'), "model {} has to have a 'evaluate' function.".format(model_name) | |||||
self.evaluator = self.model.evaluate | |||||
""" | |||||
"default_args" provides default value for important settings. | |||||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||||
Otherwise, error will raise. | |||||
""" | |||||
default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | |||||
"save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1, | |||||
"loss": Loss(None), # used to pass type check | |||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | |||||
"evaluator": Evaluator() | |||||
} | |||||
""" | |||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||||
This is used to warn users of essential settings in the training. | |||||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | |||||
""" | |||||
required_args = {} | |||||
if self.dev_data is not None: | |||||
self.tester = Tester(model=self.model, | |||||
data=self.dev_data, | |||||
batch_size=self.batch_size, | |||||
use_cuda=self.use_cuda) | |||||
for req_key in required_args: | |||||
if req_key not in kwargs: | |||||
logger.error("Trainer lacks argument {}".format(req_key)) | |||||
raise ValueError("Trainer lacks argument {}".format(req_key)) | |||||
for k, v in kwargs.items(): | |||||
setattr(self, k, v) | |||||
for key in default_args: | |||||
if key in kwargs: | |||||
if isinstance(kwargs[key], type(default_args[key])): | |||||
default_args[key] = kwargs[key] | |||||
else: | |||||
msg = "Argument %s type mismatch: expected %s while get %s" % ( | |||||
key, type(default_args[key]), type(kwargs[key])) | |||||
logger.error(msg) | |||||
raise ValueError(msg) | |||||
else: | |||||
# Trainer doesn't care about extra arguments | |||||
pass | |||||
print(default_args) | |||||
self.n_epochs = default_args["epochs"] | |||||
self.batch_size = default_args["batch_size"] | |||||
self.pickle_path = default_args["pickle_path"] | |||||
self.validate = default_args["validate"] | |||||
self.save_best_dev = default_args["save_best_dev"] | |||||
self.use_cuda = default_args["use_cuda"] | |||||
self.model_name = default_args["model_name"] | |||||
self.print_every_step = default_args["print_every_step"] | |||||
self._model = None | |||||
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None | |||||
self._optimizer = None | |||||
self._optimizer_proto = default_args["optimizer"] | |||||
self._evaluator = default_args["evaluator"] | |||||
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') | |||||
self._graph_summaried = False | |||||
self._best_accuracy = 0.0 | |||||
def train(self, network, train_data, dev_data=None): | |||||
"""General Training Procedure | |||||
:param network: a model | |||||
:param train_data: a DataSet instance, the training data | |||||
:param dev_data: a DataSet instance, the validation data (optional) | |||||
self.step = 0 | |||||
self.start_time = None # start timestamp | |||||
# print(self.__dict__) | |||||
def train(self): | |||||
"""Start Training. | |||||
:return: | |||||
""" | """ | ||||
# transfer model to gpu if available | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
self._model = network.cuda() | |||||
# self._model is used to access model-specific loss | |||||
else: | |||||
self._model = network | |||||
# define Tester over dev data | |||||
if self.validate: | |||||
default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path, | |||||
"use_cuda": self.use_cuda, "evaluator": self._evaluator} | |||||
validator = self._create_validator(default_valid_args) | |||||
logger.info("validator defined as {}".format(str(validator))) | |||||
# optimizer and loss | |||||
self.define_optimizer() | |||||
logger.info("optimizer defined as {}".format(str(self._optimizer))) | |||||
self.define_loss() | |||||
logger.info("loss function defined as {}".format(str(self._loss_func))) | |||||
# main training procedure | |||||
start = time.time() | |||||
logger.info("training epochs started") | |||||
for epoch in range(1, self.n_epochs + 1): | |||||
logger.info("training epoch {}".format(epoch)) | |||||
# turn on network training mode | |||||
self.mode(network, is_test=False) | |||||
# prepare mini-batch iterator | |||||
data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(), | |||||
use_cuda=self.use_cuda) | |||||
logger.info("prepared data iterator") | |||||
# one forward and backward pass | |||||
self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch) | |||||
# validation | |||||
if self.validate: | |||||
if dev_data is None: | |||||
raise RuntimeError( | |||||
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.") | |||||
logger.info("validation started") | |||||
validator.test(network, dev_data) | |||||
def _train_step(self, data_iterator, network, **kwargs): | |||||
try: | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
self.model = self.model.cuda() | |||||
self.mode(self.model, is_test=False) | |||||
start = time.time() | |||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||||
print("training epochs started " + self.start_time) | |||||
if self.save_path is None: | |||||
class psudoSW: | |||||
def __getattr__(self, item): | |||||
def pass_func(*args, **kwargs): | |||||
pass | |||||
return pass_func | |||||
self._summary_writer = psudoSW() | |||||
else: | |||||
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) | |||||
self._summary_writer = SummaryWriter(path) | |||||
epoch = 1 | |||||
while epoch <= self.n_epochs: | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False) | |||||
self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) | |||||
# validate_every override validation at end of epochs | |||||
if self.dev_data and self.validate_every <= 0: | |||||
self.do_validation() | |||||
epoch += 1 | |||||
finally: | |||||
self._summary_writer.close() | |||||
del self._summary_writer | |||||
def _train_epoch(self, data_iterator, model, epoch, dev_data, start, **kwargs): | |||||
"""Training process in one epoch. | """Training process in one epoch. | ||||
kwargs should contain: | kwargs should contain: | ||||
@@ -149,24 +119,36 @@ class Trainer(object): | |||||
- start: time.time(), the starting time of this step. | - start: time.time(), the starting time of this step. | ||||
- epoch: int, | - epoch: int, | ||||
""" | """ | ||||
step = 0 | |||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
prediction = self.data_forward(network, batch_x) | |||||
prediction = self.data_forward(model, batch_x) | |||||
loss = self.get_loss(prediction, batch_y) | loss = self.get_loss(prediction, batch_y) | ||||
self.grad_backward(loss) | self.grad_backward(loss) | ||||
self.update() | self.update() | ||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=step) | |||||
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: | |||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | |||||
for name, param in self.model.named_parameters(): | |||||
if param.requires_grad: | |||||
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) | |||||
if self.print_every > 0 and self.step % self.print_every == 0: | |||||
end = time.time() | end = time.time() | ||||
diff = timedelta(seconds=round(end - kwargs["start"])) | |||||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( | |||||
kwargs["epoch"], step, loss.data, diff) | |||||
diff = timedelta(seconds=round(end - start)) | |||||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( | |||||
epoch, self.step, loss.data, diff) | |||||
print(print_output) | print(print_output) | ||||
logger.info(print_output) | |||||
step += 1 | |||||
if self.validate_every > 0 and self.step % self.validate_every == 0: | |||||
self.do_validation() | |||||
self.step += 1 | |||||
def do_validation(self): | |||||
res = self.tester.test() | |||||
for name, num in res.items(): | |||||
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) | |||||
if self.save_path is not None and self.best_eval_result(res): | |||||
self.save_model(self.model, 'best_model_' + self.start_time) | |||||
def mode(self, model, is_test=False): | def mode(self, model, is_test=False): | ||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
@@ -180,24 +162,15 @@ class Trainer(object): | |||||
else: | else: | ||||
model.train() | model.train() | ||||
def define_optimizer(self): | |||||
"""Define framework-specific optimizer specified by the models. | |||||
""" | |||||
self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters()) | |||||
def update(self): | def update(self): | ||||
"""Perform weight update on a model. | """Perform weight update on a model. | ||||
For PyTorch, just call optimizer to update. | |||||
""" | """ | ||||
self._optimizer.step() | |||||
self.optimizer.step() | |||||
def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
x = _build_args(network.forward, **x) | |||||
y = network(**x) | y = network(**x) | ||||
if not self._graph_summaried: | |||||
# self._summary_writer.add_graph(network, x, verbose=False) | |||||
self._graph_summaried = True | |||||
return y | return y | ||||
def grad_backward(self, loss): | def grad_backward(self, loss): | ||||
@@ -207,7 +180,7 @@ class Trainer(object): | |||||
For PyTorch, just do "loss.backward()" | For PyTorch, just do "loss.backward()" | ||||
""" | """ | ||||
self._model.zero_grad() | |||||
self.model.zero_grad() | |||||
loss.backward() | loss.backward() | ||||
def get_loss(self, predict, truth): | def get_loss(self, predict, truth): | ||||
@@ -217,90 +190,216 @@ class Trainer(object): | |||||
:param truth: ground truth label vector | :param truth: ground truth label vector | ||||
:return: a scalar | :return: a scalar | ||||
""" | """ | ||||
if len(truth) > 1: | |||||
raise NotImplementedError("Not ready to handle multi-labels.") | |||||
truth = list(truth.values())[0] if len(truth) > 0 else None | |||||
return self._loss_func(predict, truth) | |||||
def define_loss(self): | |||||
"""Define a loss for the trainer. | |||||
If the model defines a loss, use model's loss. | |||||
Otherwise, Trainer must has a loss argument, use it as loss. | |||||
These two losses cannot be defined at the same time. | |||||
Trainer does not handle loss definition or choose default losses. | |||||
""" | |||||
# if hasattr(self._model, "loss") and self._loss_func is not None: | |||||
# raise ValueError("Both the model and Trainer define loss. Please take out your loss.") | |||||
if hasattr(self._model, "loss"): | |||||
self._loss_func = self._model.loss | |||||
logger.info("The model has a loss function, use it.") | |||||
assert isinstance(predict, dict) and isinstance(truth, dict) | |||||
args = _build_args(self.loss_func, **predict, **truth) | |||||
return self.loss_func(**args) | |||||
def save_model(self, model, model_name, only_param=False): | |||||
model_name = os.path.join(self.save_path, model_name) | |||||
if only_param: | |||||
torch.save(model.state_dict(), model_name) | |||||
else: | else: | ||||
if self._loss_func is None: | |||||
raise ValueError("Please specify a loss function.") | |||||
logger.info("The model didn't define loss, use Trainer's loss.") | |||||
torch.save(model, model_name) | |||||
def best_eval_result(self, validator): | |||||
def best_eval_result(self, metrics): | |||||
"""Check if the current epoch yields better validation results. | """Check if the current epoch yields better validation results. | ||||
:param validator: a Tester instance | |||||
:return: bool, True means current results on dev set is the best. | :return: bool, True means current results on dev set is the best. | ||||
""" | """ | ||||
loss, accuracy = validator.metrics | |||||
if isinstance(metrics, tuple): | |||||
loss, metrics = metrics | |||||
if isinstance(metrics, dict): | |||||
if len(metrics) == 1: | |||||
accuracy = list(metrics.values())[0] | |||||
else: | |||||
accuracy = metrics[self.eval_sort_key] | |||||
else: | |||||
accuracy = metrics | |||||
if accuracy > self._best_accuracy: | if accuracy > self._best_accuracy: | ||||
self._best_accuracy = accuracy | self._best_accuracy = accuracy | ||||
return True | return True | ||||
else: | else: | ||||
return False | return False | ||||
def save_model(self, network, model_name): | |||||
"""Save this model with such a name. | |||||
This method may be called multiple times by Trainer to overwritten a better model. | |||||
:param network: the PyTorch model | |||||
:param model_name: str | |||||
""" | |||||
if model_name[-4:] != ".pkl": | |||||
model_name += ".pkl" | |||||
ModelSaver(os.path.join(self.pickle_path, model_name)).save_pytorch(network) | |||||
def _create_validator(self, valid_args): | |||||
raise NotImplementedError | |||||
class SeqLabelTrainer(Trainer): | |||||
"""Trainer for Sequence Labeling | |||||
""" | |||||
def __init__(self, **kwargs): | |||||
print( | |||||
"[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer directly.") | |||||
super(SeqLabelTrainer, self).__init__(**kwargs) | |||||
def _create_validator(self, valid_args): | |||||
return SeqLabelTester(**valid_args) | |||||
class ClassificationTrainer(Trainer): | |||||
"""Trainer for text classification.""" | |||||
def __init__(self, **train_args): | |||||
print( | |||||
"[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer directly.") | |||||
super(ClassificationTrainer, self).__init__(**train_args) | |||||
def _create_validator(self, valid_args): | |||||
return ClassificationTester(**valid_args) | |||||
class SNLITrainer(Trainer): | |||||
"""Trainer for text SNLI.""" | |||||
def __init__(self, **train_args): | |||||
print( | |||||
"[FastNLP Warning] SNLITrainer will be deprecated. Please use Trainer directly.") | |||||
super(SNLITrainer, self).__init__(**train_args) | |||||
DEFAULT_CHECK_BATCH_SIZE = 2 | |||||
DEFAULT_CHECK_NUM_BATCH = 2 | |||||
IGNORE_CHECK_LEVEL = 0 | |||||
WARNING_CHECK_LEVEL = 1 | |||||
STRICT_CHECK_LEVEL = 2 | |||||
def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL): | |||||
# check get_loss 方法 | |||||
model_name = model.__class__.__name__ | |||||
if not hasattr(model, 'get_loss'): | |||||
raise AttributeError("{} has to have a 'get_loss' function.".format(model_name)) | |||||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||||
for batch_count, (batch_x, batch_y) in enumerate(batch): | |||||
_syn_model_data(model, batch_x, batch_y) | |||||
# forward check | |||||
if batch_count==0: | |||||
_check_forward_error(model_func=model.forward, check_level=check_level, | |||||
batch_x=batch_x) | |||||
refined_batch_x = _build_args(model.forward, **batch_x) | |||||
output = model(**refined_batch_x) | |||||
func_signature = get_func_signature(model.forward) | |||||
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature) | |||||
# loss check | |||||
if batch_count == 0: | |||||
_check_loss_evaluate(prev_func=model.forward, func=model.get_loss, check_level=check_level, | |||||
output=output, batch_y=batch_y) | |||||
loss_input = _build_args(model.get_loss, **output, **batch_y) | |||||
loss = model.get_loss(**loss_input) | |||||
# check loss output | |||||
if batch_count == 0: | |||||
if not isinstance(loss, torch.Tensor): | |||||
raise ValueError("The return value of {}.get_loss() should be torch.Tensor, but {} got.". | |||||
format(model_name, type(loss))) | |||||
if len(loss.size())!=0: | |||||
raise ValueError("The size of return value of {}.get_loss() is {}, should be torch.size([])".format( | |||||
model_name, loss.size() | |||||
)) | |||||
loss.backward() | |||||
model.zero_grad() | |||||
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: | |||||
break | |||||
if dev_data is not None: | |||||
if not hasattr(model, 'evaluate'): | |||||
raise AttributeError("{} has to have a 'evaluate' function to do evaluation. Or set" | |||||
"dev_data to 'None'." | |||||
.format(model_name)) | |||||
outputs, truths = defaultdict(list), defaultdict(list) | |||||
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||||
with torch.no_grad(): | |||||
for batch_count, (batch_x, batch_y) in enumerate(dev_batch): | |||||
_syn_model_data(model, batch_x, batch_y) | |||||
if hasattr(model, 'predict'): | |||||
refined_batch_x = _build_args(model.predict, **batch_x) | |||||
prev_func = model.predict | |||||
output = prev_func(**refined_batch_x) | |||||
func_signature = get_func_signature(model.predict) | |||||
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature) | |||||
else: | |||||
refined_batch_x = _build_args(model.forward, **batch_x) | |||||
prev_func = model.forward | |||||
output = prev_func(**refined_batch_x) | |||||
for k, v in output.items(): | |||||
outputs[k].append(v) | |||||
for k, v in batch_y.items(): | |||||
truths[k].append(v) | |||||
if batch_count+1>DEFAULT_CHECK_NUM_BATCH: | |||||
break | |||||
for k, v in outputs.items(): | |||||
outputs[k] = itertools.chain(*v) | |||||
for k, v in truths.items(): | |||||
truths[k] = itertools.chain(*v) | |||||
_check_loss_evaluate(prev_func=prev_func, func=model.evaluate, check_level=check_level, | |||||
output=outputs, batch_y=truths) | |||||
refined_input = _build_args(model.evaluate, **outputs, **truths) | |||||
metrics = model.evaluate(**refined_input) | |||||
func_signature = get_func_signature(model.evaluate) | |||||
assert isinstance(metrics, dict), "The return value of {} should be dict.". \ | |||||
format(func_signature) | |||||
def _check_forward_error(model_func, check_level, batch_x): | |||||
check_res = _check_arg_dict_list(model_func, batch_x) | |||||
_missing = '' | |||||
_unused = '' | |||||
func_signature = get_func_signature(model_func) | |||||
if len(check_res.missing)!=0: | |||||
_missing = "Function {} misses {}, only provided with {}, " \ | |||||
".\n".format(func_signature, check_res.missing, | |||||
list(batch_x.keys())) | |||||
if len(check_res.unused)!=0: | |||||
if len(check_res.unused) > 1: | |||||
_unused = "{} are not used ".format(check_res.unused) | |||||
else: | |||||
_unused = "{} is not used ".format(check_res.unused) | |||||
_unused += "in function {}.\n".format(func_signature) | |||||
if _missing: | |||||
if len(_unused)>0 and STRICT_CHECK_LEVEL: | |||||
_error_str = "(1).{}\n(2).{}".format(_missing, _unused) | |||||
else: | |||||
_error_str = _missing | |||||
# TODO 这里可能需要自定义一些Error类型 | |||||
raise TypeError(_error_str) | |||||
if _unused: | |||||
if check_level == STRICT_CHECK_LEVEL: | |||||
# TODO 这里可能需要自定义一些Error类型 | |||||
raise ValueError(_unused) | |||||
elif check_level == WARNING_CHECK_LEVEL: | |||||
warnings.warn(message=_unused) | |||||
def _check_loss_evaluate(prev_func, func, check_level, output, batch_y): | |||||
check_res = _check_arg_dict_list(func, [output, batch_y]) | |||||
_missing = '' | |||||
_unused = '' | |||||
_duplicated = '' | |||||
func_signature = get_func_signature(func) | |||||
prev_func_signature = get_func_signature(prev_func) | |||||
if len(check_res.missing)>0: | |||||
_missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \ | |||||
"{}(from target in Dataset)." \ | |||||
.format(func_signature, check_res.missing, | |||||
list(output.keys()), prev_func_signature, | |||||
list(batch_y.keys())) | |||||
if len(check_res.unused)>0: | |||||
if len(check_res.unused) > 1: | |||||
_unused = "{} are not used ".format(check_res.unused) | |||||
else: | |||||
_unused = "{} is not used ".format(check_res.unused) | |||||
_unused += "in function {}.\n".format(func_signature) | |||||
if len(check_res.duplicated)>0: | |||||
if len(check_res.duplicated) > 1: | |||||
_duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \ | |||||
"them in {} at the same time.".format(check_res.duplicated, | |||||
func_signature, | |||||
check_res.duplicated, | |||||
prev_func_signature) | |||||
else: | |||||
_duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \ | |||||
"it in {} at the same time.".format(check_res.duplicated, | |||||
func_signature, | |||||
check_res.duplicated, | |||||
prev_func_signature) | |||||
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0) | |||||
if _number_errs > 0: | |||||
_error_strs = [] | |||||
if _number_errs > 1: | |||||
count = 0 | |||||
order_words = ['Firstly', 'Secondly', 'Thirdly'] | |||||
if _missing: | |||||
_error_strs.append('{}, {}'.format(order_words[count], _missing)) | |||||
count += 1 | |||||
if _duplicated: | |||||
_error_strs.append('{}, {}'.format(order_words[count], _duplicated)) | |||||
count += 1 | |||||
if _unused and check_level == STRICT_CHECK_LEVEL: | |||||
_error_strs.append('{}, {}'.format(order_words[count], _unused)) | |||||
else: | |||||
if _unused: | |||||
if check_level == STRICT_CHECK_LEVEL: | |||||
# TODO 这里可能需要自定义一些Error类型 | |||||
_error_strs.append(_unused) | |||||
elif check_level == WARNING_CHECK_LEVEL: | |||||
_unused = _unused.strip() | |||||
warnings.warn(_unused) | |||||
else: | |||||
if _missing: | |||||
_error_strs.append(_missing) | |||||
if _duplicated: | |||||
_error_strs.append(_duplicated) | |||||
def _create_validator(self, valid_args): | |||||
return SNLITester(**valid_args) | |||||
if _error_strs: | |||||
raise ValueError('\n' + '\n'.join(_error_strs)) |
@@ -0,0 +1,127 @@ | |||||
import _pickle | |||||
import inspect | |||||
import os | |||||
from collections import Counter | |||||
from collections import namedtuple | |||||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False) | |||||
def save_pickle(obj, pickle_path, file_name): | |||||
"""Save an object into a pickle file. | |||||
:param obj: an object | |||||
:param pickle_path: str, the directory where the pickle file is to be saved | |||||
:param file_name: str, the name of the pickle file. In general, it should be ended by "pkl". | |||||
""" | |||||
if not os.path.exists(pickle_path): | |||||
os.mkdir(pickle_path) | |||||
print("make dir {} before saving pickle file".format(pickle_path)) | |||||
with open(os.path.join(pickle_path, file_name), "wb") as f: | |||||
_pickle.dump(obj, f) | |||||
print("{} saved in {}".format(file_name, pickle_path)) | |||||
def load_pickle(pickle_path, file_name): | |||||
"""Load an object from a given pickle file. | |||||
:param pickle_path: str, the directory where the pickle file is. | |||||
:param file_name: str, the name of the pickle file. | |||||
:return obj: an object stored in the pickle | |||||
""" | |||||
with open(os.path.join(pickle_path, file_name), "rb") as f: | |||||
obj = _pickle.load(f) | |||||
print("{} loaded from {}".format(file_name, pickle_path)) | |||||
return obj | |||||
def pickle_exist(pickle_path, pickle_name): | |||||
"""Check if a given pickle file exists in the directory. | |||||
:param pickle_path: the directory of target pickle file | |||||
:param pickle_name: the filename of target pickle file | |||||
:return: True if file exists else False | |||||
""" | |||||
if not os.path.exists(pickle_path): | |||||
os.makedirs(pickle_path) | |||||
file_name = os.path.join(pickle_path, pickle_name) | |||||
if os.path.exists(file_name): | |||||
return True | |||||
else: | |||||
return False | |||||
def _build_args(func, **kwargs): | |||||
spect = inspect.getfullargspec(func) | |||||
if spect.varkw is not None: | |||||
return kwargs | |||||
needed_args = set(spect.args) | |||||
defaults = [] | |||||
if spect.defaults is not None: | |||||
defaults = [arg for arg in spect.defaults] | |||||
start_idx = len(spect.args) - len(defaults) | |||||
output = {name: default for name, default in zip(spect.args[start_idx:], defaults)} | |||||
output.update({name: val for name, val in kwargs.items() if name in needed_args}) | |||||
return output | |||||
# check args | |||||
def _check_arg_dict_list(func, args): | |||||
if isinstance(args, dict): | |||||
arg_dict_list = [args] | |||||
else: | |||||
arg_dict_list = args | |||||
assert callable(func) and isinstance(arg_dict_list, (list, tuple)) | |||||
assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) | |||||
spect = inspect.getfullargspec(func) | |||||
assert spect.varargs is None, 'Positional Arguments({}) are not supported.'.format(spect.varargs) | |||||
all_args = set([arg for arg in spect.args if arg!='self']) | |||||
defaults = [] | |||||
if spect.defaults is not None: | |||||
defaults = [arg for arg in spect.defaults] | |||||
start_idx = len(spect.args) - len(defaults) | |||||
default_args = set(spect.args[start_idx:]) | |||||
require_args = all_args - default_args | |||||
input_arg_count = Counter() | |||||
for arg_dict in arg_dict_list: | |||||
input_arg_count.update(arg_dict.keys()) | |||||
duplicated = [name for name, val in input_arg_count.items() if val > 1] | |||||
input_args = set(input_arg_count.keys()) | |||||
missing = list(require_args - input_args) | |||||
unused = list(input_args - all_args) | |||||
return CheckRes(missing=missing, | |||||
unused=unused, | |||||
duplicated=duplicated, | |||||
required=list(require_args), | |||||
all_needed=list(all_args)) | |||||
def get_func_signature(func): | |||||
# can only be used in function or class method | |||||
if inspect.ismethod(func): | |||||
class_name = func.__self__.__class__.__name__ | |||||
signature = inspect.signature(func) | |||||
signature_str = str(signature) | |||||
if len(signature_str)>2: | |||||
_self = '(self, ' | |||||
else: | |||||
_self = '(self' | |||||
signature_str = class_name + '.' + func.__name__ + _self + signature_str[1:] | |||||
return signature_str | |||||
elif inspect.isfunction(func): | |||||
signature = inspect.signature(func) | |||||
signature_str = str(signature) | |||||
signature_str = func.__name__ + signature_str | |||||
return signature_str | |||||
# move data to model's device | |||||
import torch | |||||
def _syn_model_data(model, *args): | |||||
assert len(model.state_dict())!=0, "This model has no parameter." | |||||
device = model.parameters().__next__().device | |||||
for arg in args: | |||||
if isinstance(arg, dict): | |||||
for key, value in arg.items(): | |||||
if isinstance(value, torch.Tensor): | |||||
arg[key] = value.to(device) | |||||
else: | |||||
raise ValueError("Only support dict type right now.") |
@@ -1,19 +1,15 @@ | |||||
from collections import Counter | |||||
from copy import deepcopy | from copy import deepcopy | ||||
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | ||||
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | ||||
DEFAULT_RESERVED_LABEL = ['<reserved-2>', | |||||
'<reserved-3>', | |||||
'<reserved-4>'] # dict index = 2~4 | |||||
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, | |||||
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, | |||||
DEFAULT_RESERVED_LABEL[2]: 4} | |||||
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1} | |||||
def isiterable(p_object): | def isiterable(p_object): | ||||
try: | try: | ||||
it = iter(p_object) | |||||
_ = iter(p_object) | |||||
except TypeError: | except TypeError: | ||||
return False | return False | ||||
return True | return True | ||||
@@ -23,10 +19,8 @@ def check_build_vocab(func): | |||||
def _wrapper(self, *args, **kwargs): | def _wrapper(self, *args, **kwargs): | ||||
if self.word2idx is None: | if self.word2idx is None: | ||||
self.build_vocab() | self.build_vocab() | ||||
self.build_reverse_vocab() | |||||
elif self.idx2word is None: | |||||
self.build_reverse_vocab() | |||||
return func(self, *args, **kwargs) | return func(self, *args, **kwargs) | ||||
return _wrapper | return _wrapper | ||||
@@ -41,6 +35,7 @@ class Vocabulary(object): | |||||
vocab["word"] | vocab["word"] | ||||
vocab.to_word(5) | vocab.to_word(5) | ||||
""" | """ | ||||
def __init__(self, need_default=True, max_size=None, min_freq=None): | def __init__(self, need_default=True, max_size=None, min_freq=None): | ||||
""" | """ | ||||
:param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. | :param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. | ||||
@@ -49,61 +44,85 @@ class Vocabulary(object): | |||||
""" | """ | ||||
self.max_size = max_size | self.max_size = max_size | ||||
self.min_freq = min_freq | self.min_freq = min_freq | ||||
self.word_count = {} | |||||
self.word_count = Counter() | |||||
self.has_default = need_default | self.has_default = need_default | ||||
if self.has_default: | |||||
self.padding_label = DEFAULT_PADDING_LABEL | |||||
self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||||
else: | |||||
self.padding_label = None | |||||
self.unknown_label = None | |||||
self.word2idx = None | self.word2idx = None | ||||
self.idx2word = None | self.idx2word = None | ||||
def update(self, word): | |||||
"""add word or list of words into Vocabulary | |||||
def update(self, word_lst): | |||||
"""Add a list of words into the vocabulary. | |||||
:param word: a list of string or a single string | |||||
:param list word_lst: a list of strings | |||||
""" | """ | ||||
if not isinstance(word, str) and isiterable(word): | |||||
# it's a nested list | |||||
for w in word: | |||||
self.update(w) | |||||
else: | |||||
# it's a word to be added | |||||
if word not in self.word_count: | |||||
self.word_count[word] = 1 | |||||
else: | |||||
self.word_count[word] += 1 | |||||
self.word2idx = None | |||||
return self | |||||
self.word_count.update(word_lst) | |||||
def add(self, word): | |||||
"""Add a single word into the vocabulary. | |||||
:param str word: a word or token. | |||||
""" | |||||
self.word_count[word] += 1 | |||||
def add_word(self, word): | |||||
"""Add a single word into the vocabulary. | |||||
:param str word: a word or token. | |||||
""" | |||||
self.add(word) | |||||
def add_word_lst(self, word_lst): | |||||
"""Add a list of words into the vocabulary. | |||||
:param list word_lst: a list of strings | |||||
""" | |||||
self.update(word_lst) | |||||
def build_vocab(self): | def build_vocab(self): | ||||
"""build 'word to index' dict, and filter the word using `max_size` and `min_freq` | |||||
"""Build 'word to index' dict, and filter the word using `max_size` and `min_freq`. | |||||
""" | """ | ||||
if self.has_default: | if self.has_default: | ||||
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | ||||
self.padding_label = DEFAULT_PADDING_LABEL | |||||
self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||||
self.word2idx[self.unknown_label] = self.word2idx.pop(DEFAULT_UNKNOWN_LABEL) | |||||
self.word2idx[self.padding_label] = self.word2idx.pop(DEFAULT_PADDING_LABEL) | |||||
else: | else: | ||||
self.word2idx = {} | self.word2idx = {} | ||||
self.padding_label = None | |||||
self.unknown_label = None | |||||
words = sorted(self.word_count.items(), key=lambda kv: kv[1], reverse=True) | |||||
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | |||||
words = self.word_count.most_common(max_size) | |||||
if self.min_freq is not None: | if self.min_freq is not None: | ||||
words = list(filter(lambda kv: kv[1] >= self.min_freq, words)) | |||||
if self.max_size is not None and len(words) > self.max_size: | |||||
words = words[:self.max_size] | |||||
for w, _ in words: | |||||
self.word2idx[w] = len(self.word2idx) | |||||
words = filter(lambda kv: kv[1] >= self.min_freq, words) | |||||
start_idx = len(self.word2idx) | |||||
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | |||||
self.build_reverse_vocab() | |||||
def build_reverse_vocab(self): | def build_reverse_vocab(self): | ||||
"""build 'index to word' dict based on 'word to index' dict | |||||
"""Build 'index to word' dict based on 'word to index' dict. | |||||
""" | """ | ||||
self.idx2word = {self.word2idx[w] : w for w in self.word2idx} | |||||
self.idx2word = {i: w for w, i in self.word2idx.items()} | |||||
@check_build_vocab | @check_build_vocab | ||||
def __len__(self): | def __len__(self): | ||||
return len(self.word2idx) | return len(self.word2idx) | ||||
@check_build_vocab | @check_build_vocab | ||||
def __contains__(self, item): | |||||
"""Check if a word in vocabulary. | |||||
:param item: the word | |||||
:return: True or False | |||||
""" | |||||
return item in self.word2idx | |||||
def has_word(self, w): | def has_word(self, w): | ||||
return w in self.word2idx | |||||
return self.__contains__(w) | |||||
@check_build_vocab | @check_build_vocab | ||||
def __getitem__(self, w): | def __getitem__(self, w): | ||||
@@ -114,18 +133,17 @@ class Vocabulary(object): | |||||
if w in self.word2idx: | if w in self.word2idx: | ||||
return self.word2idx[w] | return self.word2idx[w] | ||||
elif self.has_default: | elif self.has_default: | ||||
return self.word2idx[DEFAULT_UNKNOWN_LABEL] | |||||
return self.word2idx[self.unknown_label] | |||||
else: | else: | ||||
raise ValueError("word {} not in vocabulary".format(w)) | raise ValueError("word {} not in vocabulary".format(w)) | ||||
@check_build_vocab | |||||
def to_index(self, w): | def to_index(self, w): | ||||
""" like to_index(w) function, turn a word to the index | |||||
if w is not in Vocabulary, return the unknown label | |||||
""" Turn a word to an index. | |||||
If w is not in Vocabulary, return the unknown label. | |||||
:param str w: | :param str w: | ||||
""" | """ | ||||
return self[w] | |||||
return self.__getitem__(w) | |||||
@property | @property | ||||
@check_build_vocab | @check_build_vocab | ||||
@@ -134,6 +152,11 @@ class Vocabulary(object): | |||||
return None | return None | ||||
return self.word2idx[self.unknown_label] | return self.word2idx[self.unknown_label] | ||||
def __setattr__(self, name, val): | |||||
self.__dict__[name] = val | |||||
if name in ["unknown_label", "padding_label"]: | |||||
self.word2idx = None | |||||
@property | @property | ||||
@check_build_vocab | @check_build_vocab | ||||
def padding_idx(self): | def padding_idx(self): | ||||
@@ -145,14 +168,14 @@ class Vocabulary(object): | |||||
def to_word(self, idx): | def to_word(self, idx): | ||||
"""given a word's index, return the word itself | """given a word's index, return the word itself | ||||
:param int idx: | |||||
:param int idx: the index | |||||
:return str word: the indexed word | |||||
""" | """ | ||||
if self.idx2word is None: | |||||
self.build_reverse_vocab() | |||||
return self.idx2word[idx] | return self.idx2word[idx] | ||||
def __getstate__(self): | def __getstate__(self): | ||||
"""use to prepare data for pickle | |||||
"""Use to prepare data for pickle. | |||||
""" | """ | ||||
state = self.__dict__.copy() | state = self.__dict__.copy() | ||||
# no need to pickle idx2word as it can be constructed from word2idx | # no need to pickle idx2word as it can be constructed from word2idx | ||||
@@ -160,15 +183,9 @@ class Vocabulary(object): | |||||
return state | return state | ||||
def __setstate__(self, state): | def __setstate__(self, state): | ||||
"""use to restore state from pickle | |||||
"""Use to restore state from pickle. | |||||
""" | """ | ||||
self.__dict__.update(state) | self.__dict__.update(state) | ||||
self.idx2word = None | |||||
self.build_reverse_vocab() | |||||
def __contains__(self, item): | |||||
"""Check if a word in vocabulary. | |||||
:param item: the word | |||||
:return: True or False | |||||
""" | |||||
return self.has_word(item) |
@@ -1,343 +0,0 @@ | |||||
import os | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.loader.dataset_loader import convert_seq_dataset | |||||
from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer | |||||
from fastNLP.core.preprocess import load_pickle | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.loader.model_loader import ModelLoader | |||||
""" | |||||
mapping from model name to [URL, file_name.class_name, model_pickle_name] | |||||
Notice that the class of the model should be in "models" directory. | |||||
Example: | |||||
"seq_label_model": { | |||||
"url": "www.fudan.edu.cn", | |||||
"class": "sequence_modeling.SeqLabeling", # file_name.class_name in models/ | |||||
"pickle": "seq_label_model.pkl", | |||||
"type": "seq_label", | |||||
"config_file_name": "config", # the name of the config file which stores model initialization parameters | |||||
"config_section_name": "text_class_model" # the name of the section in the config file which stores model init params | |||||
}, | |||||
"text_class_model": { | |||||
"url": "www.fudan.edu.cn", | |||||
"class": "cnn_text_classification.CNNText", | |||||
"pickle": "text_class_model.pkl", | |||||
"type": "text_class" | |||||
} | |||||
""" | |||||
FastNLP_MODEL_COLLECTION = { | |||||
"cws_basic_model": { | |||||
"url": "", | |||||
"class": "sequence_modeling.AdvSeqLabel", | |||||
"pickle": "cws_basic_model_v_0.pkl", | |||||
"type": "seq_label", | |||||
"config_file_name": "cws.cfg", | |||||
"config_section_name": "text_class_model" | |||||
}, | |||||
"pos_tag_model": { | |||||
"url": "", | |||||
"class": "sequence_modeling.AdvSeqLabel", | |||||
"pickle": "pos_tag_model_v_0.pkl", | |||||
"type": "seq_label", | |||||
"config_file_name": "pos_tag.cfg", | |||||
"config_section_name": "pos_tag_model" | |||||
}, | |||||
"text_classify_model": { | |||||
"url": "", | |||||
"class": "cnn_text_classification.CNNText", | |||||
"pickle": "text_class_model_v0.pkl", | |||||
"type": "text_class", | |||||
"config_file_name": "text_classify.cfg", | |||||
"config_section_name": "model" | |||||
} | |||||
} | |||||
class FastNLP(object): | |||||
""" | |||||
High-level interface for direct model inference. | |||||
Example Usage | |||||
:: | |||||
fastnlp = FastNLP() | |||||
fastnlp.load("zh_pos_tag_model") | |||||
text = "这是最好的基于深度学习的中文分词系统。" | |||||
result = fastnlp.run(text) | |||||
print(result) # ["这", "是", "最好", "的", "基于", "深度学习", "的", "中文", "分词", "系统", "。"] | |||||
""" | |||||
def __init__(self, model_dir="./"): | |||||
""" | |||||
:param model_dir: this directory should contain the following files: | |||||
1. a trained model | |||||
2. a config file, which is a fastNLP's configuration. | |||||
3. two Vocab files, which are pickle objects of Vocab instances, representing feature and label vocabs. | |||||
""" | |||||
self.model_dir = model_dir | |||||
self.model = None | |||||
self.infer_type = None # "seq_label"/"text_class" | |||||
self.word_vocab = None | |||||
self.label_vocab = None | |||||
def load(self, model_name, config_file="config", section_name="model"): | |||||
""" | |||||
Load a pre-trained FastNLP model together with additional data. | |||||
:param model_name: str, the name of a FastNLP model. | |||||
:param config_file: str, the name of the config file which stores the initialization information of the model. | |||||
(default: "config") | |||||
:param section_name: str, the name of the corresponding section in the config file. (default: model) | |||||
""" | |||||
assert type(model_name) is str | |||||
if model_name not in FastNLP_MODEL_COLLECTION: | |||||
raise ValueError("No FastNLP model named {}.".format(model_name)) | |||||
if not self.model_exist(model_dir=self.model_dir): | |||||
self._download(model_name, FastNLP_MODEL_COLLECTION[model_name]["url"]) | |||||
model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"]) | |||||
print("Restore model class {}".format(str(model_class))) | |||||
model_args = ConfigSection() | |||||
ConfigLoader.load_config(os.path.join(self.model_dir, config_file), {section_name: model_args}) | |||||
print("Restore model hyper-parameters {}".format(str(model_args.data))) | |||||
# fetch dictionary size and number of labels from pickle files | |||||
self.word_vocab = load_pickle(self.model_dir, "word2id.pkl") | |||||
model_args["vocab_size"] = len(self.word_vocab) | |||||
self.label_vocab = load_pickle(self.model_dir, "label2id.pkl") | |||||
model_args["num_classes"] = len(self.label_vocab) | |||||
# Construct the model | |||||
model = model_class(model_args) | |||||
print("Model constructed.") | |||||
# To do: framework independent | |||||
ModelLoader.load_pytorch(model, os.path.join(self.model_dir, FastNLP_MODEL_COLLECTION[model_name]["pickle"])) | |||||
print("Model weights loaded.") | |||||
self.model = model | |||||
self.infer_type = FastNLP_MODEL_COLLECTION[model_name]["type"] | |||||
print("Inference ready.") | |||||
def run(self, raw_input): | |||||
""" | |||||
Perform inference over given input using the loaded model. | |||||
:param raw_input: list of string. Each list is an input query. | |||||
:return results: | |||||
""" | |||||
infer = self._create_inference(self.model_dir) | |||||
# tokenize: list of string ---> 2-D list of string | |||||
infer_input = self.tokenize(raw_input, language="zh") | |||||
# create DataSet: 2-D list of strings ----> DataSet | |||||
infer_data = self._create_data_set(infer_input) | |||||
# DataSet ---> 2-D list of tags | |||||
results = infer.predict(self.model, infer_data) | |||||
# 2-D list of tags ---> list of final answers | |||||
outputs = self._make_output(results, infer_input) | |||||
return outputs | |||||
@staticmethod | |||||
def _get_model_class(file_class_name): | |||||
""" | |||||
Feature the class specified by <file_class_name> | |||||
:param file_class_name: str, contains the name of the Python module followed by the name of the class. | |||||
Example: "sequence_modeling.SeqLabeling" | |||||
:return module: the model class | |||||
""" | |||||
import_prefix = "fastNLP.models." | |||||
parts = (import_prefix + file_class_name).split(".") | |||||
from_module = ".".join(parts[:-1]) | |||||
module = __import__(from_module) | |||||
for sub in parts[1:]: | |||||
module = getattr(module, sub) | |||||
return module | |||||
def _create_inference(self, model_dir): | |||||
"""Specify which task to perform. | |||||
:param model_dir: | |||||
:return: | |||||
""" | |||||
if self.infer_type == "seq_label": | |||||
return SeqLabelInfer(model_dir) | |||||
elif self.infer_type == "text_class": | |||||
return ClassificationInfer(model_dir) | |||||
else: | |||||
raise ValueError("fail to create inference instance") | |||||
def _create_data_set(self, infer_input): | |||||
"""Create a DataSet object given the raw inputs. | |||||
:param infer_input: 2-D lists of strings | |||||
:return data_set: a DataSet object | |||||
""" | |||||
if self.infer_type in ["seq_label", "text_class"]: | |||||
data_set = convert_seq_dataset(infer_input) | |||||
data_set.index_field("word_seq", self.word_vocab) | |||||
if self.infer_type == "seq_label": | |||||
data_set.set_origin_len("word_seq") | |||||
return data_set | |||||
else: | |||||
raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type)) | |||||
def _load(self, model_dir, model_name): | |||||
return 0 | |||||
def _download(self, model_name, url): | |||||
""" | |||||
Download the model weights from <url> and save in <self.model_dir>. | |||||
:param model_name: | |||||
:param url: | |||||
""" | |||||
print("Downloading {} from {}".format(model_name, url)) | |||||
# TODO: download model via url | |||||
def model_exist(self, model_dir): | |||||
""" | |||||
Check whether the desired model is already in the directory. | |||||
:param model_dir: | |||||
""" | |||||
return True | |||||
def tokenize(self, text, language): | |||||
"""Extract tokens from strings. | |||||
For English, extract words separated by space. | |||||
For Chinese, extract characters. | |||||
TODO: more complex tokenization methods | |||||
:param text: list of string | |||||
:param language: str, one of ('zh', 'en'), Chinese or English. | |||||
:return data: list of list of string, each string is a token. | |||||
""" | |||||
assert language in ("zh", "en") | |||||
data = [] | |||||
for sent in text: | |||||
if language == "en": | |||||
tokens = sent.strip().split() | |||||
elif language == "zh": | |||||
tokens = [char for char in sent] | |||||
else: | |||||
raise RuntimeError("Unknown language {}".format(language)) | |||||
data.append(tokens) | |||||
return data | |||||
def _make_output(self, results, infer_input): | |||||
"""Transform the infer output into user-friendly output. | |||||
:param results: 1 or 2-D list of strings. | |||||
If self.infer_type == "seq_label", it is of shape [num_examples, tag_seq_length] | |||||
If self.infer_type == "text_class", it is of shape [num_examples] | |||||
:param infer_input: 2-D list of string, the input query before inference. | |||||
:return outputs: list. Each entry is a prediction. | |||||
""" | |||||
if self.infer_type == "seq_label": | |||||
outputs = make_seq_label_output(results, infer_input) | |||||
elif self.infer_type == "text_class": | |||||
outputs = make_class_output(results, infer_input) | |||||
else: | |||||
raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type)) | |||||
return outputs | |||||
def make_seq_label_output(result, infer_input): | |||||
"""Transform model output into user-friendly contents. | |||||
:param result: 2-D list of strings. (model output) | |||||
:param infer_input: 2-D list of string (model input) | |||||
:return ret: list of list of tuples | |||||
[ | |||||
[(word_11, label_11), (word_12, label_12), ...], | |||||
[(word_21, label_21), (word_22, label_22), ...], | |||||
... | |||||
] | |||||
""" | |||||
ret = [] | |||||
for example_x, example_y in zip(infer_input, result): | |||||
ret.append([(x, y) for x, y in zip(example_x, example_y)]) | |||||
return ret | |||||
def make_class_output(result, infer_input): | |||||
"""Transform model output into user-friendly contents. | |||||
:param result: 2-D list of strings. (model output) | |||||
:param infer_input: 1-D list of string (model input) | |||||
:return ret: the same as result, [label_1, label_2, ...] | |||||
""" | |||||
return result | |||||
def interpret_word_seg_results(char_seq, label_seq): | |||||
"""Transform model output into user-friendly contents. | |||||
Example: In CWS, convert <BMES> labeling into segmented text. | |||||
:param char_seq: list of string, | |||||
:param label_seq: list of string, the same length as char_seq | |||||
Each entry is one of ('B', 'M', 'E', 'S'). | |||||
:return output: list of words | |||||
""" | |||||
words = [] | |||||
word = "" | |||||
for char, label in zip(char_seq, label_seq): | |||||
if label[0] == "B": | |||||
if word != "": | |||||
words.append(word) | |||||
word = char | |||||
elif label[0] == "M": | |||||
word += char | |||||
elif label[0] == "E": | |||||
word += char | |||||
words.append(word) | |||||
word = "" | |||||
elif label[0] == "S": | |||||
if word != "": | |||||
words.append(word) | |||||
word = "" | |||||
words.append(char) | |||||
else: | |||||
raise ValueError("invalid label {}".format(label[0])) | |||||
return words | |||||
def interpret_cws_pos_results(char_seq, label_seq): | |||||
"""Transform model output into user-friendly contents. | |||||
:param char_seq: list of string | |||||
:param label_seq: list of string, the same length as char_seq. | |||||
:return outputs: list of tuple (words, pos_tag): | |||||
""" | |||||
def pos_tag_check(seq): | |||||
"""check whether all entries are the same """ | |||||
return len(set(seq)) <= 1 | |||||
word = [] | |||||
word_pos = [] | |||||
outputs = [] | |||||
for char, label in zip(char_seq, label_seq): | |||||
tmp = label.split("-") | |||||
cws_label, pos_tag = tmp[0], tmp[1] | |||||
if cws_label == "B" or cws_label == "M": | |||||
word.append(char) | |||||
word_pos.append(pos_tag) | |||||
elif cws_label == "E": | |||||
word.append(char) | |||||
word_pos.append(pos_tag) | |||||
if not pos_tag_check(word_pos): | |||||
raise RuntimeError("character-wise pos tags inconsistent. ") | |||||
outputs.append(("".join(word), word_pos[0])) | |||||
word.clear() | |||||
word_pos.clear() | |||||
elif cws_label == "S": | |||||
outputs.append((char, pos_tag)) | |||||
return outputs |
@@ -1,3 +1,7 @@ | |||||
import _pickle as pickle | |||||
import os | |||||
class BaseLoader(object): | class BaseLoader(object): | ||||
def __init__(self): | def __init__(self): | ||||
@@ -9,12 +13,23 @@ class BaseLoader(object): | |||||
text = f.readlines() | text = f.readlines() | ||||
return [line.strip() for line in text] | return [line.strip() for line in text] | ||||
@staticmethod | |||||
def load(data_path): | |||||
@classmethod | |||||
def load(cls, data_path): | |||||
with open(data_path, "r", encoding="utf-8") as f: | with open(data_path, "r", encoding="utf-8") as f: | ||||
text = f.readlines() | text = f.readlines() | ||||
return [[word for word in sent.strip()] for sent in text] | return [[word for word in sent.strip()] for sent in text] | ||||
@classmethod | |||||
def load_with_cache(cls, data_path, cache_path): | |||||
if os.path.isfile(cache_path) and os.path.getmtime(data_path) < os.path.getmtime(cache_path): | |||||
with open(cache_path, 'rb') as f: | |||||
return pickle.load(f) | |||||
else: | |||||
obj = cls.load(data_path) | |||||
with open(cache_path, 'wb') as f: | |||||
pickle.dump(obj, f) | |||||
return obj | |||||
class ToyLoader0(BaseLoader): | class ToyLoader0(BaseLoader): | ||||
""" | """ |
@@ -2,7 +2,7 @@ import configparser | |||||
import json | import json | ||||
import os | import os | ||||
from fastNLP.loader.base_loader import BaseLoader | |||||
from fastNLP.io.base_loader import BaseLoader | |||||
class ConfigLoader(BaseLoader): | class ConfigLoader(BaseLoader): |
@@ -1,7 +1,6 @@ | |||||
import os | import os | ||||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | |||||
from fastNLP.saver.logger import create_logger | |||||
from fastNLP.io.config_loader import ConfigSection, ConfigLoader | |||||
class ConfigSaver(object): | class ConfigSaver(object): | ||||
@@ -61,8 +60,8 @@ class ConfigSaver(object): | |||||
continue | continue | ||||
if '=' not in line: | if '=' not in line: | ||||
log = create_logger(__name__, './config_saver.log') | |||||
log.error("can NOT load config file [%s]" % self.file_path) | |||||
# log = create_logger(__name__, './config_saver.log') | |||||
# log.error("can NOT load config file [%s]" % self.file_path) | |||||
raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) | raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) | ||||
key = line.split('=', maxsplit=1)[0].strip() | key = line.split('=', maxsplit=1)[0].strip() | ||||
@@ -123,10 +122,10 @@ class ConfigSaver(object): | |||||
change_file = True | change_file = True | ||||
break | break | ||||
if section_file[k] != section[k]: | if section_file[k] != section[k]: | ||||
logger = create_logger(__name__, "./config_loader.log") | |||||
logger.warning("section [%s] in config file [%s] has been changed" % ( | |||||
section_name, self.file_path | |||||
)) | |||||
# logger = create_logger(__name__, "./config_loader.log") | |||||
# logger.warning("section [%s] in config file [%s] has been changed" % ( | |||||
# section_name, self.file_path | |||||
#)) | |||||
change_file = True | change_file = True | ||||
break | break | ||||
if not change_file: | if not change_file: |
@@ -1,9 +1,9 @@ | |||||
#TODO: need fix for current DataSet | |||||
import os | import os | ||||
from fastNLP.loader.base_loader import BaseLoader | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.field import * | |||||
from fastNLP.io.base_loader import BaseLoader | |||||
def convert_seq_dataset(data): | def convert_seq_dataset(data): | ||||
@@ -368,6 +368,8 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||||
pos_tag_examples = [] | pos_tag_examples = [] | ||||
ner_examples = [] | ner_examples = [] | ||||
for sent in sents: | for sent in sents: | ||||
if len(sent) <= 2: | |||||
continue | |||||
inside_ne = False | inside_ne = False | ||||
sent_pos_tag = [] | sent_pos_tag = [] | ||||
sent_words = [] | sent_words = [] | ||||
@@ -400,6 +402,7 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||||
sent_words.append(token) | sent_words.append(token) | ||||
pos_tag_examples.append([sent_words, sent_pos_tag]) | pos_tag_examples.append([sent_words, sent_pos_tag]) | ||||
ner_examples.append([sent_words, sent_ner]) | ner_examples.append([sent_words, sent_ner]) | ||||
# List[List[List[str], List[str]]] | |||||
return pos_tag_examples, ner_examples | return pos_tag_examples, ner_examples | ||||
def convert(self, data): | def convert(self, data): |
@@ -1,10 +1,7 @@ | |||||
import _pickle | |||||
import os | |||||
import torch | import torch | ||||
from fastNLP.loader.base_loader import BaseLoader | |||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.io.base_loader import BaseLoader | |||||
class EmbedLoader(BaseLoader): | class EmbedLoader(BaseLoader): | ||||
@@ -17,8 +14,8 @@ class EmbedLoader(BaseLoader): | |||||
def _load_glove(emb_file): | def _load_glove(emb_file): | ||||
"""Read file as a glove embedding | """Read file as a glove embedding | ||||
file format: | |||||
embeddings are split by line, | |||||
file format: | |||||
embeddings are split by line, | |||||
for one embedding, word and numbers split by space | for one embedding, word and numbers split by space | ||||
Example:: | Example:: | ||||
@@ -30,10 +27,10 @@ class EmbedLoader(BaseLoader): | |||||
with open(emb_file, 'r', encoding='utf-8') as f: | with open(emb_file, 'r', encoding='utf-8') as f: | ||||
for line in f: | for line in f: | ||||
line = list(filter(lambda w: len(w)>0, line.strip().split(' '))) | line = list(filter(lambda w: len(w)>0, line.strip().split(' '))) | ||||
if len(line) > 0: | |||||
if len(line) > 2: | |||||
emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) | emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) | ||||
return emb | return emb | ||||
@staticmethod | @staticmethod | ||||
def _load_pretrain(emb_file, emb_type): | def _load_pretrain(emb_file, emb_type): | ||||
"""Read txt data from embedding file and convert to np.array as pre-trained embedding | """Read txt data from embedding file and convert to np.array as pre-trained embedding | ||||
@@ -61,10 +58,10 @@ class EmbedLoader(BaseLoader): | |||||
TODO: fragile code | TODO: fragile code | ||||
""" | """ | ||||
# If the embedding pickle exists, load it and return. | # If the embedding pickle exists, load it and return. | ||||
if os.path.exists(emb_pkl): | |||||
with open(emb_pkl, "rb") as f: | |||||
embedding_tensor, vocab = _pickle.load(f) | |||||
return embedding_tensor, vocab | |||||
# if os.path.exists(emb_pkl): | |||||
# with open(emb_pkl, "rb") as f: | |||||
# embedding_tensor, vocab = _pickle.load(f) | |||||
# return embedding_tensor, vocab | |||||
# Otherwise, load the pre-trained embedding. | # Otherwise, load the pre-trained embedding. | ||||
pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) | pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) | ||||
if vocab is None: | if vocab is None: | ||||
@@ -80,6 +77,6 @@ class EmbedLoader(BaseLoader): | |||||
embedding_tensor[vocab[w]] = v | embedding_tensor[vocab[w]] = v | ||||
# save and return the result | # save and return the result | ||||
with open(emb_pkl, "wb") as f: | |||||
_pickle.dump((embedding_tensor, vocab), f) | |||||
# with open(emb_pkl, "wb") as f: | |||||
# _pickle.dump((embedding_tensor, vocab), f) | |||||
return embedding_tensor, vocab | return embedding_tensor, vocab |
@@ -1,6 +1,6 @@ | |||||
import torch | import torch | ||||
from fastNLP.loader.base_loader import BaseLoader | |||||
from fastNLP.io.base_loader import BaseLoader | |||||
class ModelLoader(BaseLoader): | class ModelLoader(BaseLoader): | ||||
@@ -8,8 +8,8 @@ class ModelLoader(BaseLoader): | |||||
Loader for models. | Loader for models. | ||||
""" | """ | ||||
def __init__(self, data_path): | |||||
super(ModelLoader, self).__init__(data_path) | |||||
def __init__(self): | |||||
super(ModelLoader, self).__init__() | |||||
@staticmethod | @staticmethod | ||||
def load_pytorch(empty_model, model_path): | def load_pytorch(empty_model, model_path): | ||||
@@ -19,3 +19,10 @@ class ModelLoader(BaseLoader): | |||||
:param model_path: str, the path to the saved model. | :param model_path: str, the path to the saved model. | ||||
""" | """ | ||||
empty_model.load_state_dict(torch.load(model_path)) | empty_model.load_state_dict(torch.load(model_path)) | ||||
@staticmethod | |||||
def load_pytorch_model(model_path): | |||||
"""Load the entire model. | |||||
""" | |||||
return torch.load(model_path) |
@@ -15,10 +15,14 @@ class ModelSaver(object): | |||||
""" | """ | ||||
self.save_path = save_path | self.save_path = save_path | ||||
def save_pytorch(self, model): | |||||
def save_pytorch(self, model, param_only=True): | |||||
"""Save a pytorch model into .pkl file. | """Save a pytorch model into .pkl file. | ||||
:param model: a PyTorch model | :param model: a PyTorch model | ||||
:param param_only: bool, whether only to save the model parameters or the entire model. | |||||
""" | """ | ||||
torch.save(model.state_dict(), self.save_path) | |||||
if param_only is True: | |||||
torch.save(model.state_dict(), self.save_path) | |||||
else: | |||||
torch.save(model, self.save_path) |
@@ -0,0 +1,6 @@ | |||||
from .base_model import BaseModel | |||||
from .biaffine_parser import BiaffineParser, GraphParser | |||||
from .char_language_model import CharLM | |||||
from .cnn_text_classification import CNNText | |||||
from .sequence_modeling import SeqLabeling, AdvSeqLabel | |||||
from .snli import SNLI |
@@ -13,3 +13,6 @@ class BaseModel(torch.nn.Module): | |||||
def fit(self, train_data, dev_data=None, **train_args): | def fit(self, train_data, dev_data=None, **train_args): | ||||
trainer = Trainer(**train_args) | trainer = Trainer(**train_args) | ||||
trainer.train(self, train_data, dev_data) | trainer.train(self, train_data, dev_data) | ||||
def predict(self, *args, **kwargs): | |||||
raise NotImplementedError |
@@ -9,6 +9,8 @@ from torch.nn import functional as F | |||||
from fastNLP.modules.utils import initial_parameter | from fastNLP.modules.utils import initial_parameter | ||||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | from fastNLP.modules.encoder.variational_rnn import VarLSTM | ||||
from fastNLP.modules.dropout import TimestepDropout | from fastNLP.modules.dropout import TimestepDropout | ||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.modules.utils import seq_mask | |||||
def mst(scores): | def mst(scores): | ||||
""" | """ | ||||
@@ -16,10 +18,9 @@ def mst(scores): | |||||
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | ||||
""" | """ | ||||
length = scores.shape[0] | length = scores.shape[0] | ||||
min_score = -np.inf | |||||
mask = np.zeros((length, length)) | |||||
np.fill_diagonal(mask, -np.inf) | |||||
scores = scores + mask | |||||
min_score = scores.min() - 1 | |||||
eye = np.eye(length) | |||||
scores = scores * (1 - eye) + min_score * eye | |||||
heads = np.argmax(scores, axis=1) | heads = np.argmax(scores, axis=1) | ||||
heads[0] = 0 | heads[0] = 0 | ||||
tokens = np.arange(1, length) | tokens = np.arange(1, length) | ||||
@@ -114,7 +115,7 @@ def _find_cycle(vertices, edges): | |||||
return [SCC for SCC in _SCCs if len(SCC) > 1] | return [SCC for SCC in _SCCs if len(SCC) > 1] | ||||
class GraphParser(nn.Module): | |||||
class GraphParser(BaseModel): | |||||
"""Graph based Parser helper class, support greedy decoding and MST(Maximum Spanning Tree) decoding | """Graph based Parser helper class, support greedy decoding and MST(Maximum Spanning Tree) decoding | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
@@ -123,22 +124,31 @@ class GraphParser(nn.Module): | |||||
def forward(self, x): | def forward(self, x): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def _greedy_decoder(self, arc_matrix, seq_mask=None): | |||||
def _greedy_decoder(self, arc_matrix, mask=None): | |||||
_, seq_len, _ = arc_matrix.shape | _, seq_len, _ = arc_matrix.shape | ||||
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | ||||
flip_mask = (mask == 0).byte() | |||||
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||||
_, heads = torch.max(matrix, dim=2) | _, heads = torch.max(matrix, dim=2) | ||||
if seq_mask is not None: | |||||
heads *= seq_mask.long() | |||||
if mask is not None: | |||||
heads *= mask.long() | |||||
return heads | return heads | ||||
def _mst_decoder(self, arc_matrix, seq_mask=None): | |||||
def _mst_decoder(self, arc_matrix, mask=None): | |||||
batch_size, seq_len, _ = arc_matrix.shape | batch_size, seq_len, _ = arc_matrix.shape | ||||
matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix) | matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix) | ||||
ans = matrix.new_zeros(batch_size, seq_len).long() | ans = matrix.new_zeros(batch_size, seq_len).long() | ||||
lens = (mask.long()).sum(1) if mask is not None else torch.zeros(batch_size) + seq_len | |||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device) | |||||
mask[batch_idx, lens-1] = 0 | |||||
for i, graph in enumerate(matrix): | for i, graph in enumerate(matrix): | ||||
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) | |||||
if seq_mask is not None: | |||||
ans *= seq_mask.long() | |||||
len_i = lens[i] | |||||
if len_i == seq_len: | |||||
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) | |||||
else: | |||||
ans[i, :len_i] = torch.as_tensor(mst(graph[:len_i, :len_i].cpu().numpy()), device=ans.device) | |||||
if mask is not None: | |||||
ans *= mask.long() | |||||
return ans | return ans | ||||
@@ -175,15 +185,13 @@ class LabelBilinear(nn.Module): | |||||
def __init__(self, in1_features, in2_features, num_label, bias=True): | def __init__(self, in1_features, in2_features, num_label, bias=True): | ||||
super(LabelBilinear, self).__init__() | super(LabelBilinear, self).__init__() | ||||
self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) | self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) | ||||
self.lin1 = nn.Linear(in1_features, num_label, bias=False) | |||||
self.lin2 = nn.Linear(in2_features, num_label, bias=False) | |||||
self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) | |||||
def forward(self, x1, x2): | def forward(self, x1, x2): | ||||
output = self.bilinear(x1, x2) | output = self.bilinear(x1, x2) | ||||
output += self.lin1(x1) + self.lin2(x2) | |||||
output += self.lin(torch.cat([x1, x2], dim=2)) | |||||
return output | return output | ||||
class BiaffineParser(GraphParser): | class BiaffineParser(GraphParser): | ||||
"""Biaffine Dependency Parser implemantation. | """Biaffine Dependency Parser implemantation. | ||||
refer to ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) | refer to ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) | ||||
@@ -194,6 +202,8 @@ class BiaffineParser(GraphParser): | |||||
word_emb_dim, | word_emb_dim, | ||||
pos_vocab_size, | pos_vocab_size, | ||||
pos_emb_dim, | pos_emb_dim, | ||||
word_hid_dim, | |||||
pos_hid_dim, | |||||
rnn_layers, | rnn_layers, | ||||
rnn_hidden_size, | rnn_hidden_size, | ||||
arc_mlp_size, | arc_mlp_size, | ||||
@@ -204,10 +214,15 @@ class BiaffineParser(GraphParser): | |||||
use_greedy_infer=False): | use_greedy_infer=False): | ||||
super(BiaffineParser, self).__init__() | super(BiaffineParser, self).__init__() | ||||
rnn_out_size = 2 * rnn_hidden_size | |||||
self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) | self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) | ||||
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | ||||
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) | |||||
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | |||||
self.word_norm = nn.LayerNorm(word_hid_dim) | |||||
self.pos_norm = nn.LayerNorm(pos_hid_dim) | |||||
if use_var_lstm: | if use_var_lstm: | ||||
self.lstm = VarLSTM(input_size=word_emb_dim + pos_emb_dim, | |||||
self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | hidden_size=rnn_hidden_size, | ||||
num_layers=rnn_layers, | num_layers=rnn_layers, | ||||
bias=True, | bias=True, | ||||
@@ -216,7 +231,7 @@ class BiaffineParser(GraphParser): | |||||
hidden_dropout=dropout, | hidden_dropout=dropout, | ||||
bidirectional=True) | bidirectional=True) | ||||
else: | else: | ||||
self.lstm = nn.LSTM(input_size=word_emb_dim + pos_emb_dim, | |||||
self.lstm = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | hidden_size=rnn_hidden_size, | ||||
num_layers=rnn_layers, | num_layers=rnn_layers, | ||||
bias=True, | bias=True, | ||||
@@ -224,141 +239,153 @@ class BiaffineParser(GraphParser): | |||||
dropout=dropout, | dropout=dropout, | ||||
bidirectional=True) | bidirectional=True) | ||||
rnn_out_size = 2 * rnn_hidden_size | |||||
self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), | self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), | ||||
nn.ELU()) | |||||
nn.LayerNorm(arc_mlp_size), | |||||
nn.ELU(), | |||||
TimestepDropout(p=dropout),) | |||||
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) | self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) | ||||
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), | self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), | ||||
nn.ELU()) | |||||
nn.LayerNorm(label_mlp_size), | |||||
nn.ELU(), | |||||
TimestepDropout(p=dropout),) | |||||
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | ||||
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | ||||
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | ||||
self.normal_dropout = nn.Dropout(p=dropout) | self.normal_dropout = nn.Dropout(p=dropout) | ||||
self.timestep_dropout = TimestepDropout(p=dropout) | |||||
self.use_greedy_infer = use_greedy_infer | self.use_greedy_infer = use_greedy_infer | ||||
initial_parameter(self) | |||||
self.reset_parameters() | |||||
self.explore_p = 0.2 | |||||
def reset_parameters(self): | |||||
for m in self.modules(): | |||||
if isinstance(m, nn.Embedding): | |||||
continue | |||||
elif isinstance(m, nn.LayerNorm): | |||||
nn.init.constant_(m.weight, 0.1) | |||||
nn.init.constant_(m.bias, 0) | |||||
else: | |||||
for p in m.parameters(): | |||||
nn.init.normal_(p, 0, 0.1) | |||||
def forward(self, word_seq, pos_seq, seq_mask, gold_heads=None, **_): | |||||
def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_): | |||||
""" | """ | ||||
:param word_seq: [batch_size, seq_len] sequence of word's indices | :param word_seq: [batch_size, seq_len] sequence of word's indices | ||||
:param pos_seq: [batch_size, seq_len] sequence of word's indices | :param pos_seq: [batch_size, seq_len] sequence of word's indices | ||||
:param seq_mask: [batch_size, seq_len] sequence of length masks | |||||
:param word_seq_origin_len: [batch_size, seq_len] sequence of length masks | |||||
:param gold_heads: [batch_size, seq_len] sequence of golden heads | :param gold_heads: [batch_size, seq_len] sequence of golden heads | ||||
:return dict: parsing results | :return dict: parsing results | ||||
arc_pred: [batch_size, seq_len, seq_len] | arc_pred: [batch_size, seq_len, seq_len] | ||||
label_pred: [batch_size, seq_len, seq_len] | label_pred: [batch_size, seq_len, seq_len] | ||||
seq_mask: [batch_size, seq_len] | |||||
mask: [batch_size, seq_len] | |||||
head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | ||||
""" | """ | ||||
# prepare embeddings | # prepare embeddings | ||||
device = self.parameters().__next__().device | |||||
word_seq = word_seq.long().to(device) | |||||
pos_seq = pos_seq.long().to(device) | |||||
word_seq_origin_len = word_seq_origin_len.long().to(device).view(-1) | |||||
batch_size, seq_len = word_seq.shape | batch_size, seq_len = word_seq.shape | ||||
# print('forward {} {}'.format(batch_size, seq_len)) | # print('forward {} {}'.format(batch_size, seq_len)) | ||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | |||||
# get sequence mask | # get sequence mask | ||||
seq_mask = seq_mask.long() | |||||
mask = seq_mask(word_seq_origin_len, seq_len).long() | |||||
word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] | word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] | ||||
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] | pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] | ||||
word, pos = self.word_fc(word), self.pos_fc(pos) | |||||
word, pos = self.word_norm(word), self.pos_norm(pos) | |||||
x = torch.cat([word, pos], dim=2) # -> [N,L,C] | x = torch.cat([word, pos], dim=2) # -> [N,L,C] | ||||
del word, pos | |||||
# lstm, extract features | # lstm, extract features | ||||
sort_lens, sort_idx = torch.sort(word_seq_origin_len, dim=0, descending=True) | |||||
x = x[sort_idx] | |||||
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) | |||||
feat, _ = self.lstm(x) # -> [N,L,C] | feat, _ = self.lstm(x) # -> [N,L,C] | ||||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | |||||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||||
feat = feat[unsort_idx] | |||||
# for arc biaffine | # for arc biaffine | ||||
# mlp, reduce dim | # mlp, reduce dim | ||||
arc_dep = self.timestep_dropout(self.arc_dep_mlp(feat)) | |||||
arc_head = self.timestep_dropout(self.arc_head_mlp(feat)) | |||||
label_dep = self.timestep_dropout(self.label_dep_mlp(feat)) | |||||
label_head = self.timestep_dropout(self.label_head_mlp(feat)) | |||||
arc_dep = self.arc_dep_mlp(feat) | |||||
arc_head = self.arc_head_mlp(feat) | |||||
label_dep = self.label_dep_mlp(feat) | |||||
label_head = self.label_head_mlp(feat) | |||||
del feat | |||||
# biaffine arc classifier | # biaffine arc classifier | ||||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | ||||
flip_mask = (seq_mask == 0) | |||||
arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||||
# use gold or predicted arc to predict label | # use gold or predicted arc to predict label | ||||
if gold_heads is None: | |||||
if gold_heads is None or not self.training: | |||||
# use greedy decoding in training | # use greedy decoding in training | ||||
if self.training or self.use_greedy_infer: | if self.training or self.use_greedy_infer: | ||||
heads = self._greedy_decoder(arc_pred, seq_mask) | |||||
heads = self._greedy_decoder(arc_pred, mask) | |||||
else: | else: | ||||
heads = self._mst_decoder(arc_pred, seq_mask) | |||||
heads = self._mst_decoder(arc_pred, mask) | |||||
head_pred = heads | head_pred = heads | ||||
else: | else: | ||||
head_pred = None | |||||
heads = gold_heads | |||||
assert self.training # must be training mode | |||||
if torch.rand(1).item() < self.explore_p: | |||||
heads = self._greedy_decoder(arc_pred, mask) | |||||
head_pred = heads | |||||
else: | |||||
head_pred = None | |||||
heads = gold_heads | |||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | |||||
label_head = label_head[batch_range, heads].contiguous() | label_head = label_head[batch_range, heads].contiguous() | ||||
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | ||||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask} | |||||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} | |||||
if head_pred is not None: | if head_pred is not None: | ||||
res_dict['head_pred'] = head_pred | res_dict['head_pred'] = head_pred | ||||
return res_dict | return res_dict | ||||
def loss(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **_): | |||||
def loss(self, arc_pred, label_pred, head_indices, head_labels, mask, **_): | |||||
""" | """ | ||||
Compute loss. | Compute loss. | ||||
:param arc_pred: [batch_size, seq_len, seq_len] | :param arc_pred: [batch_size, seq_len, seq_len] | ||||
:param label_pred: [batch_size, seq_len, seq_len] | |||||
:param label_pred: [batch_size, seq_len, n_tags] | |||||
:param head_indices: [batch_size, seq_len] | :param head_indices: [batch_size, seq_len] | ||||
:param head_labels: [batch_size, seq_len] | :param head_labels: [batch_size, seq_len] | ||||
:param seq_mask: [batch_size, seq_len] | |||||
:param mask: [batch_size, seq_len] | |||||
:return: loss value | :return: loss value | ||||
""" | """ | ||||
batch_size, seq_len, _ = arc_pred.shape | batch_size, seq_len, _ = arc_pred.shape | ||||
arc_logits = F.log_softmax(arc_pred, dim=2) | |||||
flip_mask = (mask == 0) | |||||
_arc_pred = arc_pred.new_empty((batch_size, seq_len, seq_len)).copy_(arc_pred) | |||||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||||
arc_logits = F.log_softmax(_arc_pred, dim=2) | |||||
label_logits = F.log_softmax(label_pred, dim=2) | label_logits = F.log_softmax(label_pred, dim=2) | ||||
batch_index = torch.arange(start=0, end=batch_size, device=arc_logits.device).long().unsqueeze(1) | |||||
child_index = torch.arange(start=0, end=seq_len, device=arc_logits.device).long().unsqueeze(0) | |||||
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | |||||
child_index = torch.arange(seq_len, device=arc_logits.device, dtype=torch.long).unsqueeze(0) | |||||
arc_loss = arc_logits[batch_index, child_index, head_indices] | arc_loss = arc_logits[batch_index, child_index, head_indices] | ||||
label_loss = label_logits[batch_index, child_index, head_labels] | label_loss = label_logits[batch_index, child_index, head_labels] | ||||
arc_loss = arc_loss[:, 1:] | arc_loss = arc_loss[:, 1:] | ||||
label_loss = label_loss[:, 1:] | label_loss = label_loss[:, 1:] | ||||
float_mask = seq_mask[:, 1:].float() | |||||
length = (seq_mask.sum() - batch_size).float() | |||||
arc_nll = -(arc_loss*float_mask).sum() / length | |||||
label_nll = -(label_loss*float_mask).sum() / length | |||||
float_mask = mask[:, 1:].float() | |||||
arc_nll = -(arc_loss*float_mask).mean() | |||||
label_nll = -(label_loss*float_mask).mean() | |||||
return arc_nll + label_nll | return arc_nll + label_nll | ||||
def evaluate(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **kwargs): | |||||
""" | |||||
Evaluate the performance of prediction. | |||||
:return dict: performance results. | |||||
head_pred_corrct: number of correct predicted heads. | |||||
label_pred_correct: number of correct predicted labels. | |||||
total_tokens: number of predicted tokens | |||||
def predict(self, word_seq, pos_seq, word_seq_origin_len): | |||||
""" | """ | ||||
if 'head_pred' in kwargs: | |||||
head_pred = kwargs['head_pred'] | |||||
elif self.use_greedy_infer: | |||||
head_pred = self._greedy_decoder(arc_pred, seq_mask) | |||||
else: | |||||
head_pred = self._mst_decoder(arc_pred, seq_mask) | |||||
head_pred_correct = (head_pred == head_indices).long() * seq_mask | |||||
_, label_preds = torch.max(label_pred, dim=2) | |||||
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct | |||||
return {"head_pred_correct": head_pred_correct.sum(dim=1), | |||||
"label_pred_correct": label_pred_correct.sum(dim=1), | |||||
"total_tokens": seq_mask.sum(dim=1)} | |||||
def metrics(self, head_pred_correct, label_pred_correct, total_tokens, **_): | |||||
""" | |||||
Compute the metrics of model | |||||
:param head_pred_corrct: number of correct predicted heads. | |||||
:param label_pred_correct: number of correct predicted labels. | |||||
:param total_tokens: number of predicted tokens | |||||
:return dict: the metrics results | |||||
UAS: the head predicted accuracy | |||||
LAS: the label predicted accuracy | |||||
:param word_seq: | |||||
:param pos_seq: | |||||
:param word_seq_origin_len: | |||||
:return: head_pred: [B, L] | |||||
label_pred: [B, L] | |||||
seq_len: [B,] | |||||
""" | """ | ||||
return {"UAS": head_pred_correct.sum().float() / total_tokens.sum().float() * 100, | |||||
"LAS": label_pred_correct.sum().float() / total_tokens.sum().float() * 100} | |||||
res = self(word_seq, pos_seq, word_seq_origin_len) | |||||
output = {} | |||||
output['head_pred'] = res.pop('head_pred') | |||||
_, label_pred = res.pop('label_pred').max(2) | |||||
output['label_pred'] = label_pred | |||||
return output |
@@ -15,33 +15,67 @@ class CNNText(torch.nn.Module): | |||||
Classification.' | Classification.' | ||||
""" | """ | ||||
def __init__(self, args): | |||||
def __init__(self, embed_num, | |||||
embed_dim, | |||||
num_classes, | |||||
kernel_nums=(3,4,5), | |||||
kernel_sizes=(3,4,5), | |||||
padding=0, | |||||
dropout=0.5): | |||||
super(CNNText, self).__init__() | super(CNNText, self).__init__() | ||||
num_classes = args["num_classes"] | |||||
kernel_nums = [100, 100, 100] | |||||
kernel_sizes = [3, 4, 5] | |||||
vocab_size = args["vocab_size"] | |||||
embed_dim = 300 | |||||
pretrained_embed = None | |||||
drop_prob = 0.5 | |||||
# no support for pre-trained embedding currently | # no support for pre-trained embedding currently | ||||
self.embed = encoder.embedding.Embedding(vocab_size, embed_dim) | |||||
self.conv_pool = encoder.conv_maxpool.ConvMaxpool( | |||||
self.embed = encoder.Embedding(embed_num, embed_dim) | |||||
self.conv_pool = encoder.ConvMaxpool( | |||||
in_channels=embed_dim, | in_channels=embed_dim, | ||||
out_channels=kernel_nums, | out_channels=kernel_nums, | ||||
kernel_sizes=kernel_sizes) | |||||
self.dropout = nn.Dropout(drop_prob) | |||||
self.fc = encoder.linear.Linear(sum(kernel_nums), num_classes) | |||||
kernel_sizes=kernel_sizes, | |||||
padding=padding) | |||||
self.dropout = nn.Dropout(dropout) | |||||
self.fc = encoder.Linear(sum(kernel_nums), num_classes) | |||||
self._loss = nn.CrossEntropyLoss() | |||||
def forward(self, word_seq): | def forward(self, word_seq): | ||||
""" | """ | ||||
:param word_seq: torch.LongTensor, [batch_size, seq_len] | :param word_seq: torch.LongTensor, [batch_size, seq_len] | ||||
:return x: torch.LongTensor, [batch_size, num_classes] | |||||
:return output: dict of torch.LongTensor, [batch_size, num_classes] | |||||
""" | """ | ||||
x = self.embed(word_seq) # [N,L] -> [N,L,C] | x = self.embed(word_seq) # [N,L] -> [N,L,C] | ||||
x = self.conv_pool(x) # [N,L,C] -> [N,C] | x = self.conv_pool(x) # [N,L,C] -> [N,C] | ||||
x = self.dropout(x) | x = self.dropout(x) | ||||
x = self.fc(x) # [N,C] -> [N, N_class] | x = self.fc(x) # [N,C] -> [N, N_class] | ||||
return x | |||||
return {'output':x} | |||||
def predict(self, word_seq): | |||||
""" | |||||
:param word_seq: torch.LongTensor, [batch_size, seq_len] | |||||
:return predict: dict of torch.LongTensor, [batch_size, seq_len] | |||||
""" | |||||
output = self(word_seq) | |||||
_, predict = output['output'].max(dim=1) | |||||
return {'predict': predict} | |||||
def get_loss(self, output, label_seq): | |||||
""" | |||||
:param output: output of forward(), [batch_size, seq_len] | |||||
:param label_seq: true label in DataSet, [batch_size, seq_len] | |||||
:return loss: torch.Tensor | |||||
""" | |||||
return self._loss(output, label_seq) | |||||
def evaluate(self, predict, label_seq): | |||||
""" | |||||
:param predict: iterable predict tensors | |||||
:param label_seq: iterable true label tensors | |||||
:return accuracy: dict of float | |||||
""" | |||||
predict, label_seq = torch.stack(tuple(predict), dim=0), torch.stack(tuple(label_seq), dim=0) | |||||
predict, label_seq = predict.squeeze(), label_seq.squeeze() | |||||
correct = (predict == label_seq).long().sum().item() | |||||
total = label_seq.size(0) | |||||
return {'acc': 1.0 * correct / total} | |||||
@@ -1,21 +1,9 @@ | |||||
import torch | import torch | ||||
import numpy as np | |||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules import decoder, encoder | from fastNLP.modules import decoder, encoder | ||||
def seq_mask(seq_len, max_len): | |||||
"""Create a mask for the sequences. | |||||
:param seq_len: list or torch.LongTensor | |||||
:param max_len: int | |||||
:return mask: torch.LongTensor | |||||
""" | |||||
if isinstance(seq_len, list): | |||||
seq_len = torch.LongTensor(seq_len) | |||||
mask = [torch.ge(seq_len, i + 1) for i in range(max_len)] | |||||
mask = torch.stack(mask, 1) | |||||
return mask | |||||
from fastNLP.modules.utils import seq_mask | |||||
class SeqLabeling(BaseModel): | class SeqLabeling(BaseModel): | ||||
@@ -44,6 +32,9 @@ class SeqLabeling(BaseModel): | |||||
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | :return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | ||||
If truth is not None, return loss, a scalar. Used in training. | If truth is not None, return loss, a scalar. Used in training. | ||||
""" | """ | ||||
assert word_seq.shape[0] == word_seq_origin_len.shape[0] | |||||
if truth is not None: | |||||
assert truth.shape == word_seq.shape | |||||
self.mask = self.make_mask(word_seq, word_seq_origin_len) | self.mask = self.make_mask(word_seq, word_seq_origin_len) | ||||
x = self.Embedding(word_seq) | x = self.Embedding(word_seq) | ||||
@@ -52,10 +43,8 @@ class SeqLabeling(BaseModel): | |||||
# [batch_size, max_len, hidden_size * direction] | # [batch_size, max_len, hidden_size * direction] | ||||
x = self.Linear(x) | x = self.Linear(x) | ||||
# [batch_size, max_len, num_classes] | # [batch_size, max_len, num_classes] | ||||
if truth is not None: | |||||
return self._internal_loss(x, truth) | |||||
else: | |||||
return self.decode(x) | |||||
return {"loss": self._internal_loss(x, truth) if truth is not None else None, | |||||
"predict": self.decode(x)} | |||||
def loss(self, x, y): | def loss(self, x, y): | ||||
""" Since the loss has been computed in forward(), this function simply returns x.""" | """ Since the loss has been computed in forward(), this function simply returns x.""" | ||||
@@ -79,8 +68,8 @@ class SeqLabeling(BaseModel): | |||||
def make_mask(self, x, seq_len): | def make_mask(self, x, seq_len): | ||||
batch_size, max_len = x.size(0), x.size(1) | batch_size, max_len = x.size(0), x.size(1) | ||||
mask = seq_mask(seq_len, max_len) | mask = seq_mask(seq_len, max_len) | ||||
mask = mask.byte().view(batch_size, max_len) | |||||
mask = mask.to(x) | |||||
mask = mask.view(batch_size, max_len) | |||||
mask = mask.to(x).float() | |||||
return mask | return mask | ||||
def decode(self, x, pad=True): | def decode(self, x, pad=True): | ||||
@@ -111,42 +100,119 @@ class AdvSeqLabel(SeqLabeling): | |||||
word_emb_dim = args["word_emb_dim"] | word_emb_dim = args["word_emb_dim"] | ||||
hidden_dim = args["rnn_hidden_units"] | hidden_dim = args["rnn_hidden_units"] | ||||
num_classes = args["num_classes"] | num_classes = args["num_classes"] | ||||
dropout = args['dropout'] | |||||
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | ||||
self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=3, dropout=0.3, bidirectional=True) | |||||
self.norm1 = torch.nn.LayerNorm(word_emb_dim) | |||||
# self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True) | |||||
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, bidirectional=True, batch_first=True) | |||||
self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | ||||
self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | |||||
self.relu = torch.nn.ReLU() | |||||
self.drop = torch.nn.Dropout(0.3) | |||||
self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3) | |||||
# self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | |||||
self.relu = torch.nn.LeakyReLU() | |||||
self.drop = torch.nn.Dropout(dropout) | |||||
self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) | self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) | ||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | |||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||||
def forward(self, word_seq, word_seq_origin_len, truth=None): | def forward(self, word_seq, word_seq_origin_len, truth=None): | ||||
""" | """ | ||||
:param word_seq: LongTensor, [batch_size, mex_len] | :param word_seq: LongTensor, [batch_size, mex_len] | ||||
:param word_seq_origin_len: list of int. | |||||
:param word_seq_origin_len: LongTensor, [batch_size, ] | |||||
:param truth: LongTensor, [batch_size, max_len] | :param truth: LongTensor, [batch_size, max_len] | ||||
:return y: | |||||
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | |||||
If truth is not None, return loss, a scalar. Used in training. | |||||
""" | """ | ||||
word_seq = word_seq.long() | |||||
word_seq_origin_len = word_seq_origin_len.long() | |||||
self.mask = self.make_mask(word_seq, word_seq_origin_len) | self.mask = self.make_mask(word_seq, word_seq_origin_len) | ||||
sent_len, idx_sort = torch.sort(word_seq_origin_len, descending=True) | |||||
_, idx_unsort = torch.sort(idx_sort, descending=False) | |||||
# word_seq_origin_len = word_seq_origin_len.long() | |||||
truth = truth.long() if truth is not None else None | |||||
batch_size = word_seq.size(0) | batch_size = word_seq.size(0) | ||||
max_len = word_seq.size(1) | max_len = word_seq.size(1) | ||||
if next(self.parameters()).is_cuda: | |||||
word_seq = word_seq.cuda() | |||||
idx_sort = idx_sort.cuda() | |||||
idx_unsort = idx_unsort.cuda() | |||||
self.mask = self.mask.cuda() | |||||
x = self.Embedding(word_seq) | x = self.Embedding(word_seq) | ||||
x = self.norm1(x) | |||||
# [batch_size, max_len, word_emb_dim] | # [batch_size, max_len, word_emb_dim] | ||||
x = self.Rnn(x) | |||||
sent_variable = x[idx_sort] | |||||
sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True) | |||||
x, _ = self.Rnn(sent_packed) | |||||
# print(x) | |||||
# [batch_size, max_len, hidden_size * direction] | # [batch_size, max_len, hidden_size * direction] | ||||
sent_output = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0] | |||||
x = sent_output[idx_unsort] | |||||
x = x.contiguous() | x = x.contiguous() | ||||
x = x.view(batch_size * max_len, -1) | |||||
# x = x.view(batch_size * max_len, -1) | |||||
x = self.Linear1(x) | x = self.Linear1(x) | ||||
x = self.batch_norm(x) | |||||
# x = self.batch_norm(x) | |||||
x = self.norm2(x) | |||||
x = self.relu(x) | x = self.relu(x) | ||||
x = self.drop(x) | x = self.drop(x) | ||||
x = self.Linear2(x) | x = self.Linear2(x) | ||||
x = x.view(batch_size, max_len, -1) | |||||
# x = x.view(batch_size, max_len, -1) | |||||
# [batch_size, max_len, num_classes] | # [batch_size, max_len, num_classes] | ||||
if truth is not None: | |||||
return self._internal_loss(x, truth) | |||||
else: | |||||
return self.decode(x) | |||||
# TODO seq_lens的key这样做不合理 | |||||
return {"loss": self._internal_loss(x, truth) if truth is not None else None, | |||||
"predict": self.decode(x), | |||||
'word_seq_origin_len': word_seq_origin_len} | |||||
def predict(self, **x): | |||||
out = self.forward(**x) | |||||
return {"predict": out["predict"]} | |||||
def loss(self, **kwargs): | |||||
assert 'loss' in kwargs | |||||
return kwargs['loss'] | |||||
if __name__ == '__main__': | |||||
args = { | |||||
'vocab_size': 20, | |||||
'word_emb_dim': 100, | |||||
'rnn_hidden_units': 100, | |||||
'num_classes': 10, | |||||
} | |||||
model = AdvSeqLabel(args) | |||||
data = [] | |||||
for i in range(20): | |||||
word_seq = torch.randint(20, (15,)).long() | |||||
word_seq_len = torch.LongTensor([15]) | |||||
truth = torch.randint(10, (15,)).long() | |||||
data.append((word_seq, word_seq_len, truth)) | |||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | |||||
print(model) | |||||
curidx = 0 | |||||
for i in range(1000): | |||||
endidx = min(len(data), curidx + 5) | |||||
b_word, b_len, b_truth = [], [], [] | |||||
for word_seq, word_seq_len, truth in data[curidx: endidx]: | |||||
b_word.append(word_seq) | |||||
b_len.append(word_seq_len) | |||||
b_truth.append(truth) | |||||
word_seq = torch.stack(b_word, dim=0) | |||||
word_seq_len = torch.cat(b_len, dim=0) | |||||
truth = torch.stack(b_truth, dim=0) | |||||
res = model(word_seq, word_seq_len, truth) | |||||
loss = res['loss'] | |||||
pred = res['predict'] | |||||
print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
optimizer.step() | |||||
curidx = endidx | |||||
if curidx == len(data): | |||||
curidx = 0 | |||||
@@ -1,11 +1,14 @@ | |||||
from . import aggregator | from . import aggregator | ||||
from . import decoder | from . import decoder | ||||
from . import encoder | from . import encoder | ||||
from . import interactor | |||||
from .aggregator import * | |||||
from .decoder import * | |||||
from .encoder import * | |||||
from .dropout import TimestepDropout | |||||
__version__ = '0.0.0' | __version__ = '0.0.0' | ||||
__all__ = ['encoder', | __all__ = ['encoder', | ||||
'decoder', | 'decoder', | ||||
'aggregator', | 'aggregator', | ||||
'interactor'] | |||||
'TimestepDropout'] |
@@ -1,5 +1,7 @@ | |||||
from .max_pool import MaxPool | from .max_pool import MaxPool | ||||
from .avg_pool import AvgPool | |||||
from .kmax_pool import KMaxPool | |||||
from .attention import Attention | |||||
from .self_attention import SelfAttention | |||||
__all__ = [ | |||||
'MaxPool' | |||||
] |
@@ -1,5 +1,6 @@ | |||||
import torch | import torch | ||||
from torch import nn | |||||
import math | |||||
from fastNLP.modules.utils import mask_softmax | from fastNLP.modules.utils import mask_softmax | ||||
@@ -17,3 +18,47 @@ class Attention(torch.nn.Module): | |||||
def _atten_forward(self, query, memory): | def _atten_forward(self, query, memory): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
class DotAtte(nn.Module): | |||||
def __init__(self, key_size, value_size): | |||||
# TODO never test | |||||
super(DotAtte, self).__init__() | |||||
self.key_size = key_size | |||||
self.value_size = value_size | |||||
self.scale = math.sqrt(key_size) | |||||
def forward(self, Q, K, V, seq_mask=None): | |||||
""" | |||||
:param Q: [batch, seq_len, key_size] | |||||
:param K: [batch, seq_len, key_size] | |||||
:param V: [batch, seq_len, value_size] | |||||
:param seq_mask: [batch, seq_len] | |||||
""" | |||||
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale | |||||
if seq_mask is not None: | |||||
output.masked_fill_(seq_mask.lt(1), -float('inf')) | |||||
output = nn.functional.softmax(output, dim=2) | |||||
return torch.matmul(output, V) | |||||
class MultiHeadAtte(nn.Module): | |||||
def __init__(self, input_size, output_size, key_size, value_size, num_atte): | |||||
raise NotImplementedError | |||||
# TODO never test | |||||
super(MultiHeadAtte, self).__init__() | |||||
self.in_linear = nn.ModuleList() | |||||
for i in range(num_atte * 3): | |||||
out_feat = key_size if (i % 3) != 2 else value_size | |||||
self.in_linear.append(nn.Linear(input_size, out_feat)) | |||||
self.attes = nn.ModuleList([DotAtte(key_size, value_size) for _ in range(num_atte)]) | |||||
self.out_linear = nn.Linear(value_size * num_atte, output_size) | |||||
def forward(self, Q, K, V, seq_mask=None): | |||||
heads = [] | |||||
for i in range(len(self.attes)): | |||||
j = i * 3 | |||||
qi, ki, vi = self.in_linear[j](Q), self.in_linear[j+1](K), self.in_linear[j+2](V) | |||||
headi = self.attes[i](qi, ki, vi, seq_mask) | |||||
heads.append(headi) | |||||
output = torch.cat(heads, dim=2) | |||||
return self.out_linear(output) |
@@ -3,6 +3,7 @@ from torch import nn | |||||
from fastNLP.modules.utils import initial_parameter | from fastNLP.modules.utils import initial_parameter | ||||
def log_sum_exp(x, dim=-1): | def log_sum_exp(x, dim=-1): | ||||
max_value, _ = x.max(dim=dim, keepdim=True) | max_value, _ = x.max(dim=dim, keepdim=True) | ||||
res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value | res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value | ||||
@@ -20,7 +21,7 @@ def seq_len_to_byte_mask(seq_lens): | |||||
class ConditionalRandomField(nn.Module): | class ConditionalRandomField(nn.Module): | ||||
def __init__(self, tag_size, include_start_end_trans=True ,initial_method = None): | |||||
def __init__(self, tag_size, include_start_end_trans=False ,initial_method = None): | |||||
""" | """ | ||||
:param tag_size: int, num of tags | :param tag_size: int, num of tags | ||||
:param include_start_end_trans: bool, whether to include start/end tag | :param include_start_end_trans: bool, whether to include start/end tag | ||||
@@ -31,7 +32,7 @@ class ConditionalRandomField(nn.Module): | |||||
self.tag_size = tag_size | self.tag_size = tag_size | ||||
# the meaning of entry in this matrix is (from_tag_id, to_tag_id) score | # the meaning of entry in this matrix is (from_tag_id, to_tag_id) score | ||||
self.transition_m = nn.Parameter(torch.randn(tag_size, tag_size)) | |||||
self.trans_m = nn.Parameter(torch.randn(tag_size, tag_size)) | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
self.start_scores = nn.Parameter(torch.randn(tag_size)) | self.start_scores = nn.Parameter(torch.randn(tag_size)) | ||||
self.end_scores = nn.Parameter(torch.randn(tag_size)) | self.end_scores = nn.Parameter(torch.randn(tag_size)) | ||||
@@ -39,137 +40,121 @@ class ConditionalRandomField(nn.Module): | |||||
# self.reset_parameter() | # self.reset_parameter() | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def reset_parameter(self): | def reset_parameter(self): | ||||
nn.init.xavier_normal_(self.transition_m) | |||||
nn.init.xavier_normal_(self.trans_m) | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
nn.init.normal_(self.start_scores) | nn.init.normal_(self.start_scores) | ||||
nn.init.normal_(self.end_scores) | nn.init.normal_(self.end_scores) | ||||
def _normalizer_likelihood(self, feats, masks): | |||||
def _normalizer_likelihood(self, logits, mask): | |||||
""" | """ | ||||
Computes the (batch_size,) denominator term for the log-likelihood, which is the | Computes the (batch_size,) denominator term for the log-likelihood, which is the | ||||
sum of the likelihoods across all possible state sequences. | sum of the likelihoods across all possible state sequences. | ||||
:param feats:FloatTensor, batch_size x max_len x tag_size | |||||
:param masks:ByteTensor, batch_size x max_len | |||||
:param logits:FloatTensor, max_len x batch_size x tag_size | |||||
:param mask:ByteTensor, max_len x batch_size | |||||
:return:FloatTensor, batch_size | :return:FloatTensor, batch_size | ||||
""" | """ | ||||
batch_size, max_len, _ = feats.size() | |||||
# alpha, batch_size x tag_size | |||||
seq_len, batch_size, n_tags = logits.size() | |||||
alpha = logits[0] | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha = self.start_scores.view(1, -1) + feats[:, 0] | |||||
else: | |||||
alpha = feats[:, 0] | |||||
# broadcast_trans_m, the meaning of entry in this matrix is [batch_idx, to_tag_id, from_tag_id] | |||||
broadcast_trans_m = self.transition_m.permute( | |||||
1, 0).unsqueeze(0).repeat(batch_size, 1, 1) | |||||
# loop | |||||
for i in range(1, max_len): | |||||
emit_score = feats[:, i].unsqueeze(2) | |||||
new_alpha = broadcast_trans_m + alpha.unsqueeze(1) + emit_score | |||||
alpha += self.start_scores.view(1, -1) | |||||
new_alpha = log_sum_exp(new_alpha, dim=2) | |||||
alpha = new_alpha * \ | |||||
masks[:, i:i + 1].float() + alpha * \ | |||||
(1 - masks[:, i:i + 1].float()) | |||||
for i in range(1, seq_len): | |||||
emit_score = logits[i].view(batch_size, 1, n_tags) | |||||
trans_score = self.trans_m.view(1, n_tags, n_tags) | |||||
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score | |||||
alpha = log_sum_exp(tmp, 1) * mask[i].view(batch_size, 1) + alpha * (1 - mask[i]).view(batch_size, 1) | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha = alpha + self.end_scores.view(1, -1) | |||||
alpha += self.end_scores.view(1, -1) | |||||
return log_sum_exp(alpha) | |||||
return log_sum_exp(alpha, 1) | |||||
def _glod_score(self, feats, tags, masks): | |||||
def _glod_score(self, logits, tags, mask): | |||||
""" | """ | ||||
Compute the score for the gold path. | Compute the score for the gold path. | ||||
:param feats: FloatTensor, batch_size x max_len x tag_size | |||||
:param tags: LongTensor, batch_size x max_len | |||||
:param masks: ByteTensor, batch_size x max_len | |||||
:param logits: FloatTensor, max_len x batch_size x tag_size | |||||
:param tags: LongTensor, max_len x batch_size | |||||
:param mask: ByteTensor, max_len x batch_size | |||||
:return:FloatTensor, batch_size | :return:FloatTensor, batch_size | ||||
""" | """ | ||||
batch_size, max_len, _ = feats.size() | |||||
# alpha, B x 1 | |||||
if self.include_start_end_trans: | |||||
alpha = self.start_scores.view(1, -1).repeat(batch_size, 1).gather(dim=1, index=tags[:, :1]) + \ | |||||
feats[:, 0].gather(dim=1, index=tags[:, :1]) | |||||
else: | |||||
alpha = feats[:, 0].gather(dim=1, index=tags[:, :1]) | |||||
for i in range(1, max_len): | |||||
trans_score = self.transition_m[( | |||||
tags[:, i - 1], tags[:, i])].unsqueeze(1) | |||||
emit_score = feats[:, i].gather(dim=1, index=tags[:, i:i + 1]) | |||||
new_alpha = alpha + trans_score + emit_score | |||||
alpha = new_alpha * \ | |||||
masks[:, i:i + 1].float() + alpha * \ | |||||
(1 - masks[:, i:i + 1].float()) | |||||
seq_len, batch_size, _ = logits.size() | |||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | |||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | |||||
# trans_socre [L-1, B] | |||||
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]] * mask[1:, :] | |||||
# emit_score [L, B] | |||||
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask | |||||
# score [L-1, B] | |||||
score = trans_score + emit_score[:seq_len-1, :] | |||||
score = score.sum(0) + emit_score[-1] * mask[-1] | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
last_tag_index = masks.cumsum(dim=1, dtype=torch.long)[:, -1:] - 1 | |||||
last_from_tag_id = tags.gather(dim=1, index=last_tag_index) | |||||
trans_score = self.end_scores.view( | |||||
1, -1).repeat(batch_size, 1).gather(dim=1, index=last_from_tag_id) | |||||
alpha = alpha + trans_score | |||||
return alpha.squeeze(1) | |||||
def forward(self, feats, tags, masks): | |||||
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | |||||
last_idx = mask.long().sum(0) - 1 | |||||
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] | |||||
score += st_scores + ed_scores | |||||
# return [B,] | |||||
return score | |||||
def forward(self, feats, tags, mask): | |||||
""" | """ | ||||
Calculate the neg log likelihood | Calculate the neg log likelihood | ||||
:param feats:FloatTensor, batch_size x max_len x tag_size | :param feats:FloatTensor, batch_size x max_len x tag_size | ||||
:param tags:LongTensor, batch_size x max_len | :param tags:LongTensor, batch_size x max_len | ||||
:param masks:ByteTensor batch_size x max_len | |||||
:param mask:ByteTensor batch_size x max_len | |||||
:return:FloatTensor, batch_size | :return:FloatTensor, batch_size | ||||
""" | """ | ||||
all_path_score = self._normalizer_likelihood(feats, masks) | |||||
gold_path_score = self._glod_score(feats, tags, masks) | |||||
feats = feats.transpose(0, 1) | |||||
tags = tags.transpose(0, 1).long() | |||||
mask = mask.transpose(0, 1).float() | |||||
all_path_score = self._normalizer_likelihood(feats, mask) | |||||
gold_path_score = self._glod_score(feats, tags, mask) | |||||
return all_path_score - gold_path_score | return all_path_score - gold_path_score | ||||
def viterbi_decode(self, feats, masks, get_score=False): | |||||
def viterbi_decode(self, data, mask, get_score=False): | |||||
""" | """ | ||||
Given a feats matrix, return best decode path and best score. | Given a feats matrix, return best decode path and best score. | ||||
:param feats: | |||||
:param masks: | |||||
:param data:FloatTensor, batch_size x max_len x tag_size | |||||
:param mask:ByteTensor batch_size x max_len | |||||
:param get_score: bool, whether to output the decode score. | :param get_score: bool, whether to output the decode score. | ||||
:return:List[Tuple(List, float)], | |||||
:return: scores, paths | |||||
""" | """ | ||||
batch_size, max_len, tag_size = feats.size() | |||||
batch_size, seq_len, n_tags = data.size() | |||||
data = data.transpose(0, 1).data # L, B, H | |||||
mask = mask.transpose(0, 1).data.float() # L, B | |||||
paths = torch.zeros(batch_size, max_len - 1, self.tag_size) | |||||
# dp | |||||
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||||
vscore = data[0] | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha = self.start_scores.repeat(batch_size, 1) + feats[:, 0] | |||||
else: | |||||
alpha = feats[:, 0] | |||||
for i in range(1, max_len): | |||||
new_alpha = alpha.clone() | |||||
for t in range(self.tag_size): | |||||
pre_scores = self.transition_m[:, t].view( | |||||
1, self.tag_size) + alpha | |||||
max_score, indices = pre_scores.max(dim=1) | |||||
new_alpha[:, t] = max_score + feats[:, i, t] | |||||
paths[:, i - 1, t] = indices | |||||
alpha = new_alpha * masks[:, i:i + 1].float() + alpha * (1 - masks[:, i:i + 1].float()) | |||||
vscore += self.start_scores.view(1, -1) | |||||
for i in range(1, seq_len): | |||||
prev_score = vscore.view(batch_size, n_tags, 1) | |||||
cur_score = data[i].view(batch_size, 1, n_tags) | |||||
trans_score = self.trans_m.view(1, n_tags, n_tags).data | |||||
score = prev_score + trans_score + cur_score | |||||
best_score, best_dst = score.max(1) | |||||
vpath[i] = best_dst | |||||
vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1) | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha += self.end_scores.view(1, -1) | |||||
max_scores, indices = alpha.max(dim=1) | |||||
indices = indices.cpu().numpy() | |||||
final_paths = [] | |||||
paths = paths.cpu().numpy().astype(int) | |||||
seq_lens = masks.cumsum(dim=1, dtype=torch.long)[:, -1] | |||||
vscore += self.end_scores.view(1, -1) | |||||
# backtrace | |||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device) | |||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device) | |||||
lens = (mask.long().sum(0) - 1) | |||||
# idxes [L, B], batched idx from seq_len-1 to 0 | |||||
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len | |||||
ans = data.new_empty((seq_len, batch_size), dtype=torch.long) | |||||
ans_score, last_tags = vscore.max(1) | |||||
ans[idxes[0], batch_idx] = last_tags | |||||
for i in range(seq_len - 1): | |||||
last_tags = vpath[idxes[i], batch_idx, last_tags] | |||||
ans[idxes[i+1], batch_idx] = last_tags | |||||
for b in range(batch_size): | |||||
path = [indices[b]] | |||||
for i in range(seq_lens[b] - 2, -1, -1): | |||||
index = paths[b, i, path[-1]] | |||||
path.append(index) | |||||
final_paths.append(path[::-1]) | |||||
if get_score: | if get_score: | ||||
return list(zip(final_paths, max_scores.detach().cpu().numpy())) | |||||
else: | |||||
return final_paths | |||||
return ans_score, ans.transpose(0, 1) | |||||
return ans.transpose(0, 1) |
@@ -1,13 +1,15 @@ | |||||
import torch | import torch | ||||
class TimestepDropout(torch.nn.Dropout): | class TimestepDropout(torch.nn.Dropout): | ||||
"""This module accepts a `[batch_size, num_timesteps, embedding_dim)]` and use a single | """This module accepts a `[batch_size, num_timesteps, embedding_dim)]` and use a single | ||||
dropout mask of shape `(batch_size, embedding_dim)` to apply on every time step. | dropout mask of shape `(batch_size, embedding_dim)` to apply on every time step. | ||||
""" | """ | ||||
def forward(self, x): | def forward(self, x): | ||||
dropout_mask = x.new_ones(x.shape[0], x.shape[-1]) | dropout_mask = x.new_ones(x.shape[0], x.shape[-1]) | ||||
torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True) | torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True) | ||||
dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim] | |||||
dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim] | |||||
if self.inplace: | if self.inplace: | ||||
x *= dropout_mask | x *= dropout_mask | ||||
return | return | ||||
@@ -34,8 +34,6 @@ class ConvMaxpool(nn.Module): | |||||
bias=bias) | bias=bias) | ||||
for oc, ks in zip(out_channels, kernel_sizes)]) | for oc, ks in zip(out_channels, kernel_sizes)]) | ||||
for conv in self.convs: | |||||
xavier_uniform_(conv.weight) # weight initialization | |||||
else: | else: | ||||
raise Exception( | raise Exception( | ||||
'Incorrect kernel sizes: should be list, tuple or int') | 'Incorrect kernel sizes: should be list, tuple or int') | ||||
@@ -0,0 +1,32 @@ | |||||
import torch | |||||
from torch import nn | |||||
import torch.nn.functional as F | |||||
from ..aggregator.attention import MultiHeadAtte | |||||
from ..other_modules import LayerNormalization | |||||
class TransformerEncoder(nn.Module): | |||||
class SubLayer(nn.Module): | |||||
def __init__(self, input_size, output_size, key_size, value_size, num_atte): | |||||
super(TransformerEncoder.SubLayer, self).__init__() | |||||
self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte) | |||||
self.norm1 = LayerNormalization(output_size) | |||||
self.ffn = nn.Sequential(nn.Linear(output_size, output_size), | |||||
nn.ReLU(), | |||||
nn.Linear(output_size, output_size)) | |||||
self.norm2 = LayerNormalization(output_size) | |||||
def forward(self, input, seq_mask): | |||||
attention = self.atte(input) | |||||
norm_atte = self.norm1(attention + input) | |||||
output = self.ffn(norm_atte) | |||||
return self.norm2(output + norm_atte) | |||||
def __init__(self, num_layers, **kargs): | |||||
super(TransformerEncoder, self).__init__() | |||||
self.layers = nn.Sequential(*[self.SubLayer(**kargs) for _ in range(num_layers)]) | |||||
def forward(self, x, seq_mask=None): | |||||
return self.layers(x, seq_mask) | |||||
@@ -101,14 +101,14 @@ class VarRNNBase(nn.Module): | |||||
mask_x = input.new_ones((batch_size, self.input_size)) | mask_x = input.new_ones((batch_size, self.input_size)) | ||||
mask_out = input.new_ones((batch_size, self.hidden_size * self.num_directions)) | mask_out = input.new_ones((batch_size, self.hidden_size * self.num_directions)) | ||||
mask_h = input.new_ones((batch_size, self.hidden_size)) | |||||
mask_h_ones = input.new_ones((batch_size, self.hidden_size)) | |||||
nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | ||||
nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | ||||
nn.functional.dropout(mask_h, p=self.hidden_dropout, training=self.training, inplace=True) | |||||
hidden_list = [] | hidden_list = [] | ||||
for layer in range(self.num_layers): | for layer in range(self.num_layers): | ||||
output_list = [] | output_list = [] | ||||
mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) | |||||
for direction in range(self.num_directions): | for direction in range(self.num_directions): | ||||
input_x = input if direction == 0 else flip(input, [0]) | input_x = input if direction == 0 else flip(input, [0]) | ||||
idx = self.num_directions * layer + direction | idx = self.num_directions * layer + direction | ||||
@@ -31,12 +31,12 @@ class GroupNorm(nn.Module): | |||||
class LayerNormalization(nn.Module): | class LayerNormalization(nn.Module): | ||||
""" Layer normalization module """ | """ Layer normalization module """ | ||||
def __init__(self, d_hid, eps=1e-3): | |||||
def __init__(self, layer_size, eps=1e-3): | |||||
super(LayerNormalization, self).__init__() | super(LayerNormalization, self).__init__() | ||||
self.eps = eps | self.eps = eps | ||||
self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True) | |||||
self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True) | |||||
self.a_2 = nn.Parameter(torch.ones(1, layer_size, requires_grad=True)) | |||||
self.b_2 = nn.Parameter(torch.zeros(1, layer_size, requires_grad=True)) | |||||
def forward(self, z): | def forward(self, z): | ||||
if z.size(1) == 1: | if z.size(1) == 1: | ||||
@@ -44,9 +44,8 @@ class LayerNormalization(nn.Module): | |||||
mu = torch.mean(z, keepdim=True, dim=-1) | mu = torch.mean(z, keepdim=True, dim=-1) | ||||
sigma = torch.std(z, keepdim=True, dim=-1) | sigma = torch.std(z, keepdim=True, dim=-1) | ||||
ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps) | |||||
ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out) | |||||
ln_out = (z - mu) / (sigma + self.eps) | |||||
ln_out = ln_out * self.a_2 + self.b_2 | |||||
return ln_out | return ln_out | ||||
@@ -77,11 +77,13 @@ def initial_parameter(net, initial_method=None): | |||||
def seq_mask(seq_len, max_len): | def seq_mask(seq_len, max_len): | ||||
"""Create sequence mask. | """Create sequence mask. | ||||
:param seq_len: list of int, the lengths of sequences in a batch. | |||||
:param seq_len: list or torch.Tensor, the lengths of sequences in a batch. | |||||
:param max_len: int, the maximum sequence length in a batch. | :param max_len: int, the maximum sequence length in a batch. | ||||
:return mask: torch.LongTensor, [batch_size, max_len] | :return mask: torch.LongTensor, [batch_size, max_len] | ||||
""" | """ | ||||
mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)] | |||||
mask = torch.stack(mask, 1) | |||||
return mask | |||||
if not isinstance(seq_len, torch.Tensor): | |||||
seq_len = torch.LongTensor(seq_len) | |||||
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] | |||||
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] | |||||
return torch.gt(seq_len, seq_range) # [batch_size, max_len] |
@@ -1,37 +1,40 @@ | |||||
[train] | [train] | ||||
epochs = 50 | |||||
epochs = -1 | |||||
batch_size = 16 | batch_size = 16 | ||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
validate = true | validate = true | ||||
save_best_dev = false | |||||
save_best_dev = true | |||||
eval_sort_key = "UAS" | |||||
use_cuda = true | use_cuda = true | ||||
model_saved_path = "./save/" | model_saved_path = "./save/" | ||||
task = "parse" | |||||
print_every_step = 20 | |||||
use_golden_train=true | |||||
[test] | [test] | ||||
save_output = true | save_output = true | ||||
validate_in_training = true | validate_in_training = true | ||||
save_dev_input = false | save_dev_input = false | ||||
save_loss = true | save_loss = true | ||||
batch_size = 16 | |||||
batch_size = 64 | |||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
use_cuda = true | use_cuda = true | ||||
task = "parse" | |||||
[model] | [model] | ||||
word_vocab_size = -1 | word_vocab_size = -1 | ||||
word_emb_dim = 100 | word_emb_dim = 100 | ||||
pos_vocab_size = -1 | pos_vocab_size = -1 | ||||
pos_emb_dim = 100 | pos_emb_dim = 100 | ||||
word_hid_dim = 100 | |||||
pos_hid_dim = 100 | |||||
rnn_layers = 3 | rnn_layers = 3 | ||||
rnn_hidden_size = 400 | rnn_hidden_size = 400 | ||||
arc_mlp_size = 500 | arc_mlp_size = 500 | ||||
label_mlp_size = 100 | label_mlp_size = 100 | ||||
num_label = -1 | num_label = -1 | ||||
dropout = 0.33 | dropout = 0.33 | ||||
use_var_lstm=true | |||||
use_var_lstm=false | |||||
use_greedy_infer=false | use_greedy_infer=false | ||||
[optim] | [optim] | ||||
lr = 2e-3 | lr = 2e-3 | ||||
weight_decay = 5e-5 |
@@ -0,0 +1,83 @@ | |||||
import os | |||||
import sys | |||||
sys.path.extend(['/home/yfshao/workdir/dev_fastnlp']) | |||||
from fastNLP.api.processor import * | |||||
from fastNLP.models.biaffine_parser import BiaffineParser | |||||
from fastNLP.io.config_loader import ConfigSection, ConfigLoader | |||||
import _pickle as pickle | |||||
import torch | |||||
def _load(path): | |||||
with open(path, 'rb') as f: | |||||
obj = pickle.load(f) | |||||
return obj | |||||
def _load_all(src): | |||||
model_path = src | |||||
src = os.path.dirname(src) | |||||
word_v = _load(src+'/word_v.pkl') | |||||
pos_v = _load(src+'/pos_v.pkl') | |||||
tag_v = _load(src+'/tag_v.pkl') | |||||
pos_pp = torch.load(src+'/pos_pp.pkl')['pipeline'] | |||||
model_args = ConfigSection() | |||||
ConfigLoader.load_config('cfg.cfg', {'model': model_args}) | |||||
model_args['word_vocab_size'] = len(word_v) | |||||
model_args['pos_vocab_size'] = len(pos_v) | |||||
model_args['num_label'] = len(tag_v) | |||||
model = BiaffineParser(**model_args.data) | |||||
model.load_state_dict(torch.load(model_path)) | |||||
return { | |||||
'word_v': word_v, | |||||
'pos_v': pos_v, | |||||
'tag_v': tag_v, | |||||
'model': model, | |||||
'pos_pp':pos_pp, | |||||
} | |||||
def build(load_path, save_path): | |||||
BOS = '<BOS>' | |||||
NUM = '<NUM>' | |||||
_dict = _load_all(load_path) | |||||
word_vocab = _dict['word_v'] | |||||
pos_vocab = _dict['pos_v'] | |||||
tag_vocab = _dict['tag_v'] | |||||
pos_pp = _dict['pos_pp'] | |||||
model = _dict['model'] | |||||
print('load model from {}'.format(load_path)) | |||||
word_seq = 'raw_word_seq' | |||||
pos_seq = 'raw_pos_seq' | |||||
# build pipeline | |||||
# input | |||||
pipe = pos_pp | |||||
pipe.pipeline.pop(-1) | |||||
pipe.add_processor(Num2TagProcessor(NUM, 'word_list', word_seq)) | |||||
pipe.add_processor(PreAppendProcessor(BOS, word_seq)) | |||||
pipe.add_processor(PreAppendProcessor(BOS, 'pos_list', pos_seq)) | |||||
pipe.add_processor(IndexerProcessor(word_vocab, word_seq, 'word_seq')) | |||||
pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, 'pos_seq')) | |||||
pipe.add_processor(SeqLenProcessor('word_seq', 'word_seq_origin_len')) | |||||
pipe.add_processor(SetTensorProcessor({'word_seq':True, 'pos_seq':True, 'word_seq_origin_len':True}, default=False)) | |||||
pipe.add_processor(ModelProcessor(model, 'word_seq_origin_len')) | |||||
pipe.add_processor(SliceProcessor(1, None, None, 'head_pred', 'heads')) | |||||
pipe.add_processor(SliceProcessor(1, None, None, 'label_pred', 'label_pred')) | |||||
pipe.add_processor(Index2WordProcessor(tag_vocab, 'label_pred', 'labels')) | |||||
if not os.path.exists(save_path): | |||||
os.makedirs(save_path) | |||||
with open(save_path+'/pipeline.pkl', 'wb') as f: | |||||
torch.save({'pipeline': pipe}, f) | |||||
print('save pipeline in {}'.format(save_path)) | |||||
import argparse | |||||
parser = argparse.ArgumentParser(description='build pipeline for parser.') | |||||
parser.add_argument('--src', type=str, default='/home/yfshao/workdir/dev_fastnlp/reproduction/Biaffine_parser/save') | |||||
parser.add_argument('--dst', type=str, default='/home/yfshao/workdir/dev_fastnlp/reproduction/Biaffine_parser/pipe') | |||||
args = parser.parse_args() | |||||
build(args.src, args.dst) |
@@ -0,0 +1,114 @@ | |||||
import sys | |||||
sys.path.extend(['/home/yfshao/workdir/dev_fastnlp']) | |||||
import torch | |||||
import argparse | |||||
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('--pipe', type=str, default='') | |||||
parser.add_argument('--gold_data', type=str, default='') | |||||
parser.add_argument('--new_data', type=str) | |||||
args = parser.parse_args() | |||||
pipe = torch.load(args.pipe)['pipeline'] | |||||
for p in pipe: | |||||
if p.field_name == 'word_list': | |||||
print(p.field_name) | |||||
p.field_name = 'gold_words' | |||||
elif p.field_name == 'pos_list': | |||||
print(p.field_name) | |||||
p.field_name = 'gold_pos' | |||||
data = ConllxDataLoader().load(args.gold_data) | |||||
ds = DataSet() | |||||
for ins1, ins2 in zip(add_seg_tag(data), data): | |||||
ds.append(Instance(words=ins1[0], tag=ins1[1], | |||||
gold_words=ins2[0], gold_pos=ins2[1], | |||||
gold_heads=ins2[2], gold_head_tags=ins2[3])) | |||||
ds = pipe(ds) | |||||
seg_threshold = 0. | |||||
pos_threshold = 0. | |||||
parse_threshold = 0.74 | |||||
def get_heads(ins, head_f, word_f): | |||||
head_pred = [] | |||||
for i, idx in enumerate(ins[head_f]): | |||||
j = idx - 1 if idx != 0 else i | |||||
head_pred.append(ins[word_f][j]) | |||||
return head_pred | |||||
def evaluate(ins): | |||||
seg_count = sum([1 for i, j in zip(ins['word_list'], ins['gold_words']) if i == j]) | |||||
pos_count = sum([1 for i, j in zip(ins['pos_list'], ins['gold_pos']) if i == j]) | |||||
head_count = sum([1 for i, j in zip(ins['heads'], ins['gold_heads']) if i == j]) | |||||
total = len(ins['gold_words']) | |||||
return seg_count / total, pos_count / total, head_count / total | |||||
def is_ok(x): | |||||
seg, pos, head = x[1] | |||||
return seg > seg_threshold and pos > pos_threshold and head > parse_threshold | |||||
res_list = [] | |||||
for i, ins in enumerate(ds): | |||||
res_list.append((i, evaluate(ins))) | |||||
res_list = list(filter(is_ok, res_list)) | |||||
print('{} {}'.format(len(ds), len(res_list))) | |||||
seg_cor, pos_cor, head_cor, label_cor, total = 0,0,0,0,0 | |||||
for i, _ in res_list: | |||||
ins = ds[i] | |||||
# print(i) | |||||
# print('gold_words:\t', ins['gold_words']) | |||||
# print('predict_words:\t', ins['word_list']) | |||||
# print('gold_tag:\t', ins['gold_pos']) | |||||
# print('predict_tag:\t', ins['pos_list']) | |||||
# print('gold_heads:\t', ins['gold_heads']) | |||||
# print('predict_heads:\t', ins['heads'].tolist()) | |||||
# print('gold_head_tags:\t', ins['gold_head_tags']) | |||||
# print('predict_labels:\t', ins['labels']) | |||||
# print() | |||||
head_pred = ins['heads'] | |||||
head_gold = ins['gold_heads'] | |||||
label_pred = ins['labels'] | |||||
label_gold = ins['gold_head_tags'] | |||||
total += len(head_gold) | |||||
seg_cor += sum([1 for i, j in zip(ins['word_list'], ins['gold_words']) if i == j]) | |||||
pos_cor += sum([1 for i, j in zip(ins['pos_list'], ins['gold_pos']) if i == j]) | |||||
length = len(head_gold) | |||||
for i in range(length): | |||||
head_cor += 1 if head_pred[i] == head_gold[i] else 0 | |||||
label_cor += 1 if head_pred[i] == head_gold[i] and label_gold[i] == label_pred[i] else 0 | |||||
print('SEG: {}, POS: {}, UAS: {}, LAS: {}'.format(seg_cor/total, pos_cor/total, head_cor/total, label_cor/total)) | |||||
colln_path = args.gold_data | |||||
new_colln_path = args.new_data | |||||
index_list = [x[0] for x in res_list] | |||||
with open(colln_path, 'r', encoding='utf-8') as f1, \ | |||||
open(new_colln_path, 'w', encoding='utf-8') as f2: | |||||
for idx, ins in enumerate(ds): | |||||
if idx in index_list: | |||||
length = len(ins['gold_words']) | |||||
pad = ['_' for _ in range(length)] | |||||
for x in zip( | |||||
map(str, range(1, length+1)), ins['gold_words'], ins['gold_words'], ins['gold_pos'], | |||||
pad, pad, map(str, ins['gold_heads']), ins['gold_head_tags']): | |||||
new_lines = '\t'.join(x) | |||||
f2.write(new_lines) | |||||
f2.write('\n') | |||||
f2.write('\n') |
@@ -3,34 +3,34 @@ import sys | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | ||||
from collections import defaultdict | |||||
import math | |||||
import torch | import torch | ||||
import re | |||||
from fastNLP.core.trainer import Trainer | from fastNLP.core.trainer import Trainer | ||||
from fastNLP.core.metrics import Evaluator | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.field import TextField, SeqLabelField | from fastNLP.core.field import TextField, SeqLabelField | ||||
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||||
from fastNLP.core.tester import Tester | from fastNLP.core.tester import Tester | ||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.loader.model_loader import ModelLoader | |||||
from fastNLP.loader.embed_loader import EmbedLoader | |||||
from fastNLP.io.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.io.model_loader import ModelLoader | |||||
from fastNLP.io.embed_loader import EmbedLoader | |||||
from fastNLP.models.biaffine_parser import BiaffineParser | from fastNLP.models.biaffine_parser import BiaffineParser | ||||
from fastNLP.saver.model_saver import ModelSaver | |||||
from fastNLP.io.model_saver import ModelSaver | |||||
BOS = '<BOS>' | |||||
EOS = '<EOS>' | |||||
UNK = '<OOV>' | |||||
NUM = '<NUM>' | |||||
ENG = '<ENG>' | |||||
# not in the file's dir | # not in the file's dir | ||||
if len(os.path.dirname(__file__)) != 0: | if len(os.path.dirname(__file__)) != 0: | ||||
os.chdir(os.path.dirname(__file__)) | os.chdir(os.path.dirname(__file__)) | ||||
class MyDataLoader(object): | |||||
def __init__(self, pickle_path): | |||||
self.pickle_path = pickle_path | |||||
def load(self, path, word_v=None, pos_v=None, headtag_v=None): | |||||
class ConlluDataLoader(object): | |||||
def load(self, path): | |||||
datalist = [] | datalist = [] | ||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
sample = [] | sample = [] | ||||
@@ -49,23 +49,18 @@ class MyDataLoader(object): | |||||
for sample in datalist: | for sample in datalist: | ||||
# print(sample) | # print(sample) | ||||
res = self.get_one(sample) | res = self.get_one(sample) | ||||
if word_v is not None: | |||||
word_v.update(res[0]) | |||||
pos_v.update(res[1]) | |||||
headtag_v.update(res[3]) | |||||
ds.append(Instance(word_seq=TextField(res[0], is_target=False), | ds.append(Instance(word_seq=TextField(res[0], is_target=False), | ||||
pos_seq=TextField(res[1], is_target=False), | pos_seq=TextField(res[1], is_target=False), | ||||
head_indices=SeqLabelField(res[2], is_target=True), | head_indices=SeqLabelField(res[2], is_target=True), | ||||
head_labels=TextField(res[3], is_target=True), | |||||
seq_mask=SeqLabelField([1 for _ in range(len(res[0]))], is_target=False))) | |||||
head_labels=TextField(res[3], is_target=True))) | |||||
return ds | return ds | ||||
def get_one(self, sample): | def get_one(self, sample): | ||||
text = ['<root>'] | |||||
pos_tags = ['<root>'] | |||||
heads = [0] | |||||
head_tags = ['root'] | |||||
text = [] | |||||
pos_tags = [] | |||||
heads = [] | |||||
head_tags = [] | |||||
for w in sample: | for w in sample: | ||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | ||||
if t3 == '_': | if t3 == '_': | ||||
@@ -76,17 +71,60 @@ class MyDataLoader(object): | |||||
head_tags.append(t4) | head_tags.append(t4) | ||||
return (text, pos_tags, heads, head_tags) | return (text, pos_tags, heads, head_tags) | ||||
def index_data(self, dataset, word_v, pos_v, tag_v): | |||||
dataset.index_field('word_seq', word_v) | |||||
dataset.index_field('pos_seq', pos_v) | |||||
dataset.index_field('head_labels', tag_v) | |||||
class CTBDataLoader(object): | |||||
def load(self, data_path): | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
lines = f.readlines() | |||||
data = self.parse(lines) | |||||
return self.convert(data) | |||||
def parse(self, lines): | |||||
""" | |||||
[ | |||||
[word], [pos], [head_index], [head_tag] | |||||
] | |||||
""" | |||||
sample = [] | |||||
data = [] | |||||
for i, line in enumerate(lines): | |||||
line = line.strip() | |||||
if len(line) == 0 or i+1 == len(lines): | |||||
data.append(list(map(list, zip(*sample)))) | |||||
sample = [] | |||||
else: | |||||
sample.append(line.split()) | |||||
return data | |||||
def convert(self, data): | |||||
dataset = DataSet() | |||||
for sample in data: | |||||
word_seq = [BOS] + sample[0] + [EOS] | |||||
pos_seq = [BOS] + sample[1] + [EOS] | |||||
heads = [0] + list(map(int, sample[2])) + [0] | |||||
head_tags = [BOS] + sample[3] + [EOS] | |||||
dataset.append(Instance(word_seq=TextField(word_seq, is_target=False), | |||||
pos_seq=TextField(pos_seq, is_target=False), | |||||
gold_heads=SeqLabelField(heads, is_target=False), | |||||
head_indices=SeqLabelField(heads, is_target=True), | |||||
head_labels=TextField(head_tags, is_target=True))) | |||||
return dataset | |||||
# datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT" | # datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT" | ||||
datadir = "/home/yfshao/UD_English-EWT" | |||||
# datadir = "/home/yfshao/UD_English-EWT" | |||||
# train_data_name = "en_ewt-ud-train.conllu" | |||||
# dev_data_name = "en_ewt-ud-dev.conllu" | |||||
# emb_file_name = '/home/yfshao/glove.6B.100d.txt' | |||||
# loader = ConlluDataLoader() | |||||
datadir = '/home/yfshao/workdir/parser-data/' | |||||
train_data_name = "train_ctb5.txt" | |||||
dev_data_name = "dev_ctb5.txt" | |||||
test_data_name = "test_ctb5.txt" | |||||
emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt" | |||||
# emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec" | |||||
loader = CTBDataLoader() | |||||
cfgfile = './cfg.cfg' | cfgfile = './cfg.cfg' | ||||
train_data_name = "en_ewt-ud-train.conllu" | |||||
dev_data_name = "en_ewt-ud-dev.conllu" | |||||
emb_file_name = '/home/yfshao/glove.6B.100d.txt' | |||||
processed_datadir = './save' | processed_datadir = './save' | ||||
# Config Loader | # Config Loader | ||||
@@ -95,8 +133,12 @@ test_args = ConfigSection() | |||||
model_args = ConfigSection() | model_args = ConfigSection() | ||||
optim_args = ConfigSection() | optim_args = ConfigSection() | ||||
ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) | ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) | ||||
print('trainre Args:', train_args.data) | |||||
print('test Args:', test_args.data) | |||||
print('optim Args:', optim_args.data) | |||||
# Data Loader | |||||
# Pickle Loader | |||||
def save_data(dirpath, **kwargs): | def save_data(dirpath, **kwargs): | ||||
import _pickle | import _pickle | ||||
if not os.path.exists(dirpath): | if not os.path.exists(dirpath): | ||||
@@ -117,38 +159,57 @@ def load_data(dirpath): | |||||
datas[name] = _pickle.load(f) | datas[name] = _pickle.load(f) | ||||
return datas | return datas | ||||
class MyTester(object): | |||||
def __init__(self, batch_size, use_cuda=False, **kwagrs): | |||||
self.batch_size = batch_size | |||||
self.use_cuda = use_cuda | |||||
def test(self, model, dataset): | |||||
self.model = model.cuda() if self.use_cuda else model | |||||
self.model.eval() | |||||
batchiter = Batch(dataset, self.batch_size, SequentialSampler(), self.use_cuda) | |||||
eval_res = defaultdict(list) | |||||
i = 0 | |||||
for batch_x, batch_y in batchiter: | |||||
with torch.no_grad(): | |||||
pred_y = self.model(**batch_x) | |||||
eval_one = self.model.evaluate(**pred_y, **batch_y) | |||||
i += self.batch_size | |||||
for eval_name, tensor in eval_one.items(): | |||||
eval_res[eval_name].append(tensor) | |||||
tmp = {} | |||||
for eval_name, tensorlist in eval_res.items(): | |||||
tmp[eval_name] = torch.cat(tensorlist, dim=0) | |||||
self.res = self.model.metrics(**tmp) | |||||
def show_metrics(self): | |||||
s = "" | |||||
for name, val in self.res.items(): | |||||
s += '{}: {:.2f}\t'.format(name, val) | |||||
return s | |||||
loader = MyDataLoader('') | |||||
def P2(data, field, length): | |||||
ds = [ins for ins in data if ins[field].get_length() >= length] | |||||
data.clear() | |||||
data.extend(ds) | |||||
return ds | |||||
def P1(data, field): | |||||
def reeng(w): | |||||
return w if w == BOS or w == EOS or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else ENG | |||||
def renum(w): | |||||
return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else NUM | |||||
for ins in data: | |||||
ori = ins[field].contents() | |||||
s = list(map(renum, map(reeng, ori))) | |||||
if s != ori: | |||||
# print(ori) | |||||
# print(s) | |||||
# print() | |||||
ins[field] = ins[field].new(s) | |||||
return data | |||||
class ParserEvaluator(Evaluator): | |||||
def __init__(self, ignore_label): | |||||
super(ParserEvaluator, self).__init__() | |||||
self.ignore = ignore_label | |||||
def __call__(self, predict_list, truth_list): | |||||
head_all, label_all, total_all = 0, 0, 0 | |||||
for pred, truth in zip(predict_list, truth_list): | |||||
head, label, total = self.evaluate(**pred, **truth) | |||||
head_all += head | |||||
label_all += label | |||||
total_all += total | |||||
return {'UAS': head_all*1.0 / total_all, 'LAS': label_all*1.0 / total_all} | |||||
def evaluate(self, head_pred, label_pred, head_indices, head_labels, seq_mask, **_): | |||||
""" | |||||
Evaluate the performance of prediction. | |||||
:return : performance results. | |||||
head_pred_corrct: number of correct predicted heads. | |||||
label_pred_correct: number of correct predicted labels. | |||||
total_tokens: number of predicted tokens | |||||
""" | |||||
seq_mask *= (head_labels != self.ignore).long() | |||||
head_pred_correct = (head_pred == head_indices).long() * seq_mask | |||||
_, label_preds = torch.max(label_pred, dim=2) | |||||
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct | |||||
return head_pred_correct.sum().item(), label_pred_correct.sum().item(), seq_mask.sum().item() | |||||
try: | try: | ||||
data_dict = load_data(processed_datadir) | data_dict = load_data(processed_datadir) | ||||
word_v = data_dict['word_v'] | word_v = data_dict['word_v'] | ||||
@@ -156,62 +217,90 @@ try: | |||||
tag_v = data_dict['tag_v'] | tag_v = data_dict['tag_v'] | ||||
train_data = data_dict['train_data'] | train_data = data_dict['train_data'] | ||||
dev_data = data_dict['dev_data'] | dev_data = data_dict['dev_data'] | ||||
test_data = data_dict['test_data'] | |||||
print('use saved pickles') | print('use saved pickles') | ||||
except Exception as _: | except Exception as _: | ||||
print('load raw data and preprocess') | print('load raw data and preprocess') | ||||
# use pretrain embedding | |||||
word_v = Vocabulary(need_default=True, min_freq=2) | word_v = Vocabulary(need_default=True, min_freq=2) | ||||
word_v.unknown_label = UNK | |||||
pos_v = Vocabulary(need_default=True) | pos_v = Vocabulary(need_default=True) | ||||
tag_v = Vocabulary(need_default=False) | tag_v = Vocabulary(need_default=False) | ||||
train_data = loader.load(os.path.join(datadir, train_data_name), word_v, pos_v, tag_v) | |||||
train_data = loader.load(os.path.join(datadir, train_data_name)) | |||||
dev_data = loader.load(os.path.join(datadir, dev_data_name)) | dev_data = loader.load(os.path.join(datadir, dev_data_name)) | ||||
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data) | |||||
test_data = loader.load(os.path.join(datadir, test_data_name)) | |||||
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v) | |||||
datasets = (train_data, dev_data, test_data) | |||||
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data) | |||||
loader.index_data(train_data, word_v, pos_v, tag_v) | |||||
loader.index_data(dev_data, word_v, pos_v, tag_v) | |||||
print(len(train_data)) | |||||
print(len(dev_data)) | |||||
ep = train_args['epochs'] | |||||
train_args['epochs'] = math.ceil(50000.0 / len(train_data) * train_args['batch_size']) if ep <= 0 else ep | |||||
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) | |||||
print(len(word_v)) | |||||
print(embed.size()) | |||||
# Model | |||||
model_args['word_vocab_size'] = len(word_v) | model_args['word_vocab_size'] = len(word_v) | ||||
model_args['pos_vocab_size'] = len(pos_v) | model_args['pos_vocab_size'] = len(pos_v) | ||||
model_args['num_label'] = len(tag_v) | model_args['num_label'] = len(tag_v) | ||||
model = BiaffineParser(**model_args.data) | |||||
model.reset_parameters() | |||||
datasets = (train_data, dev_data, test_data) | |||||
for ds in datasets: | |||||
ds.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) | |||||
ds.set_origin_len('word_seq') | |||||
if train_args['use_golden_train']: | |||||
train_data.set_target(gold_heads=False) | |||||
else: | |||||
train_data.set_target(gold_heads=None) | |||||
train_args.data.pop('use_golden_train') | |||||
ignore_label = pos_v['P'] | |||||
print(test_data[0]) | |||||
print(len(train_data)) | |||||
print(len(dev_data)) | |||||
print(len(test_data)) | |||||
def train(): | |||||
def train(path): | |||||
# Trainer | # Trainer | ||||
trainer = Trainer(**train_args.data) | trainer = Trainer(**train_args.data) | ||||
def _define_optim(obj): | def _define_optim(obj): | ||||
obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data) | |||||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: .75 ** (ep / 5e4)) | |||||
lr = optim_args.data['lr'] | |||||
embed_params = set(obj._model.word_embedding.parameters()) | |||||
decay_params = set(obj._model.arc_predictor.parameters()) | set(obj._model.label_predictor.parameters()) | |||||
params = [p for p in obj._model.parameters() if p not in decay_params and p not in embed_params] | |||||
obj._optimizer = torch.optim.Adam([ | |||||
{'params': list(embed_params), 'lr':lr*0.1}, | |||||
{'params': list(decay_params), **optim_args.data}, | |||||
{'params': params} | |||||
], lr=lr, betas=(0.9, 0.9)) | |||||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05)) | |||||
def _update(obj): | def _update(obj): | ||||
# torch.nn.utils.clip_grad_norm_(obj._model.parameters(), 5.0) | |||||
obj._scheduler.step() | obj._scheduler.step() | ||||
obj._optimizer.step() | obj._optimizer.step() | ||||
trainer.define_optimizer = lambda: _define_optim(trainer) | trainer.define_optimizer = lambda: _define_optim(trainer) | ||||
trainer.update = lambda: _update(trainer) | trainer.update = lambda: _update(trainer) | ||||
trainer.get_loss = lambda predict, truth: trainer._loss_func(**predict, **truth) | |||||
trainer._create_validator = lambda x: MyTester(**test_args.data) | |||||
# Model | |||||
model = BiaffineParser(**model_args.data) | |||||
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label))) | |||||
# use pretrain embedding | |||||
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) | |||||
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) | model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) | ||||
model.word_embedding.padding_idx = word_v.padding_idx | model.word_embedding.padding_idx = word_v.padding_idx | ||||
model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | ||||
model.pos_embedding.padding_idx = pos_v.padding_idx | model.pos_embedding.padding_idx = pos_v.padding_idx | ||||
model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) | model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) | ||||
try: | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
print('model parameter loaded!') | |||||
except Exception as _: | |||||
print("No saved model. Continue.") | |||||
pass | |||||
# try: | |||||
# ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
# print('model parameter loaded!') | |||||
# except Exception as _: | |||||
# print("No saved model. Continue.") | |||||
# pass | |||||
# Start training | # Start training | ||||
trainer.train(model, train_data, dev_data) | trainer.train(model, train_data, dev_data) | ||||
@@ -223,24 +312,27 @@ def train(): | |||||
print("Model saved!") | print("Model saved!") | ||||
def test(): | |||||
def test(path): | |||||
# Tester | # Tester | ||||
tester = MyTester(**test_args.data) | |||||
tester = Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label)) | |||||
# Model | # Model | ||||
model = BiaffineParser(**model_args.data) | model = BiaffineParser(**model_args.data) | ||||
model.eval() | |||||
try: | try: | ||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
ModelLoader.load_pytorch(model, path) | |||||
print('model parameter loaded!') | print('model parameter loaded!') | ||||
except Exception as _: | except Exception as _: | ||||
print("No saved model. Abort test.") | print("No saved model. Abort test.") | ||||
raise | raise | ||||
# Start training | # Start training | ||||
print("Testing Train data") | |||||
tester.test(model, train_data) | |||||
print("Testing Dev data") | |||||
tester.test(model, dev_data) | tester.test(model, dev_data) | ||||
print(tester.show_metrics()) | |||||
print("Testing finished!") | |||||
print("Testing Test data") | |||||
tester.test(model, test_data) | |||||
@@ -248,13 +340,14 @@ if __name__ == "__main__": | |||||
import argparse | import argparse | ||||
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | ||||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | ||||
parser.add_argument('--path', type=str, default='') | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
if args.mode == 'train': | if args.mode == 'train': | ||||
train() | |||||
train(args.path) | |||||
elif args.mode == 'test': | elif args.mode == 'test': | ||||
test() | |||||
test(args.path) | |||||
elif args.mode == 'infer': | elif args.mode == 'infer': | ||||
infer() | |||||
pass | |||||
else: | else: | ||||
print('no mode specified for model!') | print('no mode specified for model!') | ||||
parser.print_help() | parser.print_help() |
@@ -0,0 +1,78 @@ | |||||
class ConllxDataLoader(object): | |||||
def load(self, path): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
data = [self.get_one(sample) for sample in datalist] | |||||
return list(filter(lambda x: x is not None, data)) | |||||
def get_one(self, sample): | |||||
sample = list(map(list, zip(*sample))) | |||||
if len(sample) == 0: | |||||
return None | |||||
for w in sample[7]: | |||||
if w == '_': | |||||
print('Error Sample {}'.format(sample)) | |||||
return None | |||||
# return word_seq, pos_seq, head_seq, head_tag_seq | |||||
return sample[1], sample[3], list(map(int, sample[6])), sample[7] | |||||
class MyDataloader: | |||||
def load(self, data_path): | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
lines = f.readlines() | |||||
data = self.parse(lines) | |||||
return data | |||||
def parse(self, lines): | |||||
""" | |||||
[ | |||||
[word], [pos], [head_index], [head_tag] | |||||
] | |||||
""" | |||||
sample = [] | |||||
data = [] | |||||
for i, line in enumerate(lines): | |||||
line = line.strip() | |||||
if len(line) == 0 or i + 1 == len(lines): | |||||
data.append(list(map(list, zip(*sample)))) | |||||
sample = [] | |||||
else: | |||||
sample.append(line.split()) | |||||
if len(sample) > 0: | |||||
data.append(list(map(list, zip(*sample)))) | |||||
return data | |||||
def add_seg_tag(data): | |||||
""" | |||||
:param data: list of ([word], [pos], [heads], [head_tags]) | |||||
:return: list of ([word], [pos]) | |||||
""" | |||||
_processed = [] | |||||
for word_list, pos_list, _, _ in data: | |||||
new_sample = [] | |||||
for word, pos in zip(word_list, pos_list): | |||||
if len(word) == 1: | |||||
new_sample.append((word, 'S-' + pos)) | |||||
else: | |||||
new_sample.append((word[0], 'B-' + pos)) | |||||
for c in word[1:-1]: | |||||
new_sample.append((c, 'M-' + pos)) | |||||
new_sample.append((word[-1], 'E-' + pos)) | |||||
_processed.append(list(map(list, zip(*new_sample)))) | |||||
return _processed |
@@ -4,8 +4,9 @@ import torch.nn.functional as F | |||||
class CNN_text(nn.Module): | class CNN_text(nn.Module): | ||||
def __init__(self, kernel_h=[3, 4, 5], kernel_num=100, embed_num=1000, embed_dim=300, dropout=0.5, L2_constrain=3, | |||||
batchsize=50, pretrained_embeddings=None): | |||||
def __init__(self, kernel_h=[3, 4, 5], kernel_num=100, embed_num=1000, embed_dim=300, num_classes=2, dropout=0.5, | |||||
L2_constrain=3, | |||||
pretrained_embeddings=None): | |||||
super(CNN_text, self).__init__() | super(CNN_text, self).__init__() | ||||
self.embedding = nn.Embedding(embed_num, embed_dim) | self.embedding = nn.Embedding(embed_num, embed_dim) | ||||
@@ -15,11 +16,11 @@ class CNN_text(nn.Module): | |||||
# the network structure | # the network structure | ||||
# Conv2d: input- N,C,H,W output- (50,100,62,1) | # Conv2d: input- N,C,H,W output- (50,100,62,1) | ||||
self.conv1 = nn.ModuleList([nn.Conv2d(1, 100, (K, 300)) for K in kernel_h]) | |||||
self.fc1 = nn.Linear(300, 2) | |||||
self.conv1 = nn.ModuleList([nn.Conv2d(1, kernel_num, (K, embed_dim)) for K in kernel_h]) | |||||
self.fc1 = nn.Linear(len(kernel_h) * kernel_num, num_classes) | |||||
def max_pooling(self, x): | def max_pooling(self, x): | ||||
x = F.relu(conv(x)).squeeze(3) # N,C,L - (50,100,62) | |||||
x = F.relu(self.conv1(x)).squeeze(3) # N,C,L - (50,100,62) | |||||
x = F.max_pool1d(x, x.size(2)).squeeze(2) | x = F.max_pool1d(x, x.size(2)).squeeze(2) | ||||
# x.size(2)=62 squeeze: (50,100,1) -> (50,100) | # x.size(2)=62 squeeze: (50,100,1) -> (50,100) | ||||
return x | return x | ||||
@@ -33,3 +34,9 @@ class CNN_text(nn.Module): | |||||
x = self.dropout(x) | x = self.dropout(x) | ||||
x = self.fc1(x) | x = self.fc1(x) | ||||
return x | return x | ||||
if __name__ == '__main__': | |||||
model = CNN_text(kernel_h=[1, 2, 3, 4], embed_num=3, embed_dim=2) | |||||
x = torch.LongTensor([[1, 2, 1, 2, 0]]) | |||||
print(model(x)) |
@@ -1,10 +1,10 @@ | |||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from fastNLP.core.preprocess import ClassPreprocess as Preprocess | |||||
from fastNLP.core.trainer import ClassificationTrainer | from fastNLP.core.trainer import ClassificationTrainer | ||||
from fastNLP.loader.config_loader import ConfigLoader | |||||
from fastNLP.loader.config_loader import ConfigSection | |||||
from fastNLP.loader.dataset_loader import ClassDataSetLoader as Dataset_loader | |||||
from fastNLP.core.utils import ClassPreprocess as Preprocess | |||||
from fastNLP.io.config_loader import ConfigLoader | |||||
from fastNLP.io.config_loader import ConfigSection | |||||
from fastNLP.io.dataset_loader import ClassDataSetLoader as Dataset_loader | |||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules.aggregator.self_attention import SelfAttention | from fastNLP.modules.aggregator.self_attention import SelfAttention | ||||
from fastNLP.modules.decoder.MLP import MLP | from fastNLP.modules.decoder.MLP import MLP | ||||
@@ -1,6 +1,6 @@ | |||||
[train] | [train] | ||||
epochs = 30 | |||||
batch_size = 64 | |||||
epochs = 40 | |||||
batch_size = 8 | |||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
validate = true | validate = true | ||||
save_best_dev = true | save_best_dev = true | ||||
@@ -0,0 +1,176 @@ | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.io.dataset_loader import DataSetLoader | |||||
def cut_long_sentence(sent, max_sample_length=200): | |||||
sent_no_space = sent.replace(' ', '') | |||||
cutted_sentence = [] | |||||
if len(sent_no_space) > max_sample_length: | |||||
parts = sent.strip().split() | |||||
new_line = '' | |||||
length = 0 | |||||
for part in parts: | |||||
length += len(part) | |||||
new_line += part + ' ' | |||||
if length > max_sample_length: | |||||
new_line = new_line[:-1] | |||||
cutted_sentence.append(new_line) | |||||
length = 0 | |||||
new_line = '' | |||||
if new_line != '': | |||||
cutted_sentence.append(new_line[:-1]) | |||||
else: | |||||
cutted_sentence.append(sent) | |||||
return cutted_sentence | |||||
class NaiveCWSReader(DataSetLoader): | |||||
""" | |||||
这个reader假设了分词数据集为以下形式, 即已经用空格分割好内容了 | |||||
这是 fastNLP , 一个 非常 good 的 包 . | |||||
或者,即每个part后面还有一个pos tag | |||||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | |||||
""" | |||||
def __init__(self, in_word_splitter=None): | |||||
super().__init__() | |||||
self.in_word_splitter = in_word_splitter | |||||
def load(self, filepath, in_word_splitter=None, cut_long_sent=False): | |||||
""" | |||||
允许使用的情况有(默认以\t或空格作为seg) | |||||
这是 fastNLP , 一个 非常 good 的 包 . | |||||
和 | |||||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | |||||
如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0] | |||||
:param filepath: | |||||
:param in_word_splitter: | |||||
:return: | |||||
""" | |||||
if in_word_splitter == None: | |||||
in_word_splitter = self.in_word_splitter | |||||
dataset = DataSet() | |||||
with open(filepath, 'r') as f: | |||||
for line in f: | |||||
line = line.strip() | |||||
if len(line.replace(' ', ''))==0: # 不能接受空行 | |||||
continue | |||||
if not in_word_splitter is None: | |||||
words = [] | |||||
for part in line.split(): | |||||
word = part.split(in_word_splitter)[0] | |||||
words.append(word) | |||||
line = ' '.join(words) | |||||
if cut_long_sent: | |||||
sents = cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for sent in sents: | |||||
instance = Instance(raw_sentence=sent) | |||||
dataset.append(instance) | |||||
return dataset | |||||
class POSCWSReader(DataSetLoader): | |||||
""" | |||||
支持读取以下的情况, 即每一行是一个词, 用空行作为两句话的界限. | |||||
迈 N | |||||
向 N | |||||
充 N | |||||
... | |||||
泽 I-PER | |||||
民 I-PER | |||||
( N | |||||
一 N | |||||
九 N | |||||
... | |||||
:param filepath: | |||||
:return: | |||||
""" | |||||
def __init__(self, in_word_splitter=None): | |||||
super().__init__() | |||||
self.in_word_splitter = in_word_splitter | |||||
def load(self, filepath, in_word_splitter=None, cut_long_sent=False): | |||||
if in_word_splitter is None: | |||||
in_word_splitter = self.in_word_splitter | |||||
dataset = DataSet() | |||||
with open(filepath, 'r') as f: | |||||
words = [] | |||||
for line in f: | |||||
line = line.strip() | |||||
if len(line) == 0: # new line | |||||
if len(words)==0: # 不能接受空行 | |||||
continue | |||||
line = ' '.join(words) | |||||
if cut_long_sent: | |||||
sents = cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for sent in sents: | |||||
instance = Instance(raw_sentence=sent) | |||||
dataset.append(instance) | |||||
words = [] | |||||
else: | |||||
line = line.split()[0] | |||||
if in_word_splitter is None: | |||||
words.append(line) | |||||
else: | |||||
words.append(line.split(in_word_splitter)[0]) | |||||
return dataset | |||||
class ConlluCWSReader(object): | |||||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path, cut_long_sent=False): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_one(sample) | |||||
if res is None: | |||||
continue | |||||
line = ' '.join(res) | |||||
if cut_long_sent: | |||||
sents = cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for raw_sentence in sents: | |||||
ds.append(Instance(raw_sentence=raw_sentence)) | |||||
return ds | |||||
def get_one(self, sample): | |||||
if len(sample)==0: | |||||
return None | |||||
text = [] | |||||
for w in sample: | |||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
if t3 == '_': | |||||
return None | |||||
text.append(t1) | |||||
return text | |||||
@@ -0,0 +1,172 @@ | |||||
from torch import nn | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from fastNLP.modules.decoder.MLP import MLP | |||||
from fastNLP.models.base_model import BaseModel | |||||
from reproduction.chinese_word_segment.utils import seq_lens_to_mask | |||||
class CWSBiLSTMEncoder(BaseModel): | |||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | |||||
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1): | |||||
super().__init__() | |||||
self.input_size = 0 | |||||
self.num_bigram_per_char = num_bigram_per_char | |||||
self.bidirectional = bidirectional | |||||
self.num_layers = num_layers | |||||
self.embed_drop_p = embed_drop_p | |||||
if self.bidirectional: | |||||
self.hidden_size = hidden_size//2 | |||||
self.num_directions = 2 | |||||
else: | |||||
self.hidden_size = hidden_size | |||||
self.num_directions = 1 | |||||
if not bigram_vocab_num is None: | |||||
assert not bigram_vocab_num is None, "Specify num_bigram_per_char." | |||||
if vocab_num is not None: | |||||
self.char_embedding = nn.Embedding(num_embeddings=vocab_num, embedding_dim=embed_dim) | |||||
self.input_size += embed_dim | |||||
if bigram_vocab_num is not None: | |||||
self.bigram_embedding = nn.Embedding(num_embeddings=bigram_vocab_num, embedding_dim=bigram_embed_dim) | |||||
self.input_size += self.num_bigram_per_char*bigram_embed_dim | |||||
if not self.embed_drop_p is None: | |||||
self.embedding_drop = nn.Dropout(p=self.embed_drop_p) | |||||
self.lstm = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, bidirectional=self.bidirectional, | |||||
batch_first=True, num_layers=self.num_layers) | |||||
self.reset_parameters() | |||||
def reset_parameters(self): | |||||
for name, param in self.named_parameters(): | |||||
if 'bias_hh' in name: | |||||
nn.init.constant_(param, 0) | |||||
elif 'bias_ih' in name: | |||||
nn.init.constant_(param, 1) | |||||
else: | |||||
nn.init.xavier_uniform_(param) | |||||
def init_embedding(self, embedding, embed_name): | |||||
if embed_name == 'bigram': | |||||
self.bigram_embedding.weight.data = torch.from_numpy(embedding) | |||||
elif embed_name == 'char': | |||||
self.char_embedding.weight.data = torch.from_numpy(embedding) | |||||
def forward(self, chars, bigrams=None, seq_lens=None): | |||||
batch_size, max_len = chars.size() | |||||
x_tensor = self.char_embedding(chars) | |||||
if not bigrams is None: | |||||
bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1) | |||||
x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2) | |||||
sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True) | |||||
packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True) | |||||
outputs, _ = self.lstm(packed_x) | |||||
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) | |||||
_, desorted_indices = torch.sort(sorted_indices, descending=False) | |||||
outputs = outputs[desorted_indices] | |||||
return outputs | |||||
class CWSBiLSTMSegApp(BaseModel): | |||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | |||||
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=2): | |||||
super(CWSBiLSTMSegApp, self).__init__() | |||||
self.tag_size = tag_size | |||||
self.encoder_model = CWSBiLSTMEncoder(vocab_num, embed_dim, bigram_vocab_num, bigram_embed_dim, num_bigram_per_char, | |||||
hidden_size, bidirectional, embed_drop_p, num_layers) | |||||
size_layer = [hidden_size, 200, tag_size] | |||||
self.decoder_model = MLP(size_layer) | |||||
def forward(self, chars, seq_lens, bigrams=None): | |||||
device = self.parameters().__next__().device | |||||
chars = chars.to(device).long() | |||||
if not bigrams is None: | |||||
bigrams = bigrams.to(device).long() | |||||
else: | |||||
bigrams = None | |||||
seq_lens = seq_lens.to(device).long() | |||||
feats = self.encoder_model(chars, bigrams, seq_lens) | |||||
probs = self.decoder_model(feats) | |||||
pred_dict = {} | |||||
pred_dict['seq_lens'] = seq_lens | |||||
pred_dict['pred_probs'] = probs | |||||
return pred_dict | |||||
def predict(self, chars, seq_lens, bigrams=None): | |||||
pred_dict = self.forward(chars, seq_lens, bigrams) | |||||
pred_probs = pred_dict['pred_probs'] | |||||
_, pred_tags = pred_probs.max(dim=-1) | |||||
return {'pred_tags': pred_tags} | |||||
from fastNLP.modules.decoder.CRF import ConditionalRandomField | |||||
class CWSBiLSTMCRF(BaseModel): | |||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | |||||
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=4): | |||||
super(CWSBiLSTMCRF, self).__init__() | |||||
self.tag_size = tag_size | |||||
self.encoder_model = CWSBiLSTMEncoder(vocab_num, embed_dim, bigram_vocab_num, bigram_embed_dim, num_bigram_per_char, | |||||
hidden_size, bidirectional, embed_drop_p, num_layers) | |||||
size_layer = [hidden_size, 200, tag_size] | |||||
self.decoder_model = MLP(size_layer) | |||||
self.crf = ConditionalRandomField(tag_size=tag_size, include_start_end_trans=False) | |||||
def forward(self, chars, tags, seq_lens, bigrams=None): | |||||
device = self.parameters().__next__().device | |||||
chars = chars.to(device).long() | |||||
if not bigrams is None: | |||||
bigrams = bigrams.to(device).long() | |||||
else: | |||||
bigrams = None | |||||
seq_lens = seq_lens.to(device).long() | |||||
masks = seq_lens_to_mask(seq_lens) | |||||
feats = self.encoder_model(chars, bigrams, seq_lens) | |||||
feats = self.decoder_model(feats) | |||||
losses = self.crf(feats, tags, masks) | |||||
pred_dict = {} | |||||
pred_dict['seq_lens'] = seq_lens | |||||
pred_dict['loss'] = torch.mean(losses) | |||||
return pred_dict | |||||
def predict(self, chars, seq_lens, bigrams=None): | |||||
device = self.parameters().__next__().device | |||||
chars = chars.to(device).long() | |||||
if not bigrams is None: | |||||
bigrams = bigrams.to(device).long() | |||||
else: | |||||
bigrams = None | |||||
seq_lens = seq_lens.to(device).long() | |||||
masks = seq_lens_to_mask(seq_lens) | |||||
feats = self.encoder_model(chars, bigrams, seq_lens) | |||||
feats = self.decoder_model(feats) | |||||
probs = self.crf.viterbi_decode(feats, masks, get_score=False) | |||||
return {'pred_tags': probs} | |||||
@@ -0,0 +1,284 @@ | |||||
import re | |||||
from fastNLP.core.field import SeqLabelField | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.api.processor import Processor | |||||
from reproduction.chinese_word_segment.process.span_converter import SpanConverter | |||||
_SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' | |||||
class SpeicalSpanProcessor(Processor): | |||||
# 这个类会将句子中的special span转换为对应的内容。 | |||||
def __init__(self, field_name, new_added_field_name=None): | |||||
super(SpeicalSpanProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.span_converters = [] | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
sentence = ins[self.field_name] | |||||
for span_converter in self.span_converters: | |||||
sentence = span_converter.find_certain_span_and_replace(sentence) | |||||
ins[self.new_added_field_name] = sentence | |||||
return dataset | |||||
def add_span_converter(self, converter): | |||||
assert isinstance(converter, SpanConverter), "Only SpanConverterBase is allowed, not {}."\ | |||||
.format(type(converter)) | |||||
self.span_converters.append(converter) | |||||
class CWSCharSegProcessor(Processor): | |||||
def __init__(self, field_name, new_added_field_name): | |||||
super(CWSCharSegProcessor, self).__init__(field_name, new_added_field_name) | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
sentence = ins[self.field_name] | |||||
chars = self._split_sent_into_chars(sentence) | |||||
ins[self.new_added_field_name] = chars | |||||
return dataset | |||||
def _split_sent_into_chars(self, sentence): | |||||
sp_tag_match_iter = re.finditer(_SPECIAL_TAG_PATTERN, sentence) | |||||
sp_spans = [match_span.span() for match_span in sp_tag_match_iter] | |||||
sp_span_idx = 0 | |||||
in_span_flag = False | |||||
chars = [] | |||||
num_spans = len(sp_spans) | |||||
for idx, char in enumerate(sentence): | |||||
if sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][0]: | |||||
in_span_flag = True | |||||
elif in_span_flag and sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][1] - 1: | |||||
chars.append(sentence[sp_spans[sp_span_idx] | |||||
[0]:sp_spans[sp_span_idx][1]]) | |||||
in_span_flag = False | |||||
sp_span_idx += 1 | |||||
elif not in_span_flag: | |||||
# TODO 需要谨慎考虑如何处理空格的问题 | |||||
if char != ' ': | |||||
chars.append(char) | |||||
else: | |||||
pass | |||||
return chars | |||||
class CWSTagProcessor(Processor): | |||||
def __init__(self, field_name, new_added_field_name=None): | |||||
super(CWSTagProcessor, self).__init__(field_name, new_added_field_name) | |||||
def _generate_tag(self, sentence): | |||||
sp_tag_match_iter = re.finditer(_SPECIAL_TAG_PATTERN, sentence) | |||||
sp_spans = [match_span.span() for match_span in sp_tag_match_iter] | |||||
sp_span_idx = 0 | |||||
in_span_flag = False | |||||
tag_list = [] | |||||
word_len = 0 | |||||
num_spans = len(sp_spans) | |||||
for idx, char in enumerate(sentence): | |||||
if sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][0]: | |||||
in_span_flag = True | |||||
elif in_span_flag and sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][1] - 1: | |||||
word_len += 1 | |||||
in_span_flag = False | |||||
sp_span_idx += 1 | |||||
elif not in_span_flag: | |||||
if char == ' ': | |||||
if word_len!=0: | |||||
tag_list.extend(self._tags_from_word_len(word_len)) | |||||
word_len = 0 | |||||
else: | |||||
word_len += 1 | |||||
else: | |||||
pass | |||||
if word_len!=0: | |||||
tag_list.extend(self._tags_from_word_len(word_len)) | |||||
return tag_list | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
sentence = ins[self.field_name] | |||||
tag_list = self._generate_tag(sentence) | |||||
ins[self.new_added_field_name] = tag_list | |||||
dataset.set_target(**{self.new_added_field_name:True}) | |||||
dataset._set_need_tensor(**{self.new_added_field_name:True}) | |||||
return dataset | |||||
def _tags_from_word_len(self, word_len): | |||||
raise NotImplementedError | |||||
class CWSBMESTagProcessor(CWSTagProcessor): | |||||
def __init__(self, field_name, new_added_field_name=None): | |||||
super(CWSBMESTagProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.tag_size = 4 | |||||
def _tags_from_word_len(self, word_len): | |||||
tag_list = [] | |||||
if word_len == 1: | |||||
tag_list.append(3) | |||||
else: | |||||
tag_list.append(0) | |||||
for _ in range(word_len-2): | |||||
tag_list.append(1) | |||||
tag_list.append(2) | |||||
return tag_list | |||||
class CWSSegAppTagProcessor(CWSTagProcessor): | |||||
def __init__(self, field_name, new_added_field_name=None): | |||||
super(CWSSegAppTagProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.tag_size = 2 | |||||
def _tags_from_word_len(self, word_len): | |||||
tag_list = [] | |||||
for _ in range(word_len-1): | |||||
tag_list.append(0) | |||||
tag_list.append(1) | |||||
return tag_list | |||||
class BigramProcessor(Processor): | |||||
def __init__(self, field_name, new_added_fielf_name=None): | |||||
super(BigramProcessor, self).__init__(field_name, new_added_fielf_name) | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
characters = ins[self.field_name] | |||||
bigrams = self._generate_bigram(characters) | |||||
ins[self.new_added_field_name] = bigrams | |||||
return dataset | |||||
def _generate_bigram(self, characters): | |||||
pass | |||||
class Pre2Post2BigramProcessor(BigramProcessor): | |||||
def __init__(self, field_name, new_added_fielf_name=None): | |||||
super(BigramProcessor, self).__init__(field_name, new_added_fielf_name) | |||||
def _generate_bigram(self, characters): | |||||
bigrams = [] | |||||
characters = ['<SOS>', '<SOS>'] + characters + ['<EOS>', '<EOS>'] | |||||
for idx in range(2, len(characters)-2): | |||||
cur_char = characters[idx] | |||||
pre_pre_char = characters[idx-2] | |||||
pre_char = characters[idx-1] | |||||
post_char = characters[idx+1] | |||||
post_post_char = characters[idx+2] | |||||
pre_pre_cur_bigram = pre_pre_char + cur_char | |||||
pre_cur_bigram = pre_char + cur_char | |||||
cur_post_bigram = cur_char + post_char | |||||
cur_post_post_bigram = cur_char + post_post_char | |||||
bigrams.extend([pre_pre_char, pre_char, post_char, post_post_char, | |||||
pre_pre_cur_bigram, pre_cur_bigram, | |||||
cur_post_bigram, cur_post_post_bigram]) | |||||
return bigrams | |||||
# 这里需要建立vocabulary了,但是遇到了以下的问题 | |||||
# (1) 如果使用Processor的方式的话,但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现,不借用 | |||||
# Processor了 | |||||
class VocabProcessor(Processor): | |||||
def __init__(self, field_name, min_count=1, max_vocab_size=None): | |||||
super(VocabProcessor, self).__init__(field_name, None) | |||||
self.vocab = Vocabulary(min_freq=min_count, max_size=max_vocab_size) | |||||
def process(self, *datasets): | |||||
for dataset in datasets: | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
tokens = ins[self.field_name] | |||||
self.vocab.update(tokens) | |||||
def get_vocab(self): | |||||
self.vocab.build_vocab() | |||||
return self.vocab | |||||
def get_vocab_size(self): | |||||
return len(self.vocab) | |||||
class SeqLenProcessor(Processor): | |||||
def __init__(self, field_name, new_added_field_name='seq_lens'): | |||||
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
length = len(ins[self.field_name]) | |||||
ins[self.new_added_field_name] = length | |||||
dataset._set_need_tensor(**{self.new_added_field_name:True}) | |||||
return dataset | |||||
class SegApp2OutputProcessor(Processor): | |||||
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): | |||||
super(SegApp2OutputProcessor, self).__init__(None, None) | |||||
self.chars_field_name = chars_field_name | |||||
self.tag_field_name = tag_field_name | |||||
self.new_added_field_name = new_added_field_name | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
pred_tags = ins[self.tag_field_name] | |||||
chars = ins[self.chars_field_name] | |||||
words = [] | |||||
start_idx = 0 | |||||
for idx, tag in enumerate(pred_tags): | |||||
if tag==1: | |||||
# 当前没有考虑将原文替换回去 | |||||
words.append(''.join(chars[start_idx:idx+1])) | |||||
start_idx = idx + 1 | |||||
ins[self.new_added_field_name] = ' '.join(words) | |||||
class BMES2OutputProcessor(Processor): | |||||
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): | |||||
super(BMES2OutputProcessor, self).__init__(None, None) | |||||
self.chars_field_name = chars_field_name | |||||
self.tag_field_name = tag_field_name | |||||
self.new_added_field_name = new_added_field_name | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
pred_tags = ins[self.tag_field_name] | |||||
chars = ins[self.chars_field_name] | |||||
words = [] | |||||
start_idx = 0 | |||||
for idx, tag in enumerate(pred_tags): | |||||
if tag==3: | |||||
# 当前没有考虑将原文替换回去 | |||||
words.extend(chars[start_idx:idx+1]) | |||||
start_idx = idx + 1 | |||||
elif tag==2: | |||||
words.append(''.join(chars[start_idx:idx+1])) | |||||
start_idx = idx + 1 | |||||
ins[self.new_added_field_name] = ' '.join(words) |
@@ -0,0 +1,185 @@ | |||||
import re | |||||
class SpanConverter: | |||||
def __init__(self, replace_tag, pattern): | |||||
super(SpanConverter, self).__init__() | |||||
self.replace_tag = replace_tag | |||||
self.pattern = pattern | |||||
def find_certain_span_and_replace(self, sentence): | |||||
replaced_sentence = '' | |||||
prev_end = 0 | |||||
for match in re.finditer(self.pattern, sentence): | |||||
start, end = match.span() | |||||
span = sentence[start:end] | |||||
replaced_sentence += sentence[prev_end:start] + \ | |||||
self.span_to_special_tag(span) | |||||
prev_end = end | |||||
replaced_sentence += sentence[prev_end:] | |||||
return replaced_sentence | |||||
def span_to_special_tag(self, span): | |||||
return self.replace_tag | |||||
def find_certain_span(self, sentence): | |||||
spans = [] | |||||
for match in re.finditer(self.pattern, sentence): | |||||
spans.append(match.span()) | |||||
return spans | |||||
class AlphaSpanConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<ALPHA>' | |||||
# 理想状态下仅处理纯为字母的情况, 但不处理<[a-zA-Z]+>(因为这应该是特殊的tag). | |||||
pattern = '[a-zA-Z]+(?=[\u4e00-\u9fff ,%.!<\\-"])' | |||||
super(AlphaSpanConverter, self).__init__(replace_tag, pattern) | |||||
class DigitSpanConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<NUM>' | |||||
pattern = '\d[\d\\.]*(?=[\u4e00-\u9fff ,%.!<-])' | |||||
super(DigitSpanConverter, self).__init__(replace_tag, pattern) | |||||
def span_to_special_tag(self, span): | |||||
# return self.special_tag | |||||
if span[0] == '0' and len(span) > 2: | |||||
return '<NUM>' | |||||
decimal_point_count = 0 # one might have more than one decimal pointers | |||||
for idx, char in enumerate(span): | |||||
if char == '.' or char == '﹒' or char == '·': | |||||
decimal_point_count += 1 | |||||
if span[-1] == '.' or span[-1] == '﹒' or span[ | |||||
-1] == '·': # last digit being decimal point means this is not a number | |||||
if decimal_point_count == 1: | |||||
return span | |||||
else: | |||||
return '<UNKDGT>' | |||||
if decimal_point_count == 1: | |||||
return '<DEC>' | |||||
elif decimal_point_count > 1: | |||||
return '<UNKDGT>' | |||||
else: | |||||
return '<NUM>' | |||||
class TimeConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<TOC>' | |||||
pattern = '\d+[::∶][\d::∶]+(?=[\u4e00-\u9fff ,%.!<-])' | |||||
super().__init__(replace_tag, pattern) | |||||
class MixNumAlphaConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<MIX>' | |||||
pattern = None | |||||
super().__init__(replace_tag, pattern) | |||||
def find_certain_span_and_replace(self, sentence): | |||||
replaced_sentence = '' | |||||
start = 0 | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
for idx in range(len(sentence)): | |||||
if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]): | |||||
if not matching_flag: | |||||
replaced_sentence += sentence[start:idx] | |||||
start = idx | |||||
if re.match('[0-9]', sentence[idx]): | |||||
number_flag = True | |||||
elif re.match('[\'′&\\-]', sentence[idx]): | |||||
link_flag = True | |||||
elif re.match('/', sentence[idx]): | |||||
slash_flag = True | |||||
elif re.match('[\\(\\)]', sentence[idx]): | |||||
bracket_flag = True | |||||
else: | |||||
alpha_flag = True | |||||
matching_flag = True | |||||
elif re.match('[\\.]', sentence[idx]): | |||||
pass | |||||
else: | |||||
if matching_flag: | |||||
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \ | |||||
or (slash_flag and alpha_flag) or (link_flag and number_flag) \ | |||||
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag): | |||||
span = sentence[start:idx] | |||||
start = idx | |||||
replaced_sentence += self.span_to_special_tag(span) | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
replaced_sentence += sentence[start:] | |||||
return replaced_sentence | |||||
def find_certain_span(self, sentence): | |||||
spans = [] | |||||
start = 0 | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
for idx in range(len(sentence)): | |||||
if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]): | |||||
if not matching_flag: | |||||
start = idx | |||||
if re.match('[0-9]', sentence[idx]): | |||||
number_flag = True | |||||
elif re.match('[\'′&\\-]', sentence[idx]): | |||||
link_flag = True | |||||
elif re.match('/', sentence[idx]): | |||||
slash_flag = True | |||||
elif re.match('[\\(\\)]', sentence[idx]): | |||||
bracket_flag = True | |||||
else: | |||||
alpha_flag = True | |||||
matching_flag = True | |||||
elif re.match('[\\.]', sentence[idx]): | |||||
pass | |||||
else: | |||||
if matching_flag: | |||||
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \ | |||||
or (slash_flag and alpha_flag) or (link_flag and number_flag) \ | |||||
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag): | |||||
spans.append((start, idx)) | |||||
start = idx | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
return spans | |||||
class EmailConverter(SpanConverter): | |||||
def __init__(self): | |||||
replaced_tag = "<EML>" | |||||
pattern = '[0-9a-zA-Z]+[@][.﹒0-9a-zA-Z@]+(?=[\u4e00-\u9fff ,%.!<\\-"$])' | |||||
super(EmailConverter, self).__init__(replaced_tag, pattern) |
@@ -3,17 +3,16 @@ import sys | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | ||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.io.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
from fastNLP.loader.dataset_loader import BaseLoader, TokenizeDataSetLoader | |||||
from fastNLP.core.preprocess import load_pickle | |||||
from fastNLP.saver.model_saver import ModelSaver | |||||
from fastNLP.loader.model_loader import ModelLoader | |||||
from fastNLP.io.dataset_loader import BaseLoader, TokenizeDataSetLoader | |||||
from fastNLP.core.utils import load_pickle | |||||
from fastNLP.io.model_saver import ModelSaver | |||||
from fastNLP.io.model_loader import ModelLoader | |||||
from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | from fastNLP.models.sequence_modeling import AdvSeqLabel | ||||
from fastNLP.core.predictor import SeqLabelInfer | from fastNLP.core.predictor import SeqLabelInfer | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.preprocess import save_pickle | |||||
from fastNLP.core.utils import save_pickle | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | from fastNLP.core.metrics import SeqLabelEvaluator | ||||
# not in the file's dir | # not in the file's dir | ||||
@@ -0,0 +1,151 @@ | |||||
import torch | |||||
def seq_lens_to_mask(seq_lens): | |||||
batch_size = seq_lens.size(0) | |||||
max_len = seq_lens.max() | |||||
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) | |||||
masks = indexes.lt(seq_lens.unsqueeze(1)) | |||||
return masks | |||||
from itertools import chain | |||||
def refine_ys_on_seq_len(ys, seq_lens): | |||||
refined_ys = [] | |||||
for b_idx, length in enumerate(seq_lens): | |||||
refined_ys.append(list(ys[b_idx][:length])) | |||||
return refined_ys | |||||
def flat_nested_list(nested_list): | |||||
return list(chain(*nested_list)) | |||||
def calculate_pre_rec_f1(model, batcher, type='segapp'): | |||||
true_ys, pred_ys = decode_iterator(model, batcher) | |||||
true_ys = flat_nested_list(true_ys) | |||||
pred_ys = flat_nested_list(pred_ys) | |||||
cor_num = 0 | |||||
start = 0 | |||||
if type=='segapp': | |||||
yp_wordnum = pred_ys.count(1) | |||||
yt_wordnum = true_ys.count(1) | |||||
if true_ys[0]==1 and pred_ys[0]==1: | |||||
cor_num += 1 | |||||
start = 1 | |||||
for i in range(1, len(true_ys)): | |||||
if true_ys[i] == 1: | |||||
flag = True | |||||
if true_ys[start-1] != pred_ys[start-1]: | |||||
flag = False | |||||
else: | |||||
for j in range(start, i + 1): | |||||
if true_ys[j] != pred_ys[j]: | |||||
flag = False | |||||
break | |||||
if flag: | |||||
cor_num += 1 | |||||
start = i + 1 | |||||
elif type=='bmes': | |||||
yp_wordnum = pred_ys.count(2) + pred_ys.count(3) | |||||
yt_wordnum = true_ys.count(2) + true_ys.count(3) | |||||
for i in range(len(true_ys)): | |||||
if true_ys[i] == 2 or true_ys[i] == 3: | |||||
flag = True | |||||
for j in range(start, i + 1): | |||||
if true_ys[j] != pred_ys[j]: | |||||
flag = False | |||||
break | |||||
if flag: | |||||
cor_num += 1 | |||||
start = i + 1 | |||||
P = cor_num / (float(yp_wordnum) + 1e-6) | |||||
R = cor_num / (float(yt_wordnum) + 1e-6) | |||||
F = 2 * P * R / (P + R + 1e-6) | |||||
# print(cor_num, yt_wordnum, yp_wordnum) | |||||
return P, R, F | |||||
def decode_iterator(model, batcher): | |||||
true_ys = [] | |||||
pred_ys = [] | |||||
seq_lens = [] | |||||
with torch.no_grad(): | |||||
model.eval() | |||||
for batch_x, batch_y in batcher: | |||||
pred_dict = model.predict(**batch_x) | |||||
seq_len = batch_x['seq_lens'].cpu().numpy() | |||||
pred_y = pred_dict['pred_tags'] | |||||
true_y = batch_y['tags'] | |||||
pred_y = pred_y.cpu().numpy() | |||||
true_y = true_y.cpu().numpy() | |||||
true_ys.extend(true_y.tolist()) | |||||
pred_ys.extend(pred_y.tolist()) | |||||
seq_lens.extend(list(seq_len)) | |||||
model.train() | |||||
true_ys = refine_ys_on_seq_len(true_ys, seq_lens) | |||||
pred_ys = refine_ys_on_seq_len(pred_ys, seq_lens) | |||||
return true_ys, pred_ys | |||||
from torch import nn | |||||
import torch.nn.functional as F | |||||
class FocalLoss(nn.Module): | |||||
r""" | |||||
This criterion is a implemenation of Focal Loss, which is proposed in | |||||
Focal Loss for Dense Object Detection. | |||||
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) | |||||
The losses are averaged across observations for each minibatch. | |||||
Args: | |||||
alpha(1D Tensor, Variable) : the scalar factor for this criterion | |||||
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), | |||||
putting more focus on hard, misclassified examples | |||||
size_average(bool): size_average(bool): By default, the losses are averaged over observations for each minibatch. | |||||
However, if the field size_average is set to False, the losses are | |||||
instead summed for each minibatch. | |||||
""" | |||||
def __init__(self, class_num, gamma=2, size_average=True, reduce=False): | |||||
super(FocalLoss, self).__init__() | |||||
self.gamma = gamma | |||||
self.class_num = class_num | |||||
self.size_average = size_average | |||||
self.reduce = reduce | |||||
def forward(self, inputs, targets): | |||||
N = inputs.size(0) | |||||
C = inputs.size(1) | |||||
P = F.softmax(inputs, dim=-1) | |||||
class_mask = inputs.data.new(N, C).fill_(0) | |||||
class_mask.requires_grad = True | |||||
ids = targets.view(-1, 1) | |||||
class_mask = class_mask.scatter(1, ids.data, 1.) | |||||
probs = (P * class_mask).sum(1).view(-1, 1) | |||||
log_p = probs.log() | |||||
batch_loss = - (torch.pow((1 - probs), self.gamma)) * log_p | |||||
if self.reduce: | |||||
if self.size_average: | |||||
loss = batch_loss.mean() | |||||
else: | |||||
loss = batch_loss.sum() | |||||
return loss | |||||
return batch_loss |
@@ -0,0 +1,89 @@ | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
def cut_long_sentence(sent, max_sample_length=200): | |||||
sent_no_space = sent.replace(' ', '') | |||||
cutted_sentence = [] | |||||
if len(sent_no_space) > max_sample_length: | |||||
parts = sent.strip().split() | |||||
new_line = '' | |||||
length = 0 | |||||
for part in parts: | |||||
length += len(part) | |||||
new_line += part + ' ' | |||||
if length > max_sample_length: | |||||
new_line = new_line[:-1] | |||||
cutted_sentence.append(new_line) | |||||
length = 0 | |||||
new_line = '' | |||||
if new_line != '': | |||||
cutted_sentence.append(new_line[:-1]) | |||||
else: | |||||
cutted_sentence.append(sent) | |||||
return cutted_sentence | |||||
class ConlluPOSReader(object): | |||||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_one(sample) | |||||
if res is None: | |||||
continue | |||||
char_seq = [] | |||||
pos_seq = [] | |||||
for word, tag in zip(res[0], res[1]): | |||||
if len(word)==1: | |||||
char_seq.append(word) | |||||
pos_seq.append('S-{}'.format(tag)) | |||||
elif len(word)>1: | |||||
pos_seq.append('B-{}'.format(tag)) | |||||
for _ in range(len(word)-2): | |||||
pos_seq.append('M-{}'.format(tag)) | |||||
pos_seq.append('E-{}'.format(tag)) | |||||
char_seq.extend(list(word)) | |||||
else: | |||||
raise ValueError("Zero length of word detected.") | |||||
ds.append(Instance(words=char_seq, | |||||
tag=pos_seq)) | |||||
return ds | |||||
def get_one(self, sample): | |||||
if len(sample)==0: | |||||
return None | |||||
text = [] | |||||
pos_tags = [] | |||||
for w in sample: | |||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
if t3 == '_': | |||||
return None | |||||
text.append(t1) | |||||
pos_tags.append(t2) | |||||
return text, pos_tags | |||||
if __name__ == '__main__': | |||||
reader = ConlluPOSReader() | |||||
d = reader.load('/home/hyan/train.conllx') | |||||
print('reader') |
@@ -1,14 +1,18 @@ | |||||
[train] | [train] | ||||
epochs = 30 | |||||
batch_size = 64 | |||||
epochs = 6 | |||||
batch_size = 32 | |||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
validate = true | validate = true | ||||
save_best_dev = true | save_best_dev = true | ||||
model_saved_path = "./save/" | model_saved_path = "./save/" | ||||
rnn_hidden_units = 100 | |||||
word_emb_dim = 100 | |||||
valid_step = 250 | |||||
eval_sort_key = 'accuracy' | |||||
[model] | |||||
rnn_hidden_units = 300 | |||||
word_emb_dim = 300 | |||||
dropout = 0.5 | |||||
use_crf = true | use_crf = true | ||||
use_cuda = true | |||||
print_every_step = 10 | print_every_step = 10 | ||||
[test] | [test] | ||||
@@ -0,0 +1,131 @@ | |||||
from collections import Counter | |||||
from fastNLP.api.processor import Processor | |||||
from fastNLP.core.dataset import DataSet | |||||
class CombineWordAndPosProcessor(Processor): | |||||
def __init__(self, word_field_name, pos_field_name): | |||||
super(CombineWordAndPosProcessor, self).__init__(None, None) | |||||
self.word_field_name = word_field_name | |||||
self.pos_field_name = pos_field_name | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
chars = ins[self.word_field_name] | |||||
bmes_pos = ins[self.pos_field_name] | |||||
word_list = [] | |||||
pos_list = [] | |||||
pos_stack_cnt = Counter() | |||||
char_stack = [] | |||||
for char, p in zip(chars, bmes_pos): | |||||
parts = p.split('-') | |||||
pre = parts[0] | |||||
post = parts[1] | |||||
if pre.lower() == 's': | |||||
if len(pos_stack_cnt) != 0: | |||||
pos = pos_stack_cnt.most_common(1)[0][0] | |||||
pos_list.append(pos) | |||||
word_list.append(''.join(char_stack)) | |||||
pos_list.append(post) | |||||
word_list.append(char) | |||||
char_stack.clear() | |||||
pos_stack_cnt.clear() | |||||
elif pre.lower() == 'e': | |||||
pos_stack_cnt.update([post]) | |||||
char_stack.append(char) | |||||
pos = pos_stack_cnt.most_common(1)[0][0] | |||||
pos_list.append(pos) | |||||
word_list.append(''.join(char_stack)) | |||||
char_stack.clear() | |||||
pos_stack_cnt.clear() | |||||
elif pre.lower() == 'b': | |||||
if len(pos_stack_cnt) != 0: | |||||
pos = pos_stack_cnt.most_common(1)[0][0] | |||||
pos_list.append(pos) | |||||
word_list.append(''.join(char_stack)) | |||||
char_stack.clear() | |||||
pos_stack_cnt.clear() | |||||
char_stack.append(char) | |||||
pos_stack_cnt.update([post]) | |||||
else: | |||||
char_stack.append(char) | |||||
pos_stack_cnt.update([post]) | |||||
ins['word_list'] = word_list | |||||
ins['pos_list'] = pos_list | |||||
return dataset | |||||
class PosOutputStrProcessor(Processor): | |||||
def __init__(self, word_field_name, pos_field_name): | |||||
super(PosOutputStrProcessor, self).__init__(None, None) | |||||
self.word_field_name = word_field_name | |||||
self.pos_field_name = pos_field_name | |||||
self.sep = '_' | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
word_list = ins[self.word_field_name] | |||||
pos_list = ins[self.pos_field_name] | |||||
word_pos_list = [] | |||||
for word, pos in zip(word_list, pos_list): | |||||
word_pos_list.append(word + self.sep + pos) | |||||
#TODO 应该可以定制 | |||||
ins['word_pos_output'] = ' '.join(word_pos_list) | |||||
return dataset | |||||
if __name__ == '__main__': | |||||
chars = ['迈', '向', '充', '满', '希', '望', '的', '新', '世', '纪', '—', '—', '一', '九', '九', '八', '年', '新', '年', '讲', '话', '(', '附', '图', '片', '1', '张', ')'] | |||||
bmes_pos = ['B-v', 'E-v', 'B-v', 'E-v', 'B-n', 'E-n', 'S-u', 'S-a', 'B-n', 'E-n', 'B-w', 'E-w', 'B-t', 'M-t', 'M-t', 'M-t', 'E-t', 'B-t', 'E-t', 'B-n', 'E-n', 'S-w', 'S-v', 'B-n', 'E-n', 'S-m', 'S-q', 'S-w'] | |||||
word_list = [] | |||||
pos_list = [] | |||||
pos_stack_cnt = Counter() | |||||
char_stack = [] | |||||
for char, p in zip(''.join(chars), bmes_pos): | |||||
parts = p.split('-') | |||||
pre = parts[0] | |||||
post = parts[1] | |||||
if pre.lower() == 's': | |||||
if len(pos_stack_cnt) != 0: | |||||
pos = pos_stack_cnt.most_common(1)[0][0] | |||||
pos_list.append(pos) | |||||
word_list.append(''.join(char_stack)) | |||||
pos_list.append(post) | |||||
word_list.append(char) | |||||
char_stack.clear() | |||||
pos_stack_cnt.clear() | |||||
elif pre.lower() == 'e': | |||||
pos_stack_cnt.update([post]) | |||||
char_stack.append(char) | |||||
pos = pos_stack_cnt.most_common(1)[0][0] | |||||
pos_list.append(pos) | |||||
word_list.append(''.join(char_stack)) | |||||
char_stack.clear() | |||||
pos_stack_cnt.clear() | |||||
elif pre.lower() == 'b': | |||||
if len(pos_stack_cnt) != 0: | |||||
pos = pos_stack_cnt.most_common(1)[0][0] | |||||
pos_list.append(pos) | |||||
word_list.append(''.join(char_stack)) | |||||
char_stack.clear() | |||||
pos_stack_cnt.clear() | |||||
char_stack.append(char) | |||||
pos_stack_cnt.update([post]) | |||||
else: | |||||
char_stack.append(char) | |||||
pos_stack_cnt.update([post]) | |||||
print(word_list) | |||||
print(pos_list) |
@@ -1,146 +0,0 @@ | |||||
import os | |||||
import sys | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.core.trainer import SeqLabelTrainer | |||||
from fastNLP.loader.dataset_loader import PeopleDailyCorpusLoader, BaseLoader | |||||
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||||
from fastNLP.saver.model_saver import ModelSaver | |||||
from fastNLP.loader.model_loader import ModelLoader | |||||
from fastNLP.core.tester import SeqLabelTester | |||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | |||||
from fastNLP.core.predictor import SeqLabelInfer | |||||
# not in the file's dir | |||||
if len(os.path.dirname(__file__)) != 0: | |||||
os.chdir(os.path.dirname(__file__)) | |||||
datadir = "/home/zyfeng/data/" | |||||
cfgfile = './pos_tag.cfg' | |||||
data_name = "CWS_POS_TAG_NER_people_daily.txt" | |||||
pos_tag_data_path = os.path.join(datadir, data_name) | |||||
pickle_path = "save" | |||||
data_infer_path = os.path.join(datadir, "infer.utf8") | |||||
def infer(): | |||||
# Config Loader | |||||
test_args = ConfigSection() | |||||
ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | |||||
word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
test_args["vocab_size"] = len(word2index) | |||||
index2label = load_pickle(pickle_path, "class2id.pkl") | |||||
test_args["num_classes"] = len(index2label) | |||||
# Define the same model | |||||
model = AdvSeqLabel(test_args) | |||||
try: | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
print('model loaded!') | |||||
except Exception as e: | |||||
print('cannot load model!') | |||||
raise | |||||
# Data Loader | |||||
raw_data_loader = BaseLoader(data_infer_path) | |||||
infer_data = raw_data_loader.load_lines() | |||||
print('data loaded') | |||||
# Inference interface | |||||
infer = SeqLabelInfer(pickle_path) | |||||
results = infer.predict(model, infer_data) | |||||
print(results) | |||||
print("Inference finished!") | |||||
def train(): | |||||
# Config Loader | |||||
train_args = ConfigSection() | |||||
test_args = ConfigSection() | |||||
ConfigLoader("good_name").load_config(cfgfile, {"train": train_args, "test": test_args}) | |||||
# Data Loader | |||||
loader = PeopleDailyCorpusLoader() | |||||
train_data, _ = loader.load() | |||||
# Preprocessor | |||||
preprocessor = SeqLabelPreprocess() | |||||
data_train, data_dev = preprocessor.run(train_data, pickle_path=pickle_path, train_dev_split=0.3) | |||||
train_args["vocab_size"] = preprocessor.vocab_size | |||||
train_args["num_classes"] = preprocessor.num_classes | |||||
# Trainer | |||||
trainer = SeqLabelTrainer(**train_args.data) | |||||
# Model | |||||
model = AdvSeqLabel(train_args) | |||||
try: | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
print('model parameter loaded!') | |||||
except Exception as e: | |||||
print("No saved model. Continue.") | |||||
pass | |||||
# Start training | |||||
trainer.train(model, data_train, data_dev) | |||||
print("Training finished!") | |||||
# Saver | |||||
saver = ModelSaver("./save/saved_model.pkl") | |||||
saver.save_pytorch(model) | |||||
print("Model saved!") | |||||
def test(): | |||||
# Config Loader | |||||
test_args = ConfigSection() | |||||
ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | |||||
word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
test_args["vocab_size"] = len(word2index) | |||||
index2label = load_pickle(pickle_path, "class2id.pkl") | |||||
test_args["num_classes"] = len(index2label) | |||||
# load dev data | |||||
dev_data = load_pickle(pickle_path, "data_dev.pkl") | |||||
# Define the same model | |||||
model = AdvSeqLabel(test_args) | |||||
# Dump trained parameters into the model | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
print("model loaded!") | |||||
# Tester | |||||
tester = SeqLabelTester(**test_args.data) | |||||
# Start testing | |||||
tester.test(model, dev_data) | |||||
# print test results | |||||
print(tester.show_metrics()) | |||||
print("model tested!") | |||||
if __name__ == "__main__": | |||||
import argparse | |||||
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | |||||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | |||||
args = parser.parse_args() | |||||
if args.mode == 'train': | |||||
train() | |||||
elif args.mode == 'test': | |||||
test() | |||||
elif args.mode == 'infer': | |||||
infer() | |||||
else: | |||||
print('no mode specified for model!') | |||||
parser.print_help() |
@@ -13,7 +13,7 @@ with open('requirements.txt', encoding='utf-8') as f: | |||||
setup( | setup( | ||||
name='fastNLP', | name='fastNLP', | ||||
version='0.1.0', | |||||
version='0.1.1', | |||||
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | ||||
long_description=readme, | long_description=readme, | ||||
license=license, | license=license, | ||||
@@ -1,53 +1,33 @@ | |||||
import unittest | import unittest | ||||
import torch | |||||
import numpy as np | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.field import TextField, LabelField | |||||
from fastNLP.core.instance import Instance | |||||
raw_texts = ["i am a cat", | |||||
"this is a test of new batch", | |||||
"ha ha", | |||||
"I am a good boy .", | |||||
"This is the most beautiful girl ." | |||||
] | |||||
texts = [text.strip().split() for text in raw_texts] | |||||
labels = [0, 1, 0, 0, 1] | |||||
# prepare vocabulary | |||||
vocab = {} | |||||
for text in texts: | |||||
for tokens in text: | |||||
if tokens not in vocab: | |||||
vocab[tokens] = len(vocab) | |||||
from fastNLP.core.dataset import construct_dataset | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
class TestCase1(unittest.TestCase): | class TestCase1(unittest.TestCase): | ||||
def test(self): | |||||
data = DataSet() | |||||
for text, label in zip(texts, labels): | |||||
x = TextField(text, is_target=False) | |||||
y = LabelField(label, is_target=True) | |||||
ins = Instance(text=x, label=y) | |||||
data.append(ins) | |||||
# use vocabulary to index data | |||||
data.index_field("text", vocab) | |||||
# define naive sampler for batch class | |||||
class SeqSampler: | |||||
def __call__(self, dataset): | |||||
return list(range(len(dataset))) | |||||
# use batch to iterate dataset | |||||
data_iterator = Batch(data, 2, SeqSampler(), False) | |||||
total_data = 0 | |||||
for batch_x, batch_y in data_iterator: | |||||
total_data += batch_x["text"].size(0) | |||||
self.assertTrue(batch_x["text"].size(0) == 2 or total_data == len(raw_texts)) | |||||
self.assertTrue(isinstance(batch_x, dict)) | |||||
self.assertTrue(isinstance(batch_x["text"], torch.LongTensor)) | |||||
self.assertTrue(isinstance(batch_y, dict)) | |||||
self.assertTrue(isinstance(batch_y["label"], torch.LongTensor)) | |||||
def test_simple(self): | |||||
dataset = construct_dataset( | |||||
[["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | |||||
dataset.set_target() | |||||
batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
cnt = 0 | |||||
for _, _ in batch: | |||||
cnt += 1 | |||||
self.assertEqual(cnt, 10) | |||||
def test_dataset_batching(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
ds.set_input(x=True) | |||||
ds.set_target(y=True) | |||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
for x, y in iter: | |||||
self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray)) | |||||
self.assertEqual(len(x["x"]), 4) | |||||
self.assertEqual(len(y["y"]), 4) | |||||
self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4]) | |||||
self.assertListEqual(list(y["y"][-1]), [5, 6]) |
@@ -1,54 +1,75 @@ | |||||
import unittest | import unittest | ||||
from fastNLP.loader.dataset_loader import convert_seq2seq_dataset, convert_seq_dataset | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
class TestDataSet(unittest.TestCase): | class TestDataSet(unittest.TestCase): | ||||
labeled_data_list = [ | |||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
] | |||||
unlabeled_data_list = [ | |||||
["a", "b", "e", "d"], | |||||
["a", "b", "e", "d"], | |||||
["a", "b", "e", "d"] | |||||
] | |||||
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} | |||||
label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4} | |||||
def test_case_1(self): | |||||
data_set = convert_seq2seq_dataset(self.labeled_data_list) | |||||
data_set.index_field("word_seq", self.word_vocab) | |||||
data_set.index_field("label_seq", self.label_vocab) | |||||
self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
self.assertTrue(len(data_set) > 0) | |||||
self.assertTrue(hasattr(data_set[0], "fields")) | |||||
self.assertTrue("word_seq" in data_set[0].fields) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||||
self.assertTrue("label_seq" in data_set[0].fields) | |||||
self.assertTrue(hasattr(data_set[0].fields["label_seq"], "text")) | |||||
self.assertTrue(hasattr(data_set[0].fields["label_seq"], "_index")) | |||||
self.assertEqual(data_set[0].fields["label_seq"].text, self.labeled_data_list[0][1]) | |||||
self.assertEqual(data_set[0].fields["label_seq"]._index, | |||||
[self.label_vocab[c] for c in self.labeled_data_list[0][1]]) | |||||
def test_case_2(self): | |||||
data_set = convert_seq_dataset(self.unlabeled_data_list) | |||||
data_set.index_field("word_seq", self.word_vocab) | |||||
self.assertEqual(len(data_set), len(self.unlabeled_data_list)) | |||||
self.assertTrue(len(data_set) > 0) | |||||
self.assertTrue(hasattr(data_set[0], "fields")) | |||||
self.assertTrue("word_seq" in data_set[0].fields) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.unlabeled_data_list[0]) | |||||
self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
[self.word_vocab[c] for c in self.unlabeled_data_list[0]]) | |||||
def test_init_v1(self): | |||||
ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40) | |||||
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | |||||
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | |||||
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) | |||||
def test_init_v2(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | |||||
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | |||||
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) | |||||
def test_init_assert(self): | |||||
with self.assertRaises(AssertionError): | |||||
_ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100}) | |||||
with self.assertRaises(AssertionError): | |||||
_ = DataSet([[1, 2, 3, 4]] * 10) | |||||
with self.assertRaises(ValueError): | |||||
_ = DataSet(0.00001) | |||||
def test_append(self): | |||||
dd = DataSet() | |||||
for _ in range(3): | |||||
dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6])) | |||||
self.assertEqual(len(dd), 3) | |||||
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3) | |||||
self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3) | |||||
def test_add_append(self): | |||||
dd = DataSet() | |||||
dd.add_field("x", [[1, 2, 3]] * 10) | |||||
dd.add_field("y", [[1, 2, 3, 4]] * 10) | |||||
dd.add_field("z", [[5, 6]] * 10) | |||||
self.assertEqual(len(dd), 10) | |||||
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3]] * 10) | |||||
self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) | |||||
self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) | |||||
def test_delete_field(self): | |||||
dd = DataSet() | |||||
dd.add_field("x", [[1, 2, 3]] * 10) | |||||
dd.add_field("y", [[1, 2, 3, 4]] * 10) | |||||
dd.delete_field("x") | |||||
self.assertFalse("x" in dd.field_arrays) | |||||
self.assertTrue("y" in dd.field_arrays) | |||||
def test_getitem(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
ins_1, ins_0 = ds[0], ds[1] | |||||
self.assertTrue(isinstance(ins_1, DataSet.Instance) and isinstance(ins_0, DataSet.Instance)) | |||||
self.assertEqual(ins_1["x"], [1, 2, 3, 4]) | |||||
self.assertEqual(ins_1["y"], [5, 6]) | |||||
self.assertEqual(ins_0["x"], [1, 2, 3, 4]) | |||||
self.assertEqual(ins_0["y"], [5, 6]) | |||||
sub_ds = ds[:10] | |||||
self.assertTrue(isinstance(sub_ds, DataSet)) | |||||
self.assertEqual(len(sub_ds), 10) | |||||
field = ds["x"] | |||||
self.assertEqual(field, ds.field_arrays["x"]) | |||||
def test_apply(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx") | |||||
self.assertTrue("rx" in ds.field_arrays) | |||||
self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) |
@@ -1,42 +0,0 @@ | |||||
import unittest | |||||
from fastNLP.core.field import CharTextField, LabelField, SeqLabelField | |||||
class TestField(unittest.TestCase): | |||||
def test_char_field(self): | |||||
text = "PhD applicants must submit a Research Plan and a resume " \ | |||||
"specify your class ranking written in English and a list of research" \ | |||||
" publications if any".split() | |||||
max_word_len = max([len(w) for w in text]) | |||||
field = CharTextField(text, max_word_len, is_target=False) | |||||
all_char = set() | |||||
for word in text: | |||||
all_char.update([ch for ch in word]) | |||||
char_vocab = {ch: idx + 1 for idx, ch in enumerate(all_char)} | |||||
self.assertEqual(field.index(char_vocab), | |||||
[[char_vocab[ch] for ch in word] + [0] * (max_word_len - len(word)) for word in text]) | |||||
self.assertEqual(field.get_length(), len(text)) | |||||
self.assertEqual(field.contents(), text) | |||||
tensor = field.to_tensor(50) | |||||
self.assertEqual(tuple(tensor.shape), (50, max_word_len)) | |||||
def test_label_field(self): | |||||
label = LabelField("A", is_target=True) | |||||
self.assertEqual(label.get_length(), 1) | |||||
self.assertEqual(label.index({"A": 10}), 10) | |||||
label = LabelField(30, is_target=True) | |||||
self.assertEqual(label.get_length(), 1) | |||||
tensor = label.to_tensor(0) | |||||
self.assertEqual(tensor.shape, ()) | |||||
self.assertEqual(int(tensor), 30) | |||||
def test_seq_label_field(self): | |||||
seq = ["a", "b", "c", "d", "a", "c", "a", "b"] | |||||
field = SeqLabelField(seq) | |||||
vocab = {"a": 10, "b": 20, "c": 30, "d": 40} | |||||
self.assertEqual(field.index(vocab), [vocab[x] for x in seq]) | |||||
tensor = field.to_tensor(10) | |||||
self.assertEqual(tuple(tensor.shape), (10,)) |
@@ -0,0 +1,22 @@ | |||||
import unittest | |||||
import numpy as np | |||||
from fastNLP.core.fieldarray import FieldArray | |||||
class TestFieldArray(unittest.TestCase): | |||||
def test(self): | |||||
fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) | |||||
self.assertEqual(len(fa), 5) | |||||
fa.append(6) | |||||
self.assertEqual(len(fa), 6) | |||||
self.assertEqual(fa[-1], 6) | |||||
self.assertEqual(fa[0], 1) | |||||
fa[-1] = 60 | |||||
self.assertEqual(fa[-1], 60) | |||||
self.assertEqual(fa.get(0), 1) | |||||
self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray)) | |||||
self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3]) |
@@ -0,0 +1,29 @@ | |||||
import unittest | |||||
from fastNLP.core.instance import Instance | |||||
class TestCase(unittest.TestCase): | |||||
def test_init(self): | |||||
fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | |||||
ins = Instance(x=[1, 2, 3], y=[4, 5, 6]) | |||||
self.assertTrue(isinstance(ins.fields, dict)) | |||||
self.assertEqual(ins.fields, fields) | |||||
ins = Instance(**fields) | |||||
self.assertEqual(ins.fields, fields) | |||||
def test_add_field(self): | |||||
fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | |||||
ins = Instance(**fields) | |||||
ins.add_field("z", [1, 1, 1]) | |||||
fields.update({"z": [1, 1, 1]}) | |||||
self.assertEqual(ins.fields, fields) | |||||
def test_get_item(self): | |||||
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | |||||
ins = Instance(**fields) | |||||
self.assertEqual(ins["x"], [1, 2, 3]) | |||||
self.assertEqual(ins["y"], [4, 5, 6]) | |||||
self.assertEqual(ins["z"], [1, 1, 1]) |
@@ -1,14 +1,5 @@ | |||||
import os | |||||
import unittest | import unittest | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
from fastNLP.core.field import TextField, LabelField | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.trainer import SeqLabelTrainer | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
import fastNLP.core.loss as loss | import fastNLP.core.loss as loss | ||||
import math | import math | ||||
import torch as tc | import torch as tc | ||||
@@ -1,100 +0,0 @@ | |||||
import os | |||||
import sys | |||||
sys.path = [os.path.join(os.path.dirname(__file__), '..')] + sys.path | |||||
from fastNLP.core import metrics | |||||
# from sklearn import metrics as skmetrics | |||||
import unittest | |||||
from numpy import random | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
import torch | |||||
def generate_fake_label(low, high, size): | |||||
return random.randint(low, high, size), random.randint(low, high, size) | |||||
class TestEvaluator(unittest.TestCase): | |||||
def test_a(self): | |||||
evaluator = SeqLabelEvaluator() | |||||
pred = [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]] | |||||
truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4])}] | |||||
ans = evaluator(pred, truth) | |||||
print(ans) | |||||
def test_b(self): | |||||
evaluator = SeqLabelEvaluator() | |||||
pred = [[1, 2, 3, 4, 5, 0, 0], [1, 2, 3, 4, 5, 0, 0]] | |||||
truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3, 0, 0])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4, 0, 0])}] | |||||
ans = evaluator(pred, truth) | |||||
print(ans) | |||||
class TestMetrics(unittest.TestCase): | |||||
delta = 1e-5 | |||||
# test for binary, multiclass, multilabel | |||||
data_types = [((1000,), 2), ((1000,), 10), ((1000, 10), 2)] | |||||
fake_data = [generate_fake_label(0, high, shape) for shape, high in data_types] | |||||
def test_accuracy_score(self): | |||||
for y_true, y_pred in self.fake_data: | |||||
for normalize in [True, False]: | |||||
for sample_weight in [None, random.rand(y_true.shape[0])]: | |||||
test = metrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) | |||||
# ans = skmetrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) | |||||
# self.assertAlmostEqual(test, ans, delta=self.delta) | |||||
def test_recall_score(self): | |||||
for y_true, y_pred in self.fake_data: | |||||
# print(y_true.shape) | |||||
labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None | |||||
test = metrics.recall_score(y_true, y_pred, labels=labels, average=None) | |||||
if not isinstance(test, list): | |||||
test = list(test) | |||||
# ans = skmetrics.recall_score(y_true, y_pred,labels=labels, average=None) | |||||
# ans = list(ans) | |||||
# for a, b in zip(test, ans): | |||||
# # print('{}, {}'.format(a, b)) | |||||
# self.assertAlmostEqual(a, b, delta=self.delta) | |||||
# test binary | |||||
y_true, y_pred = generate_fake_label(0, 2, 1000) | |||||
test = metrics.recall_score(y_true, y_pred) | |||||
# ans = skmetrics.recall_score(y_true, y_pred) | |||||
# self.assertAlmostEqual(ans, test, delta=self.delta) | |||||
def test_precision_score(self): | |||||
for y_true, y_pred in self.fake_data: | |||||
# print(y_true.shape) | |||||
labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None | |||||
test = metrics.precision_score(y_true, y_pred, labels=labels, average=None) | |||||
# ans = skmetrics.precision_score(y_true, y_pred,labels=labels, average=None) | |||||
# ans, test = list(ans), list(test) | |||||
# for a, b in zip(test, ans): | |||||
# # print('{}, {}'.format(a, b)) | |||||
# self.assertAlmostEqual(a, b, delta=self.delta) | |||||
# test binary | |||||
y_true, y_pred = generate_fake_label(0, 2, 1000) | |||||
test = metrics.precision_score(y_true, y_pred) | |||||
# ans = skmetrics.precision_score(y_true, y_pred) | |||||
# self.assertAlmostEqual(ans, test, delta=self.delta) | |||||
def test_f1_score(self): | |||||
for y_true, y_pred in self.fake_data: | |||||
# print(y_true.shape) | |||||
labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None | |||||
test = metrics.f1_score(y_true, y_pred, labels=labels, average=None) | |||||
# ans = skmetrics.f1_score(y_true, y_pred,labels=labels, average=None) | |||||
# ans, test = list(ans), list(test) | |||||
# for a, b in zip(test, ans): | |||||
# # print('{}, {}'.format(a, b)) | |||||
# self.assertAlmostEqual(a, b, delta=self.delta) | |||||
# test binary | |||||
y_true, y_pred = generate_fake_label(0, 2, 1000) | |||||
test = metrics.f1_score(y_true, y_pred) | |||||
# ans = skmetrics.f1_score(y_true, y_pred) | |||||
# self.assertAlmostEqual(ans, test, delta=self.delta) | |||||
if __name__ == '__main__': | |||||
unittest.main() |
@@ -1,79 +1,6 @@ | |||||
import os | |||||
import unittest | import unittest | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.predictor import Predictor | |||||
from fastNLP.core.preprocess import save_pickle | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.loader.base_loader import BaseLoader | |||||
from fastNLP.loader.dataset_loader import convert_seq_dataset | |||||
from fastNLP.models.cnn_text_classification import CNNText | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
class TestPredictor(unittest.TestCase): | class TestPredictor(unittest.TestCase): | ||||
def test_seq_label(self): | |||||
model_args = { | |||||
"vocab_size": 10, | |||||
"word_emb_dim": 100, | |||||
"rnn_hidden_units": 100, | |||||
"num_classes": 5 | |||||
} | |||||
infer_data = [ | |||||
['a', 'b', 'c', 'd', 'e'], | |||||
['a', '@', 'c', 'd', 'e'], | |||||
['a', 'b', '#', 'd', 'e'], | |||||
['a', 'b', 'c', '?', 'e'], | |||||
['a', 'b', 'c', 'd', '$'], | |||||
['!', 'b', 'c', 'd', 'e'] | |||||
] | |||||
vocab = Vocabulary() | |||||
vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||||
class_vocab = Vocabulary() | |||||
class_vocab.word2idx = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4} | |||||
os.system("mkdir save") | |||||
save_pickle(class_vocab, "./save/", "label2id.pkl") | |||||
save_pickle(vocab, "./save/", "word2id.pkl") | |||||
model = CNNText(model_args) | |||||
import fastNLP.core.predictor as pre | |||||
predictor = Predictor("./save/", pre.text_classify_post_processor) | |||||
# Load infer data | |||||
infer_data_set = convert_seq_dataset(infer_data) | |||||
infer_data_set.index_field("word_seq", vocab) | |||||
results = predictor.predict(network=model, data=infer_data_set) | |||||
self.assertTrue(isinstance(results, list)) | |||||
self.assertGreater(len(results), 0) | |||||
self.assertEqual(len(results), len(infer_data)) | |||||
for res in results: | |||||
self.assertTrue(isinstance(res, str)) | |||||
self.assertTrue(res in class_vocab.word2idx) | |||||
del model, predictor | |||||
infer_data_set.set_origin_len("word_seq") | |||||
model = SeqLabeling(model_args) | |||||
predictor = Predictor("./save/", pre.seq_label_post_processor) | |||||
results = predictor.predict(network=model, data=infer_data_set) | |||||
self.assertTrue(isinstance(results, list)) | |||||
self.assertEqual(len(results), len(infer_data)) | |||||
for i in range(len(infer_data)): | |||||
res = results[i] | |||||
self.assertTrue(isinstance(res, list)) | |||||
self.assertEqual(len(res), len(infer_data[i])) | |||||
os.system("rm -rf save") | |||||
print("pickle path deleted") | |||||
class TestPredictor2(unittest.TestCase): | |||||
def test_text_classify(self): | |||||
# TODO | |||||
def test(self): | |||||
pass | pass |
@@ -1,44 +1,42 @@ | |||||
import unittest | |||||
import torch | import torch | ||||
from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \ | from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \ | ||||
k_means_1d, k_means_bucketing, simple_sort_bucketing | k_means_1d, k_means_bucketing, simple_sort_bucketing | ||||
def test_convert_to_torch_tensor(): | |||||
data = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 3, 4, 5, 2]] | |||||
ans = convert_to_torch_tensor(data, False) | |||||
assert isinstance(ans, torch.Tensor) | |||||
assert tuple(ans.shape) == (3, 5) | |||||
def test_sequential_sampler(): | |||||
sampler = SequentialSampler() | |||||
data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||||
for idx, i in enumerate(sampler(data)): | |||||
assert idx == i | |||||
def test_random_sampler(): | |||||
sampler = RandomSampler() | |||||
data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||||
ans = [data[i] for i in sampler(data)] | |||||
assert len(ans) == len(data) | |||||
for d in ans: | |||||
assert d in data | |||||
def test_k_means(): | |||||
centroids, assign = k_means_1d([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], 2, max_iter=5) | |||||
centroids, assign = list(centroids), list(assign) | |||||
assert len(centroids) == 2 | |||||
assert len(assign) == 10 | |||||
def test_k_means_bucketing(): | |||||
res = k_means_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], [None, None]) | |||||
assert len(res) == 2 | |||||
def test_simple_sort_bucketing(): | |||||
_ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10]) | |||||
assert len(_) == 10 | |||||
class TestSampler(unittest.TestCase): | |||||
def test_convert_to_torch_tensor(self): | |||||
data = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 3, 4, 5, 2]] | |||||
ans = convert_to_torch_tensor(data, False) | |||||
assert isinstance(ans, torch.Tensor) | |||||
assert tuple(ans.shape) == (3, 5) | |||||
def test_sequential_sampler(self): | |||||
sampler = SequentialSampler() | |||||
data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||||
for idx, i in enumerate(sampler(data)): | |||||
assert idx == i | |||||
def test_random_sampler(self): | |||||
sampler = RandomSampler() | |||||
data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||||
ans = [data[i] for i in sampler(data)] | |||||
assert len(ans) == len(data) | |||||
for d in ans: | |||||
assert d in data | |||||
def test_k_means(self): | |||||
centroids, assign = k_means_1d([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], 2, max_iter=5) | |||||
centroids, assign = list(centroids), list(assign) | |||||
assert len(centroids) == 2 | |||||
assert len(assign) == 10 | |||||
def test_k_means_bucketing(self): | |||||
res = k_means_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], [None, None]) | |||||
assert len(res) == 2 | |||||
def test_simple_sort_bucketing(self): | |||||
_ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10]) | |||||
assert len(_) == 10 |
@@ -1,57 +1,9 @@ | |||||
import os | |||||
import unittest | import unittest | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
from fastNLP.core.field import TextField, LabelField | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.tester import SeqLabelTester | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
data_name = "pku_training.utf8" | data_name = "pku_training.utf8" | ||||
pickle_path = "data_for_tests" | pickle_path = "data_for_tests" | ||||
class TestTester(unittest.TestCase): | class TestTester(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
model_args = { | |||||
"vocab_size": 10, | |||||
"word_emb_dim": 100, | |||||
"rnn_hidden_units": 100, | |||||
"num_classes": 5 | |||||
} | |||||
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||||
"save_loss": True, "batch_size": 2, "pickle_path": "./save/", | |||||
"use_cuda": False, "print_every_step": 1, "evaluator": SeqLabelEvaluator()} | |||||
train_data = [ | |||||
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']], | |||||
[['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
] | |||||
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||||
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | |||||
data_set = DataSet() | |||||
for example in train_data: | |||||
text, label = example[0], example[1] | |||||
x = TextField(text, False) | |||||
x_len = LabelField(len(text), is_target=False) | |||||
y = TextField(label, is_target=True) | |||||
ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len) | |||||
data_set.append(ins) | |||||
data_set.index_field("word_seq", vocab) | |||||
data_set.index_field("truth", label_vocab) | |||||
model = SeqLabeling(model_args) | |||||
tester = SeqLabelTester(**valid_args) | |||||
tester.test(network=model, dev_data=data_set) | |||||
# If this can run, everything is OK. | |||||
os.system("rm -rf save") | |||||
print("pickle path deleted") | |||||
pass |
@@ -1,57 +1,6 @@ | |||||
import os | |||||
import unittest | import unittest | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
from fastNLP.core.field import TextField, LabelField | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.loss import Loss | |||||
from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.trainer import SeqLabelTrainer | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
class TestTrainer(unittest.TestCase): | class TestTrainer(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
args = {"epochs": 3, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | |||||
"save_best_dev": True, "model_name": "default_model_name.pkl", | |||||
"loss": Loss("cross_entropy"), | |||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | |||||
"vocab_size": 10, | |||||
"word_emb_dim": 100, | |||||
"rnn_hidden_units": 100, | |||||
"num_classes": 5, | |||||
"evaluator": SeqLabelEvaluator() | |||||
} | |||||
trainer = SeqLabelTrainer(**args) | |||||
train_data = [ | |||||
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']], | |||||
[['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
] | |||||
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||||
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | |||||
data_set = DataSet() | |||||
for example in train_data: | |||||
text, label = example[0], example[1] | |||||
x = TextField(text, False) | |||||
x_len = LabelField(len(text), is_target=False) | |||||
y = TextField(label, is_target=False) | |||||
ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len) | |||||
data_set.append(ins) | |||||
data_set.index_field("word_seq", vocab) | |||||
data_set.index_field("truth", label_vocab) | |||||
model = SeqLabeling(args) | |||||
trainer.train(network=model, train_data=data_set, dev_data=data_set) | |||||
# If this can run, everything is OK. | |||||
os.system("rm -rf save") | |||||
print("pickle path deleted") | |||||
pass |
@@ -1,31 +0,0 @@ | |||||
import unittest | |||||
from fastNLP.core.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX | |||||
class TestVocabulary(unittest.TestCase): | |||||
def test_vocab(self): | |||||
import _pickle as pickle | |||||
import os | |||||
vocab = Vocabulary() | |||||
filename = 'vocab' | |||||
vocab.update(filename) | |||||
vocab.update([filename, ['a'], [['b']], ['c']]) | |||||
idx = vocab[filename] | |||||
before_pic = (vocab.to_word(idx), vocab[filename]) | |||||
with open(filename, 'wb') as f: | |||||
pickle.dump(vocab, f) | |||||
with open(filename, 'rb') as f: | |||||
vocab = pickle.load(f) | |||||
os.remove(filename) | |||||
vocab.build_reverse_vocab() | |||||
after_pic = (vocab.to_word(idx), vocab[filename]) | |||||
TRUE_DICT = {'vocab': 5, 'a': 6, 'b': 7, 'c': 8} | |||||
TRUE_DICT.update(DEFAULT_WORD_TO_INDEX) | |||||
TRUE_IDXDICT = {0: '<pad>', 1: '<unk>', 2: '<reserved-2>', 3: '<reserved-3>', 4: '<reserved-4>', 5: 'vocab', 6: 'a', 7: 'b', 8: 'c'} | |||||
self.assertEqual(before_pic, after_pic) | |||||
self.assertDictEqual(TRUE_DICT, vocab.word2idx) | |||||
self.assertDictEqual(TRUE_IDXDICT, vocab.idx2word) | |||||
if __name__ == '__main__': | |||||
unittest.main() |
@@ -0,0 +1,61 @@ | |||||
import unittest | |||||
from collections import Counter | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | |||||
"works", "well", "in", "most", "cases", "scales", "well"] | |||||
counter = Counter(text) | |||||
class TestAdd(unittest.TestCase): | |||||
def test_add(self): | |||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
for word in text: | |||||
vocab.add(word) | |||||
self.assertEqual(vocab.word_count, counter) | |||||
def test_add_word(self): | |||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
for word in text: | |||||
vocab.add_word(word) | |||||
self.assertEqual(vocab.word_count, counter) | |||||
def test_add_word_lst(self): | |||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab.add_word_lst(text) | |||||
self.assertEqual(vocab.word_count, counter) | |||||
def test_update(self): | |||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab.update(text) | |||||
self.assertEqual(vocab.word_count, counter) | |||||
class TestIndexing(unittest.TestCase): | |||||
def test_len(self): | |||||
vocab = Vocabulary(need_default=False, max_size=None, min_freq=None) | |||||
vocab.update(text) | |||||
self.assertEqual(len(vocab), len(counter)) | |||||
def test_contains(self): | |||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab.update(text) | |||||
self.assertTrue(text[-1] in vocab) | |||||
self.assertFalse("~!@#" in vocab) | |||||
self.assertEqual(text[-1] in vocab, vocab.has_word(text[-1])) | |||||
self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) | |||||
def test_index(self): | |||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab.update(text) | |||||
res = [vocab[w] for w in set(text)] | |||||
self.assertEqual(len(res), len(set(res))) | |||||
res = [vocab.to_index(w) for w in set(text)] | |||||
self.assertEqual(len(res), len(set(res))) | |||||
def test_to_word(self): | |||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab.update(text) | |||||
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) |
@@ -1,13 +1,13 @@ | |||||
import os | import os | ||||
import unittest | import unittest | ||||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | |||||
from fastNLP.saver.config_saver import ConfigSaver | |||||
from fastNLP.io.config_loader import ConfigSection, ConfigLoader | |||||
from fastNLP.io.config_saver import ConfigSaver | |||||
class TestConfigSaver(unittest.TestCase): | class TestConfigSaver(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
config_file_dir = "test/loader/" | |||||
config_file_dir = "test/io/" | |||||
config_file_name = "config" | config_file_name = "config" | ||||
config_file_path = os.path.join(config_file_dir, config_file_name) | config_file_path = os.path.join(config_file_dir, config_file_name) | ||||
@@ -1,53 +0,0 @@ | |||||
import configparser | |||||
import json | |||||
import os | |||||
import unittest | |||||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | |||||
class TestConfigLoader(unittest.TestCase): | |||||
def test_case_ConfigLoader(self): | |||||
def read_section_from_config(config_path, section_name): | |||||
dict = {} | |||||
if not os.path.exists(config_path): | |||||
raise FileNotFoundError("config file {} NOT found.".format(config_path)) | |||||
cfg = configparser.ConfigParser() | |||||
cfg.read(config_path) | |||||
if section_name not in cfg: | |||||
raise AttributeError("config file {} do NOT have section {}".format( | |||||
config_path, section_name | |||||
)) | |||||
gen_sec = cfg[section_name] | |||||
for s in gen_sec.keys(): | |||||
try: | |||||
val = json.loads(gen_sec[s]) | |||||
dict[s] = val | |||||
except Exception as e: | |||||
raise AttributeError("json can NOT load {} in section {}, config file {}".format( | |||||
s, section_name, config_path | |||||
)) | |||||
return dict | |||||
test_arg = ConfigSection() | |||||
ConfigLoader().load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) | |||||
section = read_section_from_config(os.path.join("./test/loader", "config"), "test") | |||||
for sec in section: | |||||
if (sec not in test_arg) or (section[sec] != test_arg[sec]): | |||||
raise AttributeError("ERROR") | |||||
for sec in test_arg.__dict__.keys(): | |||||
if (sec not in section) or (section[sec] != test_arg[sec]): | |||||
raise AttributeError("ERROR") | |||||
try: | |||||
not_exist = test_arg["NOT EXIST"] | |||||
except Exception as e: | |||||
pass | |||||
print("pass config test!") | |||||
@@ -1,53 +0,0 @@ | |||||
import os | |||||
import unittest | |||||
from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \ | |||||
PeopleDailyCorpusLoader, ConllLoader | |||||
from fastNLP.core.dataset import DataSet | |||||
class TestDatasetLoader(unittest.TestCase): | |||||
def test_case_1(self): | |||||
data = """Tom\tT\nand\tF\nJerry\tT\n.\tF\n\nHello\tT\nworld\tF\n!\tF""" | |||||
lines = data.split("\n") | |||||
answer = POSDataSetLoader.parse(lines) | |||||
truth = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], [["Hello", "world", "!"], ["T", "F", "F"]]] | |||||
self.assertListEqual(answer, truth, "POS Dataset Loader") | |||||
def test_case_TokenizeDatasetLoader(self): | |||||
loader = TokenizeDataSetLoader() | |||||
filepath = "./test/data_for_tests/cws_pku_utf_8" | |||||
data = loader.load(filepath, max_seq_len=32) | |||||
assert len(data) > 0 | |||||
data1 = DataSet() | |||||
data1.read_tokenize(filepath, max_seq_len=32) | |||||
assert len(data1) > 0 | |||||
print("pass TokenizeDataSetLoader test!") | |||||
def test_case_POSDatasetLoader(self): | |||||
loader = POSDataSetLoader() | |||||
filepath = "./test/data_for_tests/people.txt" | |||||
data = loader.load("./test/data_for_tests/people.txt") | |||||
datas = loader.load_lines("./test/data_for_tests/people.txt") | |||||
data1 = DataSet().read_pos(filepath) | |||||
assert len(data1) > 0 | |||||
print("pass POSDataSetLoader test!") | |||||
def test_case_LMDatasetLoader(self): | |||||
loader = LMDataSetLoader() | |||||
data = loader.load("./test/data_for_tests/charlm.txt") | |||||
datas = loader.load_lines("./test/data_for_tests/charlm.txt") | |||||
print("pass TokenizeDataSetLoader test!") | |||||
def test_PeopleDailyCorpusLoader(self): | |||||
loader = PeopleDailyCorpusLoader() | |||||
_, _ = loader.load("./test/data_for_tests/people_daily_raw.txt") | |||||
def test_ConllLoader(self): | |||||
loader = ConllLoader() | |||||
_ = loader.load("./test/data_for_tests/conll_example.txt") | |||||
if __name__ == '__main__': | |||||
unittest.main() |
@@ -1,33 +0,0 @@ | |||||
import unittest | |||||
import os | |||||
import torch | |||||
from fastNLP.loader.embed_loader import EmbedLoader | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
class TestEmbedLoader(unittest.TestCase): | |||||
glove_path = './test/data_for_tests/glove.6B.50d_test.txt' | |||||
pkl_path = './save' | |||||
raw_texts = ["i am a cat", | |||||
"this is a test of new batch", | |||||
"ha ha", | |||||
"I am a good boy .", | |||||
"This is the most beautiful girl ." | |||||
] | |||||
texts = [text.strip().split() for text in raw_texts] | |||||
vocab = Vocabulary() | |||||
vocab.update(texts) | |||||
def test1(self): | |||||
emb, _ = EmbedLoader.load_embedding(50, self.glove_path, 'glove', self.vocab, self.pkl_path) | |||||
self.assertTrue(emb.shape[0] == (len(self.vocab))) | |||||
self.assertTrue(emb.shape[1] == 50) | |||||
os.remove(self.pkl_path) | |||||
def test2(self): | |||||
try: | |||||
_ = EmbedLoader.load_embedding(100, self.glove_path, 'glove', self.vocab, self.pkl_path) | |||||
self.fail(msg="load dismatch embedding") | |||||
except ValueError: | |||||
pass |
@@ -1,150 +0,0 @@ | |||||
import os | |||||
import sys | |||||
sys.path.append("..") | |||||
import argparse | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.loader.dataset_loader import BaseLoader | |||||
from fastNLP.saver.model_saver import ModelSaver | |||||
from fastNLP.loader.model_loader import ModelLoader | |||||
from fastNLP.core.tester import SeqLabelTester | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
from fastNLP.core.predictor import SeqLabelInfer | |||||
from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
from fastNLP.core.preprocess import save_pickle, load_pickle | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files") | |||||
parser.add_argument("-t", "--train", type=str, default="../data_for_tests/people.txt", | |||||
help="path to the training data") | |||||
parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file") | |||||
parser.add_argument("-m", "--model_name", type=str, default="seq_label_model.pkl", help="the name of the model") | |||||
parser.add_argument("-i", "--infer", type=str, default="../data_for_tests/people_infer.txt", | |||||
help="data used for inference") | |||||
args = parser.parse_args() | |||||
pickle_path = args.save | |||||
model_name = args.model_name | |||||
config_dir = args.config | |||||
data_path = args.train | |||||
data_infer_path = args.infer | |||||
def infer(): | |||||
# Load infer configuration, the same as test | |||||
test_args = ConfigSection() | |||||
ConfigLoader().load_config(config_dir, {"POS_infer": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | |||||
word_vocab = load_pickle(pickle_path, "word2id.pkl") | |||||
label_vocab = load_pickle(pickle_path, "label2id.pkl") | |||||
test_args["vocab_size"] = len(word_vocab) | |||||
test_args["num_classes"] = len(label_vocab) | |||||
print("vocabularies loaded") | |||||
# Define the same model | |||||
model = SeqLabeling(test_args) | |||||
print("model defined") | |||||
# Dump trained parameters into the model | |||||
ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | |||||
print("model loaded!") | |||||
# Data Loader | |||||
infer_data = SeqLabelDataSet(load_func=BaseLoader.load) | |||||
infer_data.load(data_infer_path, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}, infer=True) | |||||
print("data set prepared") | |||||
# Inference interface | |||||
infer = SeqLabelInfer(pickle_path) | |||||
results = infer.predict(model, infer_data) | |||||
for res in results: | |||||
print(res) | |||||
print("Inference finished!") | |||||
def train_and_test(): | |||||
# Config Loader | |||||
trainer_args = ConfigSection() | |||||
model_args = ConfigSection() | |||||
ConfigLoader().load_config(config_dir, { | |||||
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | |||||
data_set = SeqLabelDataSet() | |||||
data_set.load(data_path) | |||||
train_set, dev_set = data_set.split(0.3, shuffle=True) | |||||
model_args["vocab_size"] = len(data_set.word_vocab) | |||||
model_args["num_classes"] = len(data_set.label_vocab) | |||||
save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl") | |||||
save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl") | |||||
""" | |||||
trainer = SeqLabelTrainer( | |||||
epochs=trainer_args["epochs"], | |||||
batch_size=trainer_args["batch_size"], | |||||
validate=False, | |||||
use_cuda=trainer_args["use_cuda"], | |||||
pickle_path=pickle_path, | |||||
save_best_dev=trainer_args["save_best_dev"], | |||||
model_name=model_name, | |||||
optimizer=Optimizer("SGD", lr=0.01, momentum=0.9), | |||||
) | |||||
""" | |||||
# Model | |||||
model = SeqLabeling(model_args) | |||||
model.fit(train_set, dev_set, | |||||
epochs=trainer_args["epochs"], | |||||
batch_size=trainer_args["batch_size"], | |||||
validate=False, | |||||
use_cuda=trainer_args["use_cuda"], | |||||
pickle_path=pickle_path, | |||||
save_best_dev=trainer_args["save_best_dev"], | |||||
model_name=model_name, | |||||
optimizer=Optimizer("SGD", lr=0.01, momentum=0.9)) | |||||
# Start training | |||||
# trainer.train(model, train_set, dev_set) | |||||
print("Training finished!") | |||||
# Saver | |||||
saver = ModelSaver(os.path.join(pickle_path, model_name)) | |||||
saver.save_pytorch(model) | |||||
print("Model saved!") | |||||
del model | |||||
change_field_is_target(dev_set, "truth", True) | |||||
# Define the same model | |||||
model = SeqLabeling(model_args) | |||||
# Dump trained parameters into the model | |||||
ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | |||||
print("model loaded!") | |||||
# Load test configuration | |||||
tester_args = ConfigSection() | |||||
ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||||
# Tester | |||||
tester = SeqLabelTester(batch_size=4, | |||||
use_cuda=False, | |||||
pickle_path=pickle_path, | |||||
model_name="seq_label_in_test.pkl", | |||||
evaluator=SeqLabelEvaluator() | |||||
) | |||||
# Start testing with validation data | |||||
tester.test(model, dev_set) | |||||
print("model tested!") | |||||
if __name__ == "__main__": | |||||
train_and_test() | |||||
infer() |
@@ -1,25 +0,0 @@ | |||||
import unittest | |||||
import numpy as np | |||||
import torch | |||||
from fastNLP.models.char_language_model import CharLM | |||||
class TestCharLM(unittest.TestCase): | |||||
def test_case_1(self): | |||||
char_emb_dim = 50 | |||||
word_emb_dim = 50 | |||||
vocab_size = 1000 | |||||
num_char = 24 | |||||
max_word_len = 21 | |||||
num_seq = 64 | |||||
seq_len = 32 | |||||
model = CharLM(char_emb_dim, word_emb_dim, vocab_size, num_char) | |||||
x = torch.from_numpy(np.random.randint(0, num_char, size=(num_seq, seq_len, max_word_len + 2))) | |||||
self.assertEqual(tuple(x.shape), (num_seq, seq_len, max_word_len + 2)) | |||||
y = model(x) | |||||
self.assertEqual(tuple(y.shape), (num_seq * seq_len, vocab_size)) |
@@ -1,112 +0,0 @@ | |||||
import os | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
from fastNLP.core.predictor import SeqLabelInfer | |||||
from fastNLP.core.preprocess import save_pickle, load_pickle | |||||
from fastNLP.core.tester import SeqLabelTester | |||||
from fastNLP.core.trainer import SeqLabelTrainer | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader, RawDataSetLoader | |||||
from fastNLP.loader.model_loader import ModelLoader | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
from fastNLP.saver.model_saver import ModelSaver | |||||
data_name = "pku_training.utf8" | |||||
cws_data_path = "./test/data_for_tests/cws_pku_utf_8" | |||||
pickle_path = "./save/" | |||||
data_infer_path = "./test/data_for_tests/people_infer.txt" | |||||
config_path = "./test/data_for_tests/config" | |||||
def infer(): | |||||
# Load infer configuration, the same as test | |||||
test_args = ConfigSection() | |||||
ConfigLoader().load_config(config_path, {"POS_infer": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | |||||
word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
test_args["vocab_size"] = len(word2index) | |||||
index2label = load_pickle(pickle_path, "label2id.pkl") | |||||
test_args["num_classes"] = len(index2label) | |||||
# Define the same model | |||||
model = SeqLabeling(test_args) | |||||
# Dump trained parameters into the model | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
print("model loaded!") | |||||
# Load infer data | |||||
infer_data = RawDataSetLoader().load(data_infer_path) | |||||
infer_data.index_field("word_seq", word2index) | |||||
infer_data.set_origin_len("word_seq") | |||||
# inference | |||||
infer = SeqLabelInfer(pickle_path) | |||||
results = infer.predict(model, infer_data) | |||||
print(results) | |||||
def train_test(): | |||||
# Config Loader | |||||
train_args = ConfigSection() | |||||
ConfigLoader().load_config(config_path, {"POS_infer": train_args}) | |||||
# define dataset | |||||
data_train = TokenizeDataSetLoader().load(cws_data_path) | |||||
word_vocab = Vocabulary() | |||||
label_vocab = Vocabulary() | |||||
data_train.update_vocab(word_seq=word_vocab, label_seq=label_vocab) | |||||
data_train.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab) | |||||
data_train.set_origin_len("word_seq") | |||||
data_train.rename_field("label_seq", "truth").set_target(truth=False) | |||||
train_args["vocab_size"] = len(word_vocab) | |||||
train_args["num_classes"] = len(label_vocab) | |||||
save_pickle(word_vocab, pickle_path, "word2id.pkl") | |||||
save_pickle(label_vocab, pickle_path, "label2id.pkl") | |||||
# Trainer | |||||
trainer = SeqLabelTrainer(**train_args.data) | |||||
# Model | |||||
model = SeqLabeling(train_args) | |||||
# Start training | |||||
trainer.train(model, data_train) | |||||
# Saver | |||||
saver = ModelSaver("./save/saved_model.pkl") | |||||
saver.save_pytorch(model) | |||||
del model, trainer | |||||
# Define the same model | |||||
model = SeqLabeling(train_args) | |||||
# Dump trained parameters into the model | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
# Load test configuration | |||||
test_args = ConfigSection() | |||||
ConfigLoader().load_config(config_path, {"POS_infer": test_args}) | |||||
test_args["evaluator"] = SeqLabelEvaluator() | |||||
# Tester | |||||
tester = SeqLabelTester(**test_args.data) | |||||
# Start testing | |||||
data_train.set_target(truth=True) | |||||
tester.test(model, data_train) | |||||
def test(): | |||||
os.makedirs("save", exist_ok=True) | |||||
train_test() | |||||
infer() | |||||
os.system("rm -rf save") | |||||
if __name__ == "__main__": | |||||
train_test() | |||||
infer() |
@@ -1,86 +0,0 @@ | |||||
import os | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.preprocess import save_pickle | |||||
from fastNLP.core.tester import SeqLabelTester | |||||
from fastNLP.core.trainer import SeqLabelTrainer | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.loader.model_loader import ModelLoader | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
from fastNLP.saver.model_saver import ModelSaver | |||||
pickle_path = "./seq_label/" | |||||
model_name = "seq_label_model.pkl" | |||||
config_dir = "test/data_for_tests/config" | |||||
data_path = "test/data_for_tests/people.txt" | |||||
data_infer_path = "test/data_for_tests/people_infer.txt" | |||||
def test_training(): | |||||
# Config Loader | |||||
trainer_args = ConfigSection() | |||||
model_args = ConfigSection() | |||||
ConfigLoader().load_config(config_dir, { | |||||
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | |||||
data_set = TokenizeDataSetLoader().load(data_path) | |||||
word_vocab = Vocabulary() | |||||
label_vocab = Vocabulary() | |||||
data_set.update_vocab(word_seq=word_vocab, label_seq=label_vocab) | |||||
data_set.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab) | |||||
data_set.set_origin_len("word_seq") | |||||
data_set.rename_field("label_seq", "truth").set_target(truth=False) | |||||
data_train, data_dev = data_set.split(0.3, shuffle=True) | |||||
model_args["vocab_size"] = len(word_vocab) | |||||
model_args["num_classes"] = len(label_vocab) | |||||
save_pickle(word_vocab, pickle_path, "word2id.pkl") | |||||
save_pickle(label_vocab, pickle_path, "label2id.pkl") | |||||
trainer = SeqLabelTrainer( | |||||
epochs=trainer_args["epochs"], | |||||
batch_size=trainer_args["batch_size"], | |||||
validate=False, | |||||
use_cuda=False, | |||||
pickle_path=pickle_path, | |||||
save_best_dev=trainer_args["save_best_dev"], | |||||
model_name=model_name, | |||||
optimizer=Optimizer("SGD", lr=0.01, momentum=0.9), | |||||
) | |||||
# Model | |||||
model = SeqLabeling(model_args) | |||||
# Start training | |||||
trainer.train(model, data_train, data_dev) | |||||
# Saver | |||||
saver = ModelSaver(os.path.join(pickle_path, model_name)) | |||||
saver.save_pytorch(model) | |||||
del model, trainer | |||||
# Define the same model | |||||
model = SeqLabeling(model_args) | |||||
# Dump trained parameters into the model | |||||
ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | |||||
# Load test configuration | |||||
tester_args = ConfigSection() | |||||
ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||||
# Tester | |||||
tester = SeqLabelTester(batch_size=4, | |||||
use_cuda=False, | |||||
pickle_path=pickle_path, | |||||
model_name="seq_label_in_test.pkl", | |||||
evaluator=SeqLabelEvaluator() | |||||
) | |||||
# Start testing with validation data | |||||
data_dev.set_target(truth=True) | |||||
tester.test(model, data_dev) |
@@ -1,107 +0,0 @@ | |||||
# Python: 3.5 | |||||
# encoding: utf-8 | |||||
import argparse | |||||
import os | |||||
import sys | |||||
sys.path.append("..") | |||||
from fastNLP.core.predictor import ClassificationInfer | |||||
from fastNLP.core.trainer import ClassificationTrainer | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.loader.dataset_loader import ClassDataSetLoader | |||||
from fastNLP.loader.model_loader import ModelLoader | |||||
from fastNLP.models.cnn_text_classification import CNNText | |||||
from fastNLP.saver.model_saver import ModelSaver | |||||
from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.loss import Loss | |||||
from fastNLP.core.dataset import TextClassifyDataSet | |||||
from fastNLP.core.preprocess import save_pickle, load_pickle | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files") | |||||
parser.add_argument("-t", "--train", type=str, default="../data_for_tests/text_classify.txt", | |||||
help="path to the training data") | |||||
parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file") | |||||
parser.add_argument("-m", "--model_name", type=str, default="classify_model.pkl", help="the name of the model") | |||||
args = parser.parse_args() | |||||
save_dir = args.save | |||||
train_data_dir = args.train | |||||
model_name = args.model_name | |||||
config_dir = args.config | |||||
def infer(): | |||||
# load dataset | |||||
print("Loading data...") | |||||
word_vocab = load_pickle(save_dir, "word2id.pkl") | |||||
label_vocab = load_pickle(save_dir, "label2id.pkl") | |||||
print("vocabulary size:", len(word_vocab)) | |||||
print("number of classes:", len(label_vocab)) | |||||
infer_data = TextClassifyDataSet(load_func=ClassDataSetLoader.load) | |||||
infer_data.load(train_data_dir, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}) | |||||
model_args = ConfigSection() | |||||
model_args["vocab_size"] = len(word_vocab) | |||||
model_args["num_classes"] = len(label_vocab) | |||||
ConfigLoader.load_config(config_dir, {"text_class_model": model_args}) | |||||
# construct model | |||||
print("Building model...") | |||||
cnn = CNNText(model_args) | |||||
# Dump trained parameters into the model | |||||
ModelLoader.load_pytorch(cnn, os.path.join(save_dir, model_name)) | |||||
print("model loaded!") | |||||
infer = ClassificationInfer(pickle_path=save_dir) | |||||
results = infer.predict(cnn, infer_data) | |||||
print(results) | |||||
def train(): | |||||
train_args, model_args = ConfigSection(), ConfigSection() | |||||
ConfigLoader.load_config(config_dir, {"text_class": train_args}) | |||||
# load dataset | |||||
print("Loading data...") | |||||
data = TextClassifyDataSet(load_func=ClassDataSetLoader.load) | |||||
data.load(train_data_dir) | |||||
print("vocabulary size:", len(data.word_vocab)) | |||||
print("number of classes:", len(data.label_vocab)) | |||||
save_pickle(data.word_vocab, save_dir, "word2id.pkl") | |||||
save_pickle(data.label_vocab, save_dir, "label2id.pkl") | |||||
model_args["num_classes"] = len(data.label_vocab) | |||||
model_args["vocab_size"] = len(data.word_vocab) | |||||
# construct model | |||||
print("Building model...") | |||||
model = CNNText(model_args) | |||||
# train | |||||
print("Training...") | |||||
trainer = ClassificationTrainer(epochs=train_args["epochs"], | |||||
batch_size=train_args["batch_size"], | |||||
validate=train_args["validate"], | |||||
use_cuda=train_args["use_cuda"], | |||||
pickle_path=save_dir, | |||||
save_best_dev=train_args["save_best_dev"], | |||||
model_name=model_name, | |||||
loss=Loss("cross_entropy"), | |||||
optimizer=Optimizer("SGD", lr=0.001, momentum=0.9)) | |||||
trainer.train(model, data) | |||||
print("Training finished!") | |||||
saver = ModelSaver(os.path.join(save_dir, model_name)) | |||||
saver.save_pytorch(model) | |||||
print("Model saved!") | |||||
if __name__ == "__main__": | |||||
train() | |||||
infer() |
@@ -14,7 +14,7 @@ class TestGroupNorm(unittest.TestCase): | |||||
class TestLayerNormalization(unittest.TestCase): | class TestLayerNormalization(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
ln = LayerNormalization(d_hid=5, eps=2e-3) | |||||
ln = LayerNormalization(layer_size=5, eps=2e-3) | |||||
x = torch.randn((20, 50, 5)) | x = torch.randn((20, 50, 5)) | ||||
y = ln(x) | y = ln(x) | ||||
@@ -1,213 +0,0 @@ | |||||
# encoding: utf-8 | |||||
import os | |||||
from fastNLP.core.preprocess import save_pickle | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.fastnlp import FastNLP | |||||
from fastNLP.fastnlp import interpret_word_seg_results, interpret_cws_pos_results | |||||
from fastNLP.models.cnn_text_classification import CNNText | |||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | |||||
from fastNLP.saver.model_saver import ModelSaver | |||||
PATH_TO_CWS_PICKLE_FILES = "/home/zyfeng/fastNLP/reproduction/chinese_word_segment/save/" | |||||
PATH_TO_POS_TAG_PICKLE_FILES = "/home/zyfeng/data/crf_seg/" | |||||
PATH_TO_TEXT_CLASSIFICATION_PICKLE_FILES = "/home/zyfeng/data/text_classify/" | |||||
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | |||||
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | |||||
DEFAULT_RESERVED_LABEL = ['<reserved-2>', | |||||
'<reserved-3>', | |||||
'<reserved-4>'] # dict index = 2~4 | |||||
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, | |||||
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, | |||||
DEFAULT_RESERVED_LABEL[2]: 4} | |||||
def word_seg(model_dir, config, section): | |||||
nlp = FastNLP(model_dir=model_dir) | |||||
nlp.load("cws_basic_model", config_file=config, section_name=section) | |||||
text = ["这是最好的基于深度学习的中文分词系统。", | |||||
"大王叫我来巡山。", | |||||
"我党多年来致力于改善人民生活水平。"] | |||||
results = nlp.run(text) | |||||
print(results) | |||||
for example in results: | |||||
words, labels = [], [] | |||||
for res in example: | |||||
words.append(res[0]) | |||||
labels.append(res[1]) | |||||
print(interpret_word_seg_results(words, labels)) | |||||
def mock_cws(): | |||||
os.makedirs("mock", exist_ok=True) | |||||
text = ["这是最好的基于深度学习的中文分词系统。", | |||||
"大王叫我来巡山。", | |||||
"我党多年来致力于改善人民生活水平。"] | |||||
word2id = Vocabulary() | |||||
word_list = [ch for ch in "".join(text)] | |||||
word2id.update(word_list) | |||||
save_pickle(word2id, "./mock/", "word2id.pkl") | |||||
class2id = Vocabulary(need_default=False) | |||||
label_list = ['B', 'M', 'E', 'S'] | |||||
class2id.update(label_list) | |||||
save_pickle(class2id, "./mock/", "label2id.pkl") | |||||
model_args = {"vocab_size": len(word2id), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(class2id)} | |||||
config_file = """ | |||||
[test_section] | |||||
vocab_size = {} | |||||
word_emb_dim = 50 | |||||
rnn_hidden_units = 50 | |||||
num_classes = {} | |||||
""".format(len(word2id), len(class2id)) | |||||
with open("mock/test.cfg", "w", encoding="utf-8") as f: | |||||
f.write(config_file) | |||||
model = AdvSeqLabel(model_args) | |||||
ModelSaver("mock/cws_basic_model_v_0.pkl").save_pytorch(model) | |||||
def test_word_seg(): | |||||
# fake the model and pickles | |||||
print("start mocking") | |||||
mock_cws() | |||||
# run the inference codes | |||||
print("start testing") | |||||
word_seg("./mock/", "test.cfg", "test_section") | |||||
# clean up environments | |||||
print("clean up") | |||||
os.system("rm -rf mock") | |||||
def pos_tag(model_dir, config, section): | |||||
nlp = FastNLP(model_dir=model_dir) | |||||
nlp.load("pos_tag_model", config_file=config, section_name=section) | |||||
text = ["这是最好的基于深度学习的中文分词系统。", | |||||
"大王叫我来巡山。", | |||||
"我党多年来致力于改善人民生活水平。"] | |||||
results = nlp.run(text) | |||||
for example in results: | |||||
words, labels = [], [] | |||||
for res in example: | |||||
words.append(res[0]) | |||||
labels.append(res[1]) | |||||
try: | |||||
print(interpret_cws_pos_results(words, labels)) | |||||
except RuntimeError: | |||||
print("inconsistent pos tags. this is for test only.") | |||||
def mock_pos_tag(): | |||||
os.makedirs("mock", exist_ok=True) | |||||
text = ["这是最好的基于深度学习的中文分词系统。", | |||||
"大王叫我来巡山。", | |||||
"我党多年来致力于改善人民生活水平。"] | |||||
vocab = Vocabulary() | |||||
word_list = [ch for ch in "".join(text)] | |||||
vocab.update(word_list) | |||||
save_pickle(vocab, "./mock/", "word2id.pkl") | |||||
idx2label = Vocabulary(need_default=False) | |||||
label_list = ['B-n', 'M-v', 'E-nv', 'S-adj', 'B-v', 'M-vn', 'S-adv'] | |||||
idx2label.update(label_list) | |||||
save_pickle(idx2label, "./mock/", "label2id.pkl") | |||||
model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)} | |||||
config_file = """ | |||||
[test_section] | |||||
vocab_size = {} | |||||
word_emb_dim = 50 | |||||
rnn_hidden_units = 50 | |||||
num_classes = {} | |||||
""".format(len(vocab), len(idx2label)) | |||||
with open("mock/test.cfg", "w", encoding="utf-8") as f: | |||||
f.write(config_file) | |||||
model = AdvSeqLabel(model_args) | |||||
ModelSaver("mock/pos_tag_model_v_0.pkl").save_pytorch(model) | |||||
def test_pos_tag(): | |||||
mock_pos_tag() | |||||
pos_tag("./mock/", "test.cfg", "test_section") | |||||
os.system("rm -rf mock") | |||||
def text_classify(model_dir, config, section): | |||||
nlp = FastNLP(model_dir=model_dir) | |||||
nlp.load("text_classify_model", config_file=config, section_name=section) | |||||
text = [ | |||||
"世界物联网大会明日在京召开龙头股启动在即", | |||||
"乌鲁木齐市新增一处城市中心旅游目的地", | |||||
"朱元璋的大明朝真的源于明教吗?——告诉你一个真实的“明教”"] | |||||
results = nlp.run(text) | |||||
print(results) | |||||
def mock_text_classify(): | |||||
os.makedirs("mock", exist_ok=True) | |||||
text = ["世界物联网大会明日在京召开龙头股启动在即", | |||||
"乌鲁木齐市新增一处城市中心旅游目的地", | |||||
"朱元璋的大明朝真的源于明教吗?——告诉你一个真实的“明教”" | |||||
] | |||||
vocab = Vocabulary() | |||||
word_list = [ch for ch in "".join(text)] | |||||
vocab.update(word_list) | |||||
save_pickle(vocab, "./mock/", "word2id.pkl") | |||||
idx2label = Vocabulary(need_default=False) | |||||
label_list = ['class_A', 'class_B', 'class_C', 'class_D', 'class_E', 'class_F'] | |||||
idx2label.update(label_list) | |||||
save_pickle(idx2label, "./mock/", "label2id.pkl") | |||||
model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)} | |||||
config_file = """ | |||||
[test_section] | |||||
vocab_size = {} | |||||
word_emb_dim = 50 | |||||
rnn_hidden_units = 50 | |||||
num_classes = {} | |||||
""".format(len(vocab), len(idx2label)) | |||||
with open("mock/test.cfg", "w", encoding="utf-8") as f: | |||||
f.write(config_file) | |||||
model = CNNText(model_args) | |||||
ModelSaver("mock/text_class_model_v0.pkl").save_pytorch(model) | |||||
def test_text_classify(): | |||||
mock_text_classify() | |||||
text_classify("./mock/", "test.cfg", "test_section") | |||||
os.system("rm -rf mock") | |||||
def test_word_seg_interpret(): | |||||
foo = [[('这', 'S'), ('是', 'S'), ('最', 'S'), ('好', 'S'), ('的', 'S'), ('基', 'B'), ('于', 'E'), ('深', 'B'), ('度', 'E'), | |||||
('学', 'B'), ('习', 'E'), ('的', 'S'), ('中', 'B'), ('文', 'E'), ('分', 'B'), ('词', 'E'), ('系', 'B'), ('统', 'E'), | |||||
('。', 'S')]] | |||||
chars = [x[0] for x in foo[0]] | |||||
labels = [x[1] for x in foo[0]] | |||||
print(interpret_word_seg_results(chars, labels)) | |||||
def test_interpret_cws_pos_results(): | |||||
foo = [ | |||||
[('这', 'S-r'), ('是', 'S-v'), ('最', 'S-d'), ('好', 'S-a'), ('的', 'S-u'), ('基', 'B-p'), ('于', 'E-p'), ('深', 'B-d'), | |||||
('度', 'E-d'), ('学', 'B-v'), ('习', 'E-v'), ('的', 'S-u'), ('中', 'B-nz'), ('文', 'E-nz'), ('分', 'B-vn'), | |||||
('词', 'E-vn'), ('系', 'B-n'), ('统', 'E-n'), ('。', 'S-w')] | |||||
] | |||||
chars = [x[0] for x in foo[0]] | |||||
labels = [x[1] for x in foo[0]] | |||||
print(interpret_cws_pos_results(chars, labels)) | |||||
if __name__ == "__main__": | |||||
test_word_seg() | |||||
test_pos_tag() | |||||
test_text_classify() | |||||
test_word_seg_interpret() | |||||
test_interpret_cws_pos_results() |