@@ -0,0 +1,17 @@ | |||||
Description:简要描述这次PR的内容 | |||||
Main reason: 做出这次修改的原因 | |||||
Checklist 检查下面各项是否完成 | |||||
Please feel free to remove inapplicable items for your PR. | |||||
- [ ] The PR title starts with [$CATEGORY] (such as [Models], [Modules], [Core], [io], [Doc], 分别对应各个子模块) | |||||
- [ ] Changes are complete (i.e. I finished coding on this PR) 代码写完了 | |||||
- [ ] All changes have test coverage 修改的地方经过测试。对于可复用部分的修改,例如core/和modules/,测试代码必须提供。其他部分建议提供。 | |||||
- [ ] Code is well-documented 注释写好,文档会从注释中自动抽取 | |||||
- [ ] To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change 这种情况请找核心开发人员 | |||||
Changes: 逐项描述修改的内容 | |||||
- Switch to sparse_coo_matrix for torch v1.0. #282 | |||||
- Fix bug that nx graph to dgl graph is not properly converted. #286 |
@@ -6,17 +6,43 @@ | |||||
![Hex.pm](https://img.shields.io/hexpm/l/plug.svg) | ![Hex.pm](https://img.shields.io/hexpm/l/plug.svg) | ||||
[![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest) | [![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest) | ||||
fastNLP is a modular Natural Language Processing system based on PyTorch, for fast development of NLP tools. It divides the NLP model based on deep learning into different modules. These modules fall into 4 categories: encoder, interaction, aggregation and decoder, while each category contains different implemented modules. Encoder modules encode the input into some abstract representation, interaction modules make the information in the representation interact with each other, aggregation modules aggregate and reduce information, and decoder modules decode the representation into the output. Most current NLP models could be built on these modules, which vastly simplifies the process of developing NLP models. The architecture of fastNLP is as the figure below: | |||||
FastNLP is a modular Natural Language Processing system based on PyTorch, built for fast development of NLP models. | |||||
![](https://github.com/fastnlp/fastNLP/raw/master/docs/source/figures/procedures.PNG) | |||||
![](https://github.com/fastnlp/fastNLP/raw/master/docs/source/figures/text_classification.png) | |||||
A deep learning NLP model is the composition of three types of modules: | |||||
<table> | |||||
<tr> | |||||
<td><b> module type </b></td> | |||||
<td><b> functionality </b></td> | |||||
<td><b> example </b></td> | |||||
</tr> | |||||
<tr> | |||||
<td> encoder </td> | |||||
<td> encode the input into some abstract representation </td> | |||||
<td> embedding, RNN, CNN, transformer | |||||
</tr> | |||||
<tr> | |||||
<td> aggregator </td> | |||||
<td> aggregate and reduce information </td> | |||||
<td> self-attention, max-pooling </td> | |||||
</tr> | |||||
<tr> | |||||
<td> decoder </td> | |||||
<td> decode the representation into the output </td> | |||||
<td> MLP, CRF </td> | |||||
</tr> | |||||
</table> | |||||
For example: | |||||
![](docs/source/figures/text_classification.png) | |||||
## Requirements | ## Requirements | ||||
- Python>=3.6 | |||||
- numpy>=1.14.2 | - numpy>=1.14.2 | ||||
- torch==0.4.0 | |||||
- torchvision>=0.1.8 | |||||
- torch>=0.4.0 | |||||
- tensorboardX | - tensorboardX | ||||
- tqdm>=4.28.1 | |||||
## Resources | ## Resources | ||||
@@ -39,12 +65,12 @@ pip install fastNLP | |||||
<td> an open-source NLP library </td> | <td> an open-source NLP library </td> | ||||
</tr> | </tr> | ||||
<tr> | <tr> | ||||
<td><b> fastNLP.core </b></td> | |||||
<td> trainer, tester, predictor </td> | |||||
<td><b> fastNLP.api </b></td> | |||||
<td> APIs for end-to-end prediction </td> | |||||
</tr> | </tr> | ||||
<tr> | <tr> | ||||
<td><b> fastNLP.loader </b></td> | |||||
<td> all kinds of loaders/readers </td> | |||||
<td><b> fastNLP.core </b></td> | |||||
<td> data representation & train/test presedure </td> | |||||
</tr> | </tr> | ||||
<tr> | <tr> | ||||
<td><b> fastNLP.models </b></td> | <td><b> fastNLP.models </b></td> | ||||
@@ -55,11 +81,7 @@ pip install fastNLP | |||||
<td> a collection of PyTorch sub-models/components/wheels </td> | <td> a collection of PyTorch sub-models/components/wheels </td> | ||||
</tr> | </tr> | ||||
<tr> | <tr> | ||||
<td><b> fastNLP.saver </b></td> | |||||
<td> all kinds of savers/writers </td> | |||||
</tr> | |||||
<tr> | |||||
<td><b> fastNLP.fastnlp </b></td> | |||||
<td> a high-level interface for prediction </td> | |||||
<td><b> fastNLP.io </b></td> | |||||
<td> readers & savers </td> | |||||
</tr> | </tr> | ||||
</table> | </table> |
@@ -1 +1,2 @@ | |||||
# FastNLP Quick Tutorial | |||||
# FastNLP Quick Tutorial | |||||
@@ -1,75 +0,0 @@ | |||||
from fastNLP.core.loss import Loss | |||||
from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.predictor import ClassificationInfer | |||||
from fastNLP.core.preprocess import ClassPreprocess | |||||
from fastNLP.core.trainer import ClassificationTrainer | |||||
from fastNLP.loader.dataset_loader import ClassDataSetLoader | |||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.modules import aggregator | |||||
from fastNLP.modules import decoder | |||||
from fastNLP.modules import encoder | |||||
class ClassificationModel(BaseModel): | |||||
""" | |||||
Simple text classification model based on CNN. | |||||
""" | |||||
def __init__(self, num_classes, vocab_size): | |||||
super(ClassificationModel, self).__init__() | |||||
self.emb = encoder.Embedding(nums=vocab_size, dims=300) | |||||
self.enc = encoder.Conv( | |||||
in_channels=300, out_channels=100, kernel_size=3) | |||||
self.agg = aggregator.MaxPool() | |||||
self.dec = decoder.MLP(size_layer=[100, num_classes]) | |||||
def forward(self, x): | |||||
x = self.emb(x) # [N,L] -> [N,L,C] | |||||
x = self.enc(x) # [N,L,C_in] -> [N,L,C_out] | |||||
x = self.agg(x) # [N,L,C] -> [N,C] | |||||
x = self.dec(x) # [N,C] -> [N, N_class] | |||||
return x | |||||
data_dir = 'save/' # directory to save data and model | |||||
train_path = './data_for_tests/text_classify.txt' # training set file | |||||
# load dataset | |||||
ds_loader = ClassDataSetLoader() | |||||
data = ds_loader.load() | |||||
# pre-process dataset | |||||
pre = ClassPreprocess() | |||||
train_set, dev_set = pre.run(data, train_dev_split=0.3, pickle_path=data_dir) | |||||
n_classes, vocab_size = pre.num_classes, pre.vocab_size | |||||
# construct model | |||||
model_args = { | |||||
'num_classes': n_classes, | |||||
'vocab_size': vocab_size | |||||
} | |||||
model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size) | |||||
# construct trainer | |||||
train_args = { | |||||
"epochs": 3, | |||||
"batch_size": 16, | |||||
"pickle_path": data_dir, | |||||
"validate": False, | |||||
"save_best_dev": False, | |||||
"model_saved_path": None, | |||||
"use_cuda": True, | |||||
"loss": Loss("cross_entropy"), | |||||
"optimizer": Optimizer("Adam", lr=0.001) | |||||
} | |||||
trainer = ClassificationTrainer(**train_args) | |||||
# start training | |||||
trainer.train(model, train_data=train_set, dev_data=dev_set) | |||||
# predict using model | |||||
data_infer = [x[0] for x in data] | |||||
infer = ClassificationInfer(data_dir) | |||||
labels_pred = infer.predict(model.cpu(), data_infer) | |||||
print(labels_pred) |
@@ -0,0 +1,3 @@ | |||||
from .core import * | |||||
from . import models | |||||
from . import modules |
@@ -0,0 +1,314 @@ | |||||
import warnings | |||||
import torch | |||||
warnings.filterwarnings('ignore') | |||||
import os | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.api.model_zoo import load_url | |||||
from fastNLP.api.processor import ModelProcessor | |||||
from reproduction.chinese_word_segment.cws_io.cws_reader import ConlluCWSReader | |||||
from reproduction.pos_tag_model.pos_io.pos_reader import ConlluPOSReader | |||||
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.batch import Batch | |||||
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 | |||||
from fastNLP.api.pipeline import Pipeline | |||||
from fastNLP.core.metrics import SeqLabelEvaluator2 | |||||
from fastNLP.core.tester import Tester | |||||
# TODO add pretrain urls | |||||
model_urls = { | |||||
} | |||||
class API: | |||||
def __init__(self): | |||||
self.pipeline = None | |||||
def predict(self, *args, **kwargs): | |||||
raise NotImplementedError | |||||
def load(self, path, device): | |||||
if os.path.exists(os.path.expanduser(path)): | |||||
_dict = torch.load(path, map_location='cpu') | |||||
else: | |||||
_dict = load_url(path, map_location='cpu') | |||||
self.pipeline = _dict['pipeline'] | |||||
self._dict = _dict | |||||
for processor in self.pipeline.pipeline: | |||||
if isinstance(processor, ModelProcessor): | |||||
processor.set_model_device(device) | |||||
class POS(API): | |||||
"""FastNLP API for Part-Of-Speech tagging. | |||||
""" | |||||
def __init__(self, model_path=None, device='cpu'): | |||||
super(POS, self).__init__() | |||||
if model_path is None: | |||||
model_path = model_urls['pos'] | |||||
self.load(model_path, device) | |||||
def predict(self, content): | |||||
""" | |||||
:param content: list of list of str. Each string is a token(word). | |||||
:return answer: list of list of str. Each string is a tag. | |||||
""" | |||||
if not hasattr(self, 'pipeline'): | |||||
raise ValueError("You have to load model first.") | |||||
sentence_list = [] | |||||
# 1. 检查sentence的类型 | |||||
if isinstance(content, str): | |||||
sentence_list.append(content) | |||||
elif isinstance(content, list): | |||||
sentence_list = content | |||||
# 2. 组建dataset | |||||
dataset = DataSet() | |||||
dataset.add_field('words', sentence_list) | |||||
# 3. 使用pipeline | |||||
self.pipeline(dataset) | |||||
output = dataset['word_pos_output'].content | |||||
if isinstance(content, str): | |||||
return output[0] | |||||
elif isinstance(content, list): | |||||
return output | |||||
def test(self, filepath): | |||||
tag_proc = self._dict['tag_indexer'] | |||||
model = self.pipeline.pipeline[2].model | |||||
pipeline = self.pipeline.pipeline[0:2] | |||||
pipeline.append(tag_proc) | |||||
pp = Pipeline(pipeline) | |||||
reader = ConlluPOSReader() | |||||
te_dataset = reader.load(filepath) | |||||
evaluator = SeqLabelEvaluator2('word_seq_origin_len') | |||||
end_tagidx_set = set() | |||||
tag_proc.vocab.build_vocab() | |||||
for key, value in tag_proc.vocab.word2idx.items(): | |||||
if key.startswith('E-'): | |||||
end_tagidx_set.add(value) | |||||
if key.startswith('S-'): | |||||
end_tagidx_set.add(value) | |||||
evaluator.end_tagidx_set = end_tagidx_set | |||||
default_valid_args = {"batch_size": 64, | |||||
"use_cuda": True, "evaluator": evaluator} | |||||
pp(te_dataset) | |||||
te_dataset.set_target(truth=True) | |||||
tester = Tester(**default_valid_args) | |||||
test_result = tester.test(model, te_dataset) | |||||
f1 = round(test_result['F'] * 100, 2) | |||||
pre = round(test_result['P'] * 100, 2) | |||||
rec = round(test_result['R'] * 100, 2) | |||||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | |||||
return f1, pre, rec | |||||
class CWS(API): | |||||
def __init__(self, model_path=None, device='cpu'): | |||||
super(CWS, self).__init__() | |||||
if model_path is None: | |||||
model_path = model_urls['cws'] | |||||
self.load(model_path, device) | |||||
def predict(self, content): | |||||
if not hasattr(self, 'pipeline'): | |||||
raise ValueError("You have to load model first.") | |||||
sentence_list = [] | |||||
# 1. 检查sentence的类型 | |||||
if isinstance(content, str): | |||||
sentence_list.append(content) | |||||
elif isinstance(content, list): | |||||
sentence_list = content | |||||
# 2. 组建dataset | |||||
dataset = DataSet() | |||||
dataset.add_field('raw_sentence', sentence_list) | |||||
# 3. 使用pipeline | |||||
self.pipeline(dataset) | |||||
output = dataset['output'].content | |||||
if isinstance(content, str): | |||||
return output[0] | |||||
elif isinstance(content, list): | |||||
return output | |||||
def test(self, filepath): | |||||
tag_proc = self._dict['tag_indexer'] | |||||
cws_model = self.pipeline.pipeline[-2].model | |||||
pipeline = self.pipeline.pipeline[:5] | |||||
pipeline.insert(1, tag_proc) | |||||
pp = Pipeline(pipeline) | |||||
reader = ConlluCWSReader() | |||||
# te_filename = '/home/hyan/ctb3/test.conllx' | |||||
te_dataset = reader.load(filepath) | |||||
pp(te_dataset) | |||||
batch_size = 64 | |||||
te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False) | |||||
pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher, type='bmes') | |||||
f1 = round(f1 * 100, 2) | |||||
pre = round(pre * 100, 2) | |||||
rec = round(rec * 100, 2) | |||||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | |||||
return f1, pre, rec | |||||
class Parser(API): | |||||
def __init__(self, model_path=None, device='cpu'): | |||||
super(Parser, self).__init__() | |||||
if model_path is None: | |||||
model_path = model_urls['parser'] | |||||
self.load(model_path, device) | |||||
def predict(self, content): | |||||
if not hasattr(self, 'pipeline'): | |||||
raise ValueError("You have to load model first.") | |||||
sentence_list = [] | |||||
# 1. 检查sentence的类型 | |||||
if isinstance(content, str): | |||||
sentence_list.append(content) | |||||
elif isinstance(content, list): | |||||
sentence_list = content | |||||
# 2. 组建dataset | |||||
dataset = DataSet() | |||||
dataset.add_field('words', sentence_list) | |||||
# dataset.add_field('tag', sentence_list) | |||||
# 3. 使用pipeline | |||||
self.pipeline(dataset) | |||||
for ins in dataset: | |||||
ins['heads'] = ins['heads'].tolist() | |||||
return dataset['heads'], dataset['labels'] | |||||
def test(self, filepath): | |||||
data = ConllxDataLoader().load(filepath) | |||||
ds = DataSet() | |||||
for ins1, ins2 in zip(add_seg_tag(data), data): | |||||
ds.append(Instance(words=ins1[0], tag=ins1[1], | |||||
gold_words=ins2[0], gold_pos=ins2[1], | |||||
gold_heads=ins2[2], gold_head_tags=ins2[3])) | |||||
pp = self.pipeline | |||||
for p in pp: | |||||
if p.field_name == 'word_list': | |||||
p.field_name = 'gold_words' | |||||
elif p.field_name == 'pos_list': | |||||
p.field_name = 'gold_pos' | |||||
pp(ds) | |||||
head_cor, label_cor, total = 0, 0, 0 | |||||
for ins in ds: | |||||
head_gold = ins['gold_heads'] | |||||
head_pred = ins['heads'] | |||||
length = len(head_gold) | |||||
total += length | |||||
for i in range(length): | |||||
head_cor += 1 if head_pred[i] == head_gold[i] else 0 | |||||
uas = head_cor / total | |||||
print('uas:{:.2f}'.format(uas)) | |||||
for p in pp: | |||||
if p.field_name == 'gold_words': | |||||
p.field_name = 'word_list' | |||||
elif p.field_name == 'gold_pos': | |||||
p.field_name = 'pos_list' | |||||
return uas | |||||
class Analyzer: | |||||
def __init__(self, device='cpu'): | |||||
self.cws = CWS(device=device) | |||||
self.pos = POS(device=device) | |||||
self.parser = Parser(device=device) | |||||
def predict(self, content, seg=False, pos=False, parser=False): | |||||
if seg is False and pos is False and parser is False: | |||||
seg = True | |||||
output_dict = {} | |||||
if seg: | |||||
seg_output = self.cws.predict(content) | |||||
output_dict['seg'] = seg_output | |||||
if pos: | |||||
pos_output = self.pos.predict(content) | |||||
output_dict['pos'] = pos_output | |||||
if parser: | |||||
parser_output = self.parser.predict(content) | |||||
output_dict['parser'] = parser_output | |||||
return output_dict | |||||
def test(self, filepath): | |||||
output_dict = {} | |||||
if self.seg: | |||||
seg_output = self.cws.test(filepath) | |||||
output_dict['seg'] = seg_output | |||||
if self.pos: | |||||
pos_output = self.pos.test(filepath) | |||||
output_dict['pos'] = pos_output | |||||
if self.parser: | |||||
parser_output = self.parser.test(filepath) | |||||
output_dict['parser'] = parser_output | |||||
return output_dict | |||||
if __name__ == "__main__": | |||||
# pos_model_path = '../../reproduction/pos_tag_model/pos_crf.pkl' | |||||
# pos = POS(device='cpu') | |||||
# s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' , | |||||
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
# '那么这款无人机到底有多厉害?'] | |||||
# print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll')) | |||||
# print(pos.predict(s)) | |||||
# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' | |||||
# cws = CWS(device='cpu') | |||||
# s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' , | |||||
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
# '那么这款无人机到底有多厉害?'] | |||||
# print(cws.test('/Users/yh/Desktop/test_data/cws_test.conll')) | |||||
# print(cws.predict(s)) | |||||
parser = Parser(device='cpu') | |||||
# print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll')) | |||||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
'那么这款无人机到底有多厉害?'] | |||||
print(parser.predict(s)) |
@@ -0,0 +1,181 @@ | |||||
import re | |||||
class SpanConverter: | |||||
def __init__(self, replace_tag, pattern): | |||||
super(SpanConverter, self).__init__() | |||||
self.replace_tag = replace_tag | |||||
self.pattern = pattern | |||||
def find_certain_span_and_replace(self, sentence): | |||||
replaced_sentence = '' | |||||
prev_end = 0 | |||||
for match in re.finditer(self.pattern, sentence): | |||||
start, end = match.span() | |||||
span = sentence[start:end] | |||||
replaced_sentence += sentence[prev_end:start] + self.span_to_special_tag(span) | |||||
prev_end = end | |||||
replaced_sentence += sentence[prev_end:] | |||||
return replaced_sentence | |||||
def span_to_special_tag(self, span): | |||||
return self.replace_tag | |||||
def find_certain_span(self, sentence): | |||||
spans = [] | |||||
for match in re.finditer(self.pattern, sentence): | |||||
spans.append(match.span()) | |||||
return spans | |||||
class AlphaSpanConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<ALPHA>' | |||||
# 理想状态下仅处理纯为字母的情况, 但不处理<[a-zA-Z]+>(因为这应该是特殊的tag). | |||||
pattern = '[a-zA-Z]+(?=[\u4e00-\u9fff ,%.!<\\-"])' | |||||
super(AlphaSpanConverter, self).__init__(replace_tag, pattern) | |||||
class DigitSpanConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<NUM>' | |||||
pattern = '\d[\d\\.]*(?=[\u4e00-\u9fff ,%.!<-])' | |||||
super(DigitSpanConverter, self).__init__(replace_tag, pattern) | |||||
def span_to_special_tag(self, span): | |||||
# return self.special_tag | |||||
if span[0] == '0' and len(span) > 2: | |||||
return '<NUM>' | |||||
decimal_point_count = 0 # one might have more than one decimal pointers | |||||
for idx, char in enumerate(span): | |||||
if char == '.' or char == '﹒' or char == '·': | |||||
decimal_point_count += 1 | |||||
if span[-1] == '.' or span[-1] == '﹒' or span[-1] == '·': | |||||
# last digit being decimal point means this is not a number | |||||
if decimal_point_count == 1: | |||||
return span | |||||
else: | |||||
return '<UNKDGT>' | |||||
if decimal_point_count == 1: | |||||
return '<DEC>' | |||||
elif decimal_point_count > 1: | |||||
return '<UNKDGT>' | |||||
else: | |||||
return '<NUM>' | |||||
class TimeConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<TOC>' | |||||
pattern = '\d+[::∶][\d::∶]+(?=[\u4e00-\u9fff ,%.!<-])' | |||||
super().__init__(replace_tag, pattern) | |||||
class MixNumAlphaConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<MIX>' | |||||
pattern = None | |||||
super().__init__(replace_tag, pattern) | |||||
def find_certain_span_and_replace(self, sentence): | |||||
replaced_sentence = '' | |||||
start = 0 | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
for idx in range(len(sentence)): | |||||
if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]): | |||||
if not matching_flag: | |||||
replaced_sentence += sentence[start:idx] | |||||
start = idx | |||||
if re.match('[0-9]', sentence[idx]): | |||||
number_flag = True | |||||
elif re.match('[\'′&\\-]', sentence[idx]): | |||||
link_flag = True | |||||
elif re.match('/', sentence[idx]): | |||||
slash_flag = True | |||||
elif re.match('[\\(\\)]', sentence[idx]): | |||||
bracket_flag = True | |||||
else: | |||||
alpha_flag = True | |||||
matching_flag = True | |||||
elif re.match('[\\.]', sentence[idx]): | |||||
pass | |||||
else: | |||||
if matching_flag: | |||||
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \ | |||||
or (slash_flag and alpha_flag) or (link_flag and number_flag) \ | |||||
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag): | |||||
span = sentence[start:idx] | |||||
start = idx | |||||
replaced_sentence += self.span_to_special_tag(span) | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
replaced_sentence += sentence[start:] | |||||
return replaced_sentence | |||||
def find_certain_span(self, sentence): | |||||
spans = [] | |||||
start = 0 | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
for idx in range(len(sentence)): | |||||
if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]): | |||||
if not matching_flag: | |||||
start = idx | |||||
if re.match('[0-9]', sentence[idx]): | |||||
number_flag = True | |||||
elif re.match('[\'′&\\-]', sentence[idx]): | |||||
link_flag = True | |||||
elif re.match('/', sentence[idx]): | |||||
slash_flag = True | |||||
elif re.match('[\\(\\)]', sentence[idx]): | |||||
bracket_flag = True | |||||
else: | |||||
alpha_flag = True | |||||
matching_flag = True | |||||
elif re.match('[\\.]', sentence[idx]): | |||||
pass | |||||
else: | |||||
if matching_flag: | |||||
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \ | |||||
or (slash_flag and alpha_flag) or (link_flag and number_flag) \ | |||||
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag): | |||||
spans.append((start, idx)) | |||||
start = idx | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
return spans | |||||
class EmailConverter(SpanConverter): | |||||
def __init__(self): | |||||
replaced_tag = "<EML>" | |||||
pattern = '[0-9a-zA-Z]+[@][.﹒0-9a-zA-Z@]+(?=[\u4e00-\u9fff ,%.!<\\-"$])' | |||||
super(EmailConverter, self).__init__(replaced_tag, pattern) |
@@ -0,0 +1,134 @@ | |||||
import hashlib | |||||
import os | |||||
import re | |||||
import shutil | |||||
import sys | |||||
import tempfile | |||||
import torch | |||||
try: | |||||
from requests.utils import urlparse | |||||
from requests import get as urlopen | |||||
requests_available = True | |||||
except ImportError: | |||||
requests_available = False | |||||
if sys.version_info[0] == 2: | |||||
from urlparse import urlparse # noqa f811 | |||||
from urllib2 import urlopen # noqa f811 | |||||
else: | |||||
from urllib.request import urlopen | |||||
from urllib.parse import urlparse | |||||
try: | |||||
from tqdm import tqdm | |||||
except ImportError: | |||||
tqdm = None # defined below | |||||
# matches bfd8deac from resnet18-bfd8deac.pth | |||||
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') | |||||
def load_url(url, model_dir=None, map_location=None, progress=True): | |||||
r"""Loads the Torch serialized object at the given URL. | |||||
If the object is already present in `model_dir`, it's deserialized and | |||||
returned. The filename part of the URL should follow the naming convention | |||||
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more | |||||
digits of the SHA256 hash of the contents of the file. The hash is used to | |||||
ensure unique names and to verify the contents of the file. | |||||
The default value of `model_dir` is ``$TORCH_HOME/models`` where | |||||
``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be | |||||
overridden with the ``$TORCH_MODEL_ZOO`` environment variable. | |||||
Args: | |||||
url (string): URL of the object to download | |||||
model_dir (string, optional): directory in which to save the object | |||||
map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) | |||||
progress (bool, optional): whether or not to display a progress bar to stderr | |||||
Example: | |||||
# >>> state_dict = model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') | |||||
""" | |||||
if model_dir is None: | |||||
torch_home = os.path.expanduser(os.getenv('fastNLP_HOME', '~/.fastNLP')) | |||||
model_dir = os.getenv('fastNLP_MODEL_ZOO', os.path.join(torch_home, 'models')) | |||||
if not os.path.exists(model_dir): | |||||
os.makedirs(model_dir) | |||||
parts = urlparse(url) | |||||
filename = os.path.basename(parts.path) | |||||
cached_file = os.path.join(model_dir, filename) | |||||
if not os.path.exists(cached_file): | |||||
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) | |||||
# hash_prefix = HASH_REGEX.search(filename).group(1) | |||||
_download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) | |||||
return torch.load(cached_file, map_location=map_location) | |||||
def _download_url_to_file(url, dst, hash_prefix, progress): | |||||
if requests_available: | |||||
u = urlopen(url, stream=True) | |||||
file_size = int(u.headers["Content-Length"]) | |||||
u = u.raw | |||||
else: | |||||
u = urlopen(url) | |||||
meta = u.info() | |||||
if hasattr(meta, 'getheaders'): | |||||
file_size = int(meta.getheaders("Content-Length")[0]) | |||||
else: | |||||
file_size = int(meta.get_all("Content-Length")[0]) | |||||
f = tempfile.NamedTemporaryFile(delete=False) | |||||
try: | |||||
if hash_prefix is not None: | |||||
sha256 = hashlib.sha256() | |||||
with tqdm(total=file_size, disable=not progress) as pbar: | |||||
while True: | |||||
buffer = u.read(8192) | |||||
if len(buffer) == 0: | |||||
break | |||||
f.write(buffer) | |||||
if hash_prefix is not None: | |||||
sha256.update(buffer) | |||||
pbar.update(len(buffer)) | |||||
f.close() | |||||
if hash_prefix is not None: | |||||
digest = sha256.hexdigest() | |||||
if digest[:len(hash_prefix)] != hash_prefix: | |||||
raise RuntimeError('invalid hash value (expected "{}", got "{}")' | |||||
.format(hash_prefix, digest)) | |||||
shutil.move(f.name, dst) | |||||
finally: | |||||
f.close() | |||||
if os.path.exists(f.name): | |||||
os.remove(f.name) | |||||
if tqdm is None: | |||||
# fake tqdm if it's not installed | |||||
class tqdm(object): | |||||
def __init__(self, total, disable=False): | |||||
self.total = total | |||||
self.disable = disable | |||||
self.n = 0 | |||||
def update(self, n): | |||||
if self.disable: | |||||
return | |||||
self.n += n | |||||
sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) | |||||
sys.stderr.flush() | |||||
def __enter__(self): | |||||
return self | |||||
def __exit__(self, exc_type, exc_val, exc_tb): | |||||
if self.disable: | |||||
return | |||||
sys.stderr.write('\n') | |||||
@@ -0,0 +1,33 @@ | |||||
from fastNLP.api.processor import Processor | |||||
class Pipeline: | |||||
""" | |||||
Pipeline takes a DataSet object as input, runs multiple processors sequentially, and | |||||
outputs a DataSet object. | |||||
""" | |||||
def __init__(self, processors=None): | |||||
self.pipeline = [] | |||||
if isinstance(processors, list): | |||||
for proc in processors: | |||||
assert isinstance(proc, Processor), "Must be a Processor, not {}.".format(type(proc)) | |||||
self.pipeline = processors | |||||
def add_processor(self, processor): | |||||
assert isinstance(processor, Processor), "Must be a Processor, not {}.".format(type(processor)) | |||||
self.pipeline.append(processor) | |||||
def process(self, dataset): | |||||
assert len(self.pipeline) != 0, "You need to add some processor first." | |||||
for proc in self.pipeline: | |||||
dataset = proc(dataset) | |||||
return dataset | |||||
def __call__(self, *args, **kwargs): | |||||
return self.process(*args, **kwargs) | |||||
def __getitem__(self, item): | |||||
return self.pipeline[item] |
@@ -0,0 +1,288 @@ | |||||
import re | |||||
from collections import defaultdict | |||||
import torch | |||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
class Processor(object): | |||||
def __init__(self, field_name, new_added_field_name): | |||||
self.field_name = field_name | |||||
if new_added_field_name is None: | |||||
self.new_added_field_name = field_name | |||||
else: | |||||
self.new_added_field_name = new_added_field_name | |||||
def process(self, *args, **kwargs): | |||||
raise NotImplementedError | |||||
def __call__(self, *args, **kwargs): | |||||
return self.process(*args, **kwargs) | |||||
class FullSpaceToHalfSpaceProcessor(Processor): | |||||
"""全角转半角,以字符为处理单元 | |||||
""" | |||||
def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True, | |||||
change_space=True): | |||||
super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None) | |||||
self.change_alpha = change_alpha | |||||
self.change_digit = change_digit | |||||
self.change_punctuation = change_punctuation | |||||
self.change_space = change_space | |||||
FH_SPACE = [(u" ", u" ")] | |||||
FH_NUM = [ | |||||
(u"0", u"0"), (u"1", u"1"), (u"2", u"2"), (u"3", u"3"), (u"4", u"4"), | |||||
(u"5", u"5"), (u"6", u"6"), (u"7", u"7"), (u"8", u"8"), (u"9", u"9")] | |||||
FH_ALPHA = [ | |||||
(u"a", u"a"), (u"b", u"b"), (u"c", u"c"), (u"d", u"d"), (u"e", u"e"), | |||||
(u"f", u"f"), (u"g", u"g"), (u"h", u"h"), (u"i", u"i"), (u"j", u"j"), | |||||
(u"k", u"k"), (u"l", u"l"), (u"m", u"m"), (u"n", u"n"), (u"o", u"o"), | |||||
(u"p", u"p"), (u"q", u"q"), (u"r", u"r"), (u"s", u"s"), (u"t", u"t"), | |||||
(u"u", u"u"), (u"v", u"v"), (u"w", u"w"), (u"x", u"x"), (u"y", u"y"), | |||||
(u"z", u"z"), | |||||
(u"A", u"A"), (u"B", u"B"), (u"C", u"C"), (u"D", u"D"), (u"E", u"E"), | |||||
(u"F", u"F"), (u"G", u"G"), (u"H", u"H"), (u"I", u"I"), (u"J", u"J"), | |||||
(u"K", u"K"), (u"L", u"L"), (u"M", u"M"), (u"N", u"N"), (u"O", u"O"), | |||||
(u"P", u"P"), (u"Q", u"Q"), (u"R", u"R"), (u"S", u"S"), (u"T", u"T"), | |||||
(u"U", u"U"), (u"V", u"V"), (u"W", u"W"), (u"X", u"X"), (u"Y", u"Y"), | |||||
(u"Z", u"Z")] | |||||
# 谨慎使用标点符号转换, 因为"5.12特大地震"转换后可能就成了"5.12特大地震" | |||||
FH_PUNCTUATION = [ | |||||
(u'%', u'%'), (u'!', u'!'), (u'"', u'\"'), (u''', u'\''), (u'#', u'#'), | |||||
(u'¥', u'$'), (u'&', u'&'), (u'(', u'('), (u')', u')'), (u'*', u'*'), | |||||
(u'+', u'+'), (u',', u','), (u'-', u'-'), (u'.', u'.'), (u'/', u'/'), | |||||
(u':', u':'), (u';', u';'), (u'<', u'<'), (u'=', u'='), (u'>', u'>'), | |||||
(u'?', u'?'), (u'@', u'@'), (u'[', u'['), (u']', u']'), (u'\', u'\\'), | |||||
(u'^', u'^'), (u'_', u'_'), (u'`', u'`'), (u'~', u'~'), (u'{', u'{'), | |||||
(u'}', u'}'), (u'|', u'|')] | |||||
FHs = [] | |||||
if self.change_alpha: | |||||
FHs = FH_ALPHA | |||||
if self.change_digit: | |||||
FHs += FH_NUM | |||||
if self.change_punctuation: | |||||
FHs += FH_PUNCTUATION | |||||
if self.change_space: | |||||
FHs += FH_SPACE | |||||
self.convert_map = {k: v for k, v in FHs} | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
sentence = ins[self.field_name] | |||||
new_sentence = [None] * len(sentence) | |||||
for idx, char in enumerate(sentence): | |||||
if char in self.convert_map: | |||||
char = self.convert_map[char] | |||||
new_sentence[idx] = char | |||||
ins[self.field_name] = ''.join(new_sentence) | |||||
return dataset | |||||
class PreAppendProcessor(Processor): | |||||
def __init__(self, data, field_name, new_added_field_name=None): | |||||
super(PreAppendProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.data = data | |||||
def process(self, dataset): | |||||
for ins in dataset: | |||||
sent = ins[self.field_name] | |||||
ins[self.new_added_field_name] = [self.data] + sent | |||||
return dataset | |||||
class SliceProcessor(Processor): | |||||
def __init__(self, start, end, step, field_name, new_added_field_name=None): | |||||
super(SliceProcessor, self).__init__(field_name, new_added_field_name) | |||||
for o in (start, end, step): | |||||
assert isinstance(o, int) or o is None | |||||
self.slice = slice(start, end, step) | |||||
def process(self, dataset): | |||||
for ins in dataset: | |||||
sent = ins[self.field_name] | |||||
ins[self.new_added_field_name] = sent[self.slice] | |||||
return dataset | |||||
class Num2TagProcessor(Processor): | |||||
def __init__(self, tag, field_name, new_added_field_name=None): | |||||
super(Num2TagProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.tag = tag | |||||
self.pattern = r'[-+]?([0-9]+[.]?[0-9]*)+[/eE]?[-+]?([0-9]+[.]?[0-9]*)' | |||||
def process(self, dataset): | |||||
for ins in dataset: | |||||
s = ins[self.field_name] | |||||
new_s = [None] * len(s) | |||||
for i, w in enumerate(s): | |||||
if re.search(self.pattern, w) is not None: | |||||
w = self.tag | |||||
new_s[i] = w | |||||
ins[self.new_added_field_name] = new_s | |||||
return dataset | |||||
class IndexerProcessor(Processor): | |||||
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True): | |||||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | |||||
super(IndexerProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.vocab = vocab | |||||
self.delete_old_field = delete_old_field | |||||
self.is_input = is_input | |||||
def set_vocab(self, vocab): | |||||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | |||||
self.vocab = vocab | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
tokens = ins[self.field_name] | |||||
index = [self.vocab.to_index(token) for token in tokens] | |||||
ins[self.new_added_field_name] = index | |||||
if self.is_input: | |||||
dataset.set_input(self.new_added_field_name) | |||||
if self.delete_old_field: | |||||
dataset.delete_field(self.field_name) | |||||
return dataset | |||||
class VocabProcessor(Processor): | |||||
"""Build vocabulary with a field in the data set. | |||||
""" | |||||
def __init__(self, field_name): | |||||
super(VocabProcessor, self).__init__(field_name, None) | |||||
self.vocab = Vocabulary() | |||||
def process(self, *datasets): | |||||
for dataset in datasets: | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
tokens = ins[self.field_name] | |||||
self.vocab.update(tokens) | |||||
def get_vocab(self): | |||||
self.vocab.build_vocab() | |||||
return self.vocab | |||||
class SeqLenProcessor(Processor): | |||||
def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True): | |||||
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.is_input = is_input | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
length = len(ins[self.field_name]) | |||||
ins[self.new_added_field_name] = length | |||||
if self.is_input: | |||||
dataset.set_input(self.new_added_field_name) | |||||
return dataset | |||||
class ModelProcessor(Processor): | |||||
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | |||||
""" | |||||
迭代模型并将结果的padding drop掉 | |||||
:param seq_len_field_name: | |||||
:param batch_size: | |||||
""" | |||||
super(ModelProcessor, self).__init__(None, None) | |||||
self.batch_size = batch_size | |||||
self.seq_len_field_name = seq_len_field_name | |||||
self.model = model | |||||
def process(self, dataset): | |||||
self.model.eval() | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False) | |||||
batch_output = defaultdict(list) | |||||
with torch.no_grad(): | |||||
for batch_x, _ in data_iterator: | |||||
prediction = self.model.predict(**batch_x) | |||||
seq_lens = batch_x[self.seq_len_field_name].cpu().numpy().tolist() | |||||
for key, value in prediction.items(): | |||||
tmp_batch = [] | |||||
value = value.cpu().numpy() | |||||
if len(value.shape) == 1 or (len(value.shape)==2 and value.shape[1]==1): | |||||
batch_output[key].extend(value.tolist()) | |||||
else: | |||||
for idx, seq_len in enumerate(seq_lens): | |||||
tmp_batch.append(value[idx, :seq_len]) | |||||
batch_output[key].extend(tmp_batch) | |||||
batch_output[self.seq_len_field_name].extend(seq_lens) | |||||
# TODO 当前的实现会导致之后的processor需要知道model输出的output的key是什么 | |||||
for field_name, fields in batch_output.items(): | |||||
dataset.add_field(field_name, fields, need_tensor=False, is_target=False) | |||||
return dataset | |||||
def set_model(self, model): | |||||
self.model = model | |||||
def set_model_device(self, device): | |||||
device = torch.device(device) | |||||
self.model.to(device) | |||||
class Index2WordProcessor(Processor): | |||||
def __init__(self, vocab, field_name, new_added_field_name): | |||||
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.vocab = vocab | |||||
def process(self, dataset): | |||||
for ins in dataset: | |||||
new_sent = [self.vocab.to_word(w) for w in ins[self.field_name]] | |||||
ins[self.new_added_field_name] = new_sent | |||||
return dataset | |||||
class SetTensorProcessor(Processor): | |||||
# TODO: remove it. It is strange. | |||||
def __init__(self, field_dict, default=False): | |||||
super(SetTensorProcessor, self).__init__(None, None) | |||||
self.field_dict = field_dict | |||||
self.default = default | |||||
def process(self, dataset): | |||||
set_dict = {name: self.default for name in dataset.get_all_fields().keys()} | |||||
set_dict.update(self.field_dict) | |||||
dataset._set_need_tensor(**set_dict) | |||||
return dataset | |||||
class SetIsTargetProcessor(Processor): | |||||
# TODO; remove it. | |||||
def __init__(self, field_dict, default=False): | |||||
super(SetIsTargetProcessor, self).__init__(None, None) | |||||
self.field_dict = field_dict | |||||
self.default = default | |||||
def process(self, dataset): | |||||
set_dict = {name: self.default for name in dataset.get_all_fields().keys()} | |||||
set_dict.update(self.field_dict) | |||||
dataset.set_target(**set_dict) | |||||
return dataset |
@@ -0,0 +1,13 @@ | |||||
from .batch import Batch | |||||
# from .dataset import DataSet | |||||
from .fieldarray import FieldArray | |||||
from .instance import Instance | |||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | |||||
from .metrics import AccuracyMetric | |||||
from .optimizer import Optimizer, SGD, Adam | |||||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler | |||||
from .tester import Tester | |||||
from .trainer import Trainer | |||||
from .vocabulary import Vocabulary | |||||
from ..io.dataset_loader import DataSet | |||||
@@ -1,5 +1,4 @@ | |||||
from collections import defaultdict | |||||
import numpy as np | |||||
import torch | import torch | ||||
@@ -7,25 +6,27 @@ class Batch(object): | |||||
"""Batch is an iterable object which iterates over mini-batches. | """Batch is an iterable object which iterates over mini-batches. | ||||
:: | :: | ||||
for batch_x, batch_y in Batch(data_set): | |||||
for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()): | |||||
""" | """ | ||||
def __init__(self, dataset, batch_size, sampler, use_cuda): | |||||
def __init__(self, dataset, batch_size, sampler, as_numpy=False): | |||||
""" | """ | ||||
:param dataset: a DataSet object | :param dataset: a DataSet object | ||||
:param batch_size: int, the size of the batch | :param batch_size: int, the size of the batch | ||||
:param sampler: a Sampler object | :param sampler: a Sampler object | ||||
:param use_cuda: bool, whether to use GPU | |||||
:param as_numpy: bool. If True, return Numpy array. Otherwise, return torch tensors. | |||||
""" | """ | ||||
self.dataset = dataset | self.dataset = dataset | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.sampler = sampler | self.sampler = sampler | ||||
self.use_cuda = use_cuda | |||||
self.as_numpy = as_numpy | |||||
self.idx_list = None | self.idx_list = None | ||||
self.curidx = 0 | self.curidx = 0 | ||||
self.num_batches = len(dataset)//batch_size + int(len(dataset)%batch_size!=0) | |||||
def __iter__(self): | def __iter__(self): | ||||
self.idx_list = self.sampler(self.dataset) | self.idx_list = self.sampler(self.dataset) | ||||
@@ -34,41 +35,35 @@ class Batch(object): | |||||
return self | return self | ||||
def __next__(self): | def __next__(self): | ||||
""" | |||||
:return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | |||||
E.g. | |||||
:: | |||||
{'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]}) | |||||
batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | |||||
All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True. | |||||
""" | |||||
if self.curidx >= len(self.idx_list): | if self.curidx >= len(self.idx_list): | ||||
raise StopIteration | raise StopIteration | ||||
else: | else: | ||||
endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | ||||
padding_length = {field_name: max(field_length[self.curidx: endidx]) | |||||
for field_name, field_length in self.lengths.items()} | |||||
batch_x, batch_y = defaultdict(list), defaultdict(list) | |||||
# transform index to tensor and do padding for sequences | |||||
for idx in range(self.curidx, endidx): | |||||
x, y = self.dataset.to_tensor(idx, padding_length) | |||||
for name, tensor in x.items(): | |||||
batch_x[name].append(tensor) | |||||
for name, tensor in y.items(): | |||||
batch_y[name].append(tensor) | |||||
# combine instances to form a batch | |||||
for batch in (batch_x, batch_y): | |||||
for name, tensor_list in batch.items(): | |||||
if self.use_cuda: | |||||
batch[name] = torch.stack(tensor_list, dim=0).cuda() | |||||
else: | |||||
batch[name] = torch.stack(tensor_list, dim=0) | |||||
batch_x, batch_y = {}, {} | |||||
indices = self.idx_list[self.curidx:endidx] | |||||
for field_name, field in self.dataset.get_all_fields().items(): | |||||
if field.is_target or field.is_input: | |||||
batch = field.get(indices) | |||||
if not self.as_numpy: | |||||
batch = to_tensor(batch, field.dtype) | |||||
if field.is_target: | |||||
batch_y[field_name] = batch | |||||
if field.is_input: | |||||
batch_x[field_name] = batch | |||||
self.curidx = endidx | self.curidx = endidx | ||||
return batch_x, batch_y | return batch_x, batch_y | ||||
def __len__(self): | |||||
return self.num_batches | |||||
def to_tensor(batch, dtype): | |||||
if dtype in (int, np.int8, np.int16, np.int32, np.int64): | |||||
batch = torch.LongTensor(batch) | |||||
if dtype in (float, np.float32, np.float64): | |||||
batch = torch.FloatTensor(batch) | |||||
return batch |
@@ -1,160 +1,390 @@ | |||||
import random | |||||
import sys | |||||
from collections import defaultdict | |||||
from copy import deepcopy | |||||
import _pickle as pickle | |||||
from fastNLP.core.field import TextField, LabelField | |||||
import numpy as np | |||||
from fastNLP.core.fieldarray import FieldArray | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.core.utils import get_func_signature | |||||
from fastNLP.io.base_loader import DataLoaderRegister | |||||
_READERS = {} | |||||
class DataSet(list): | |||||
"""A DataSet object is a list of Instance objects. | |||||
class DataSet(object): | |||||
"""DataSet is the collection of examples. | |||||
DataSet provides instance-level interface. You can append and access an instance of the DataSet. | |||||
However, it stores data in a different way: Field-first, Instance-second. | |||||
""" | """ | ||||
def __init__(self, name="", instances=None): | |||||
def __init__(self, data=None): | |||||
""" | """ | ||||
:param name: str, the name of the dataset. (default: "") | |||||
:param instances: list of Instance objects. (default: None) | |||||
:param data: a dict or a list. | |||||
If `data` is a dict, the key is the name of a FieldArray and the value is the FieldArray. All values | |||||
must be of the same length. | |||||
If `data` is a list, it must be a list of Instance objects. | |||||
""" | """ | ||||
list.__init__([]) | |||||
self.name = name | |||||
self.origin_len = None | |||||
if instances is not None: | |||||
self.extend(instances) | |||||
self.field_arrays = {} | |||||
if data is not None: | |||||
if isinstance(data, dict): | |||||
length_set = set() | |||||
for key, value in data.items(): | |||||
length_set.add(len(value)) | |||||
assert len(length_set) == 1, "Arrays must all be same length." | |||||
for key, value in data.items(): | |||||
self.add_field(name=key, fields=value) | |||||
elif isinstance(data, list): | |||||
for ins in data: | |||||
assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins)) | |||||
self.append(ins) | |||||
else: | |||||
raise ValueError("data only be dict or list type.") | |||||
def __contains__(self, item): | |||||
return item in self.field_arrays | |||||
def __iter__(self): | |||||
def iter_func(): | |||||
for idx in range(len(self)): | |||||
yield self[idx] | |||||
return iter_func() | |||||
def _inner_iter(self): | |||||
class Iter_ptr: | |||||
def __init__(self, dataset, idx): | |||||
self.dataset = dataset | |||||
self.idx = idx | |||||
def __getitem__(self, item): | |||||
assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[ | |||||
self.idx]) | |||||
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) | |||||
return self.dataset.field_arrays[item][self.idx] | |||||
def __repr__(self): | |||||
return self.dataset[self.idx].__repr__() | |||||
def index_all(self, vocab): | |||||
for ins in self: | |||||
ins.index_all(vocab) | |||||
return self | |||||
def inner_iter_func(): | |||||
for idx in range(len(self)): | |||||
yield Iter_ptr(self, idx) | |||||
def index_field(self, field_name, vocab): | |||||
if isinstance(field_name, str): | |||||
field_list = [field_name] | |||||
vocab_list = [vocab] | |||||
return inner_iter_func() | |||||
def __getitem__(self, idx): | |||||
"""Fetch Instance(s) at the `idx` position(s) in the dataset. | |||||
Notice: This method returns a copy of the actual instance(s). Any change to the returned value would not modify | |||||
the origin instance(s) of the DataSet. | |||||
If you want to make in-place changes to all Instances, use `apply` method. | |||||
:param idx: can be int or slice. | |||||
:return: If `idx` is int, return an Instance object. | |||||
If `idx` is slice, return a DataSet object. | |||||
""" | |||||
if isinstance(idx, int): | |||||
return Instance(**{name: self.field_arrays[name][idx] for name in self.field_arrays}) | |||||
elif isinstance(idx, slice): | |||||
if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)): | |||||
raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self)-1}") | |||||
data_set = DataSet() | |||||
for field in self.field_arrays.values(): | |||||
data_set.add_field(name=field.name, | |||||
fields=field.content[idx], | |||||
padding_val=field.padding_val, | |||||
is_input=field.is_input, | |||||
is_target=field.is_target) | |||||
return data_set | |||||
else: | else: | ||||
classes = (list, tuple) | |||||
assert isinstance(field_name, classes) and isinstance(vocab, classes) and len(field_name) == len(vocab) | |||||
field_list = field_name | |||||
vocab_list = vocab | |||||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||||
def __getattr__(self, item): | |||||
# Not tested. Don't use !! | |||||
if item == "field_arrays": | |||||
raise AttributeError | |||||
if isinstance(item, str) and item in self.field_arrays: | |||||
return self.field_arrays[item] | |||||
try: | |||||
reader = DataLoaderRegister.get_reader(item) | |||||
return reader | |||||
except AttributeError: | |||||
raise | |||||
for name, vocabs in zip(field_list, vocab_list): | |||||
for ins in self: | |||||
ins.index_field(name, vocabs) | |||||
return self | |||||
def __setstate__(self, state): | |||||
self.__dict__ = state | |||||
def to_tensor(self, idx: int, padding_length: dict): | |||||
"""Convert an instance in a dataset to tensor. | |||||
def __getstate__(self): | |||||
return self.__dict__ | |||||
:param idx: int, the index of the instance in the dataset. | |||||
:param padding_length: int | |||||
:return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||||
tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||||
def __len__(self): | |||||
"""Fetch the length of the dataset. | |||||
:return int length: | |||||
""" | """ | ||||
ins = self[idx] | |||||
return ins.to_tensor(padding_length, self.origin_len) | |||||
if len(self.field_arrays) == 0: | |||||
return 0 | |||||
field = iter(self.field_arrays.values()).__next__() | |||||
return len(field) | |||||
def get_length(self): | |||||
"""Fetch lengths of all fields in all instances in a dataset. | |||||
def __inner_repr__(self): | |||||
if len(self) < 20: | |||||
return ",\n".join([ins.__repr__() for ins in self]) | |||||
else: | |||||
return self[:5].__inner_repr__() + "\n...\n" + self[-5:].__inner_repr__() | |||||
def __repr__(self): | |||||
return "DataSet(" + self.__inner_repr__() + ")" | |||||
:return lengths: dict of (str: list). The str is the field name. | |||||
The list contains lengths of this field in all instances. | |||||
def append(self, ins): | |||||
"""Add an instance to the DataSet. | |||||
If the DataSet is not empty, the instance must have the same field names as the rest instances in the DataSet. | |||||
:param ins: an Instance object | |||||
""" | """ | ||||
lengths = defaultdict(list) | |||||
for ins in self: | |||||
for field_name, field_length in ins.get_length().items(): | |||||
lengths[field_name].append(field_length) | |||||
return lengths | |||||
if len(self.field_arrays) == 0: | |||||
# DataSet has no field yet | |||||
for name, field in ins.fields.items(): | |||||
self.field_arrays[name] = FieldArray(name, [field]) | |||||
else: | |||||
assert len(self.field_arrays) == len(ins.fields) | |||||
for name, field in ins.fields.items(): | |||||
assert name in self.field_arrays | |||||
self.field_arrays[name].append(field) | |||||
def shuffle(self): | |||||
random.shuffle(self) | |||||
return self | |||||
def add_field(self, name, fields, padding_val=0, is_input=False, is_target=False): | |||||
"""Add a new field to the DataSet. | |||||
:param str name: the name of the field. | |||||
:param fields: a list of int, float, or other objects. | |||||
:param int padding_val: integer for padding. | |||||
:param bool is_input: whether this field is model input. | |||||
:param bool is_target: whether this field is label or target. | |||||
""" | |||||
if len(self.field_arrays) != 0: | |||||
if len(self) != len(fields): | |||||
raise RuntimeError(f"The field to append must have the same size as dataset. " | |||||
f"Dataset size {len(self)} != field size {len(fields)}") | |||||
self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target, | |||||
is_input=is_input) | |||||
def split(self, ratio, shuffle=True): | |||||
"""Train/dev splitting | |||||
def delete_field(self, name): | |||||
"""Delete a field based on the field name. | |||||
:param ratio: float, between 0 and 1. The ratio of development set in origin data set. | |||||
:param shuffle: bool, whether shuffle the data set before splitting. Default: True. | |||||
:return train_set: a DataSet object, representing the training set | |||||
dev_set: a DataSet object, representing the validation set | |||||
:param str name: the name of the field to be deleted. | |||||
""" | |||||
self.field_arrays.pop(name) | |||||
def get_field(self, field_name): | |||||
if field_name not in self.field_arrays: | |||||
raise KeyError("Field name {} not found in DataSet".format(field_name)) | |||||
return self.field_arrays[field_name] | |||||
def get_all_fields(self): | |||||
"""Return all the fields with their names. | |||||
:return dict field_arrays: the internal data structure of DataSet. | |||||
""" | """ | ||||
assert 0 < ratio < 1 | |||||
if shuffle: | |||||
self.shuffle() | |||||
split_idx = int(len(self) * ratio) | |||||
dev_set = deepcopy(self) | |||||
train_set = deepcopy(self) | |||||
del train_set[:split_idx] | |||||
del dev_set[split_idx:] | |||||
return train_set, dev_set | |||||
return self.field_arrays | |||||
def get_length(self): | |||||
"""Fetch the length of the dataset. | |||||
:return int length: | |||||
""" | |||||
return len(self) | |||||
def rename_field(self, old_name, new_name): | def rename_field(self, old_name, new_name): | ||||
"""rename a field | |||||
"""Rename a field. | |||||
:param str old_name: | |||||
:param str new_name: | |||||
""" | """ | ||||
for ins in self: | |||||
ins.rename_field(old_name, new_name) | |||||
return self | |||||
if old_name in self.field_arrays: | |||||
self.field_arrays[new_name] = self.field_arrays.pop(old_name) | |||||
self.field_arrays[new_name].name = new_name | |||||
else: | |||||
raise KeyError("DataSet has no field named {}.".format(old_name)) | |||||
def set_target(self, **fields): | |||||
"""Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged. | |||||
def set_target(self, *field_names, flag=True): | |||||
"""Change the target flag of these fields. | |||||
:param key-value pairs for field-name and `is_target` value(True, False or None). | |||||
:param field_names: a sequence of str, indicating field names | |||||
:param bool flag: Set these fields as target if True. Unset them if False. | |||||
""" | """ | ||||
for ins in self: | |||||
ins.set_target(**fields) | |||||
return self | |||||
for name in field_names: | |||||
if name in self.field_arrays: | |||||
self.field_arrays[name].is_target = flag | |||||
else: | |||||
raise KeyError("{} is not a valid field name.".format(name)) | |||||
def update_vocab(self, **name_vocab): | |||||
"""using certain field data to update vocabulary. | |||||
def set_input(self, *field_name, flag=True): | |||||
"""Set the input flag of these fields. | |||||
e.g. :: | |||||
:param field_name: a sequence of str, indicating field names. | |||||
:param bool flag: Set these fields as input if True. Unset them if False. | |||||
""" | |||||
for name in field_name: | |||||
if name in self.field_arrays: | |||||
self.field_arrays[name].is_input = flag | |||||
else: | |||||
raise KeyError("{} is not a valid field name.".format(name)) | |||||
# update word vocab and label vocab seperately | |||||
dataset.update_vocab(word_seq=word_vocab, label_seq=label_vocab) | |||||
def get_input_name(self): | |||||
"""Get all field names with `is_input` as True. | |||||
:return list field_names: a list of str | |||||
""" | """ | ||||
for field_name, vocab in name_vocab.items(): | |||||
for ins in self: | |||||
vocab.update(ins[field_name].contents()) | |||||
return self | |||||
return [name for name, field in self.field_arrays.items() if field.is_input] | |||||
def get_target_name(self): | |||||
"""Get all field names with `is_target` as True. | |||||
def set_origin_len(self, origin_field, origin_len_name=None): | |||||
"""make dataset tensor output contain origin_len field. | |||||
:return list field_names: a list of str | |||||
""" | |||||
return [name for name, field in self.field_arrays.items() if field.is_target] | |||||
e.g. :: | |||||
def apply(self, func, new_field_name=None, **kwargs): | |||||
"""Apply a function to every instance of the DataSet. | |||||
# output "word_seq_origin_len", lengths based on "word_seq" field | |||||
dataset.set_origin_len("word_seq") | |||||
:param func: a function that takes an instance as input. | |||||
:param str new_field_name: If not None, results of the function will be stored as a new field. | |||||
:param **kwargs: Accept parameters will be | |||||
(1) is_input: boolean, will be ignored if new_field is None. If True, the new field will be as input. | |||||
(2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target. | |||||
:return results: if new_field_name is not passed, returned values of the function over all instances. | |||||
""" | """ | ||||
if origin_field is None: | |||||
self.origin_len = None | |||||
else: | |||||
self.origin_len = (origin_field + "_origin_len", origin_field) \ | |||||
if origin_len_name is None else (origin_len_name, origin_field) | |||||
return self | |||||
def __getattribute__(self, name): | |||||
if name in _READERS: | |||||
# add read_*data() support | |||||
def _read(*args, **kwargs): | |||||
data = _READERS[name]().load(*args, **kwargs) | |||||
self.extend(data) | |||||
return self | |||||
return _read | |||||
results = [func(ins) for ins in self._inner_iter()] | |||||
if len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | |||||
extra_param = {} | |||||
if 'is_input' in kwargs: | |||||
extra_param['is_input'] = kwargs['is_input'] | |||||
if 'is_target' in kwargs: | |||||
extra_param['is_target'] = kwargs['is_target'] | |||||
if new_field_name is not None: | |||||
if new_field_name in self.field_arrays: | |||||
# overwrite the field, keep same attributes | |||||
old_field = self.field_arrays[new_field_name] | |||||
if 'is_input' not in extra_param: | |||||
extra_param['is_input'] = old_field.is_input | |||||
if 'is_target' not in extra_param: | |||||
extra_param['is_target'] = old_field.is_target | |||||
self.add_field(name=new_field_name, | |||||
fields=results, | |||||
padding_val=old_field.padding_val, | |||||
**extra_param) | |||||
else: | |||||
self.add_field(name=new_field_name, fields=results, **extra_param) | |||||
else: | else: | ||||
return object.__getattribute__(self, name) | |||||
return results | |||||
def drop(self, func): | |||||
"""Drop instances if a condition holds. | |||||
:param func: a function that takes an Instance object as input, and returns bool. | |||||
The instance will be dropped if the function returns True. | |||||
""" | |||||
results = [ins for ins in self._inner_iter() if not func(ins)] | |||||
for name, old_field in self.field_arrays.items(): | |||||
self.field_arrays[name].content = [ins[name] for ins in results] | |||||
def split(self, dev_ratio): | |||||
"""Split the dataset into training and development(validation) set. | |||||
:param float dev_ratio: the ratio of test set in all data. | |||||
:return DataSet train_set: the training set | |||||
DataSet dev_set: the development set | |||||
""" | |||||
assert isinstance(dev_ratio, float) | |||||
assert 0 < dev_ratio < 1 | |||||
all_indices = [_ for _ in range(len(self))] | |||||
np.random.shuffle(all_indices) | |||||
split = int(dev_ratio * len(self)) | |||||
dev_indices = all_indices[:split] | |||||
train_indices = all_indices[split:] | |||||
dev_set = DataSet() | |||||
train_set = DataSet() | |||||
for idx in dev_indices: | |||||
dev_set.append(self[idx]) | |||||
for idx in train_indices: | |||||
train_set.append(self[idx]) | |||||
for field_name in self.field_arrays: | |||||
train_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input | |||||
train_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target | |||||
dev_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input | |||||
dev_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target | |||||
return train_set, dev_set | |||||
@classmethod | @classmethod | ||||
def set_reader(cls, method_name): | |||||
"""decorator to add dataloader support | |||||
""" | |||||
assert isinstance(method_name, str) | |||||
def wrapper(read_cls): | |||||
_READERS[method_name] = read_cls | |||||
return read_cls | |||||
return wrapper | |||||
def read_csv(cls, csv_path, headers=None, sep=",", dropna=True): | |||||
"""Load data from a CSV file and return a DataSet object. | |||||
:param str csv_path: path to the CSV file | |||||
:param List[str] or Tuple[str] headers: headers of the CSV file | |||||
:param str sep: delimiter in CSV file. Default: "," | |||||
:param bool dropna: If True, drop rows that have less entries than headers. | |||||
:return DataSet dataset: | |||||
""" | |||||
with open(csv_path, "r") as f: | |||||
start_idx = 0 | |||||
if headers is None: | |||||
headers = f.readline().rstrip('\r\n') | |||||
headers = headers.split(sep) | |||||
start_idx += 1 | |||||
else: | |||||
assert isinstance(headers, (list, tuple)), "headers should be list or tuple, not {}.".format( | |||||
type(headers)) | |||||
_dict = {} | |||||
for col in headers: | |||||
_dict[col] = [] | |||||
for line_idx, line in enumerate(f, start_idx): | |||||
contents = line.rstrip('\r\n').split(sep) | |||||
if len(contents) != len(headers): | |||||
if dropna: | |||||
continue | |||||
else: | |||||
# TODO change error type | |||||
raise ValueError("Line {} has {} parts, while header has {} parts." \ | |||||
.format(line_idx, len(contents), len(headers))) | |||||
for header, content in zip(headers, contents): | |||||
_dict[header].append(content) | |||||
return cls(_dict) | |||||
# def read_pos(self): | |||||
# return DataLoaderRegister.get_reader('read_pos') | |||||
def save(self, path): | |||||
"""Save the DataSet object as pickle. | |||||
:param str path: the path to the pickle | |||||
""" | |||||
with open(path, 'wb') as f: | |||||
pickle.dump(self, f) | |||||
@staticmethod | |||||
def load(path): | |||||
"""Load a DataSet object from pickle. | |||||
:param str path: the path to the pickle | |||||
:return DataSet data_set: | |||||
""" | |||||
with open(path, 'rb') as f: | |||||
return pickle.load(f) | |||||
def construct_dataset(sentences): | |||||
"""Construct a data set from a list of sentences. | |||||
:param sentences: list of list of str | |||||
:return dataset: a DataSet object | |||||
""" | |||||
dataset = DataSet() | |||||
for sentence in sentences: | |||||
instance = Instance() | |||||
instance['raw_sentence'] = sentence | |||||
dataset.append(instance) | |||||
return dataset |
@@ -1,135 +0,0 @@ | |||||
import torch | |||||
class Field(object): | |||||
"""A field defines a data type. | |||||
""" | |||||
def __init__(self, is_target: bool): | |||||
self.is_target = is_target | |||||
def index(self, vocab): | |||||
raise NotImplementedError | |||||
def get_length(self): | |||||
raise NotImplementedError | |||||
def to_tensor(self, padding_length): | |||||
raise NotImplementedError | |||||
def contents(self): | |||||
raise NotImplementedError | |||||
class TextField(Field): | |||||
def __init__(self, text, is_target): | |||||
""" | |||||
:param text: list of strings | |||||
:param is_target: bool | |||||
""" | |||||
super(TextField, self).__init__(is_target) | |||||
self.text = text | |||||
self._index = None | |||||
def index(self, vocab): | |||||
if self._index is None: | |||||
self._index = [vocab[c] for c in self.text] | |||||
else: | |||||
raise RuntimeError("Replicate indexing of this field.") | |||||
return self._index | |||||
def get_length(self): | |||||
"""Fetch the length of the text field. | |||||
:return length: int, the length of the text. | |||||
""" | |||||
return len(self.text) | |||||
def to_tensor(self, padding_length: int): | |||||
"""Convert text field to tensor. | |||||
:param padding_length: int | |||||
:return tensor: torch.LongTensor, of shape [padding_length, ] | |||||
""" | |||||
pads = [] | |||||
if self._index is None: | |||||
raise RuntimeError("Indexing not done before to_tensor in TextField.") | |||||
if padding_length > self.get_length(): | |||||
pads = [0] * (padding_length - self.get_length()) | |||||
return torch.LongTensor(self._index + pads) | |||||
def contents(self): | |||||
return self.text.copy() | |||||
class LabelField(Field): | |||||
"""The Field representing a single label. Can be a string or integer. | |||||
""" | |||||
def __init__(self, label, is_target=True): | |||||
super(LabelField, self).__init__(is_target) | |||||
self.label = label | |||||
self._index = None | |||||
def get_length(self): | |||||
"""Fetch the length of the label field. | |||||
:return length: int, the length of the label, always 1. | |||||
""" | |||||
return 1 | |||||
def index(self, vocab): | |||||
if self._index is None: | |||||
if isinstance(self.label, str): | |||||
self._index = vocab[self.label] | |||||
return self._index | |||||
def to_tensor(self, padding_length): | |||||
if self._index is None: | |||||
if isinstance(self.label, int): | |||||
return torch.tensor(self.label) | |||||
elif isinstance(self.label, str): | |||||
raise RuntimeError("Field {} not indexed. Call index method.".format(self.label)) | |||||
else: | |||||
raise RuntimeError( | |||||
"Not support type for LabelField. Expect str or int, got {}.".format(type(self.label))) | |||||
else: | |||||
return torch.LongTensor([self._index]) | |||||
def contents(self): | |||||
return [self.label] | |||||
class SeqLabelField(Field): | |||||
def __init__(self, label_seq, is_target=True): | |||||
super(SeqLabelField, self).__init__(is_target) | |||||
self.label_seq = label_seq | |||||
self._index = None | |||||
def get_length(self): | |||||
return len(self.label_seq) | |||||
def index(self, vocab): | |||||
if self._index is None: | |||||
self._index = [vocab[c] for c in self.label_seq] | |||||
return self._index | |||||
def to_tensor(self, padding_length): | |||||
pads = [0] * (padding_length - self.get_length()) | |||||
if self._index is None: | |||||
if self.get_length() == 0: | |||||
return torch.LongTensor(pads) | |||||
elif isinstance(self.label_seq[0], int): | |||||
return torch.LongTensor(self.label_seq + pads) | |||||
elif isinstance(self.label_seq[0], str): | |||||
raise RuntimeError("Field {} not indexed. Call index method.".format(self.label)) | |||||
else: | |||||
raise RuntimeError( | |||||
"Not support type for SeqLabelField. Expect str or int, got {}.".format(type(self.label))) | |||||
else: | |||||
return torch.LongTensor(self._index + pads) | |||||
def contents(self): | |||||
return self.label_seq.copy() | |||||
if __name__ == "__main__": | |||||
tf = TextField("test the code".split(), is_target=False) |
@@ -0,0 +1,190 @@ | |||||
import numpy as np | |||||
class FieldArray(object): | |||||
"""FieldArray is the collection of Instances of the same Field. | |||||
It is the basic element of DataSet class. | |||||
""" | |||||
def __init__(self, name, content, padding_val=0, is_target=None, is_input=None): | |||||
""" | |||||
:param str name: the name of the FieldArray | |||||
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray. | |||||
:param int padding_val: the integer for padding. Default: 0. | |||||
:param bool is_target: If True, this FieldArray is used to compute loss. | |||||
:param bool is_input: If True, this FieldArray is used to the model input. | |||||
""" | |||||
self.name = name | |||||
if isinstance(content, list): | |||||
content = content | |||||
elif isinstance(content, np.ndarray): | |||||
content = content.tolist() # convert np.ndarray into 2-D list | |||||
else: | |||||
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) | |||||
self.content = content | |||||
self.padding_val = padding_val | |||||
self._is_target = None | |||||
self._is_input = None | |||||
self.BASIC_TYPES = (int, float, str, np.ndarray) | |||||
self.is_2d_list = False | |||||
self.pytype = None # int, float, str, or np.ndarray | |||||
self.dtype = None # np.int64, np.float64, np.str | |||||
if is_input is not None: | |||||
self.is_input = is_input | |||||
if is_target is not None: | |||||
self.is_target = is_target | |||||
@property | |||||
def is_input(self): | |||||
return self._is_input | |||||
@is_input.setter | |||||
def is_input(self, value): | |||||
if value is True: | |||||
self.pytype = self._type_detection(self.content) | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
self._is_input = value | |||||
@property | |||||
def is_target(self): | |||||
return self._is_target | |||||
@is_target.setter | |||||
def is_target(self, value): | |||||
if value is True: | |||||
self.pytype = self._type_detection(self.content) | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
self._is_target = value | |||||
def _type_detection(self, content): | |||||
""" | |||||
:param content: a list of int, float, str or np.ndarray, or a list of list of one. | |||||
:return type: one of int, float, str, np.ndarray | |||||
""" | |||||
if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list): | |||||
# content is a 2-D list | |||||
if not all(isinstance(_, list) for _ in content): # strict check 2-D list | |||||
raise TypeError("Please provide 2-D list.") | |||||
type_set = set([self._type_detection(x) for x in content]) | |||||
if len(type_set) == 2 and int in type_set and float in type_set: | |||||
type_set = {float} | |||||
elif len(type_set) > 1: | |||||
raise TypeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set)) | |||||
self.is_2d_list = True | |||||
return type_set.pop() | |||||
elif isinstance(content, list): | |||||
# content is a 1-D list | |||||
if len(content) == 0: | |||||
# the old error is not informative enough. | |||||
raise RuntimeError("Cannot create FieldArray with an empty list. Or one element in the list is empty.") | |||||
type_set = set([type(item) for item in content]) | |||||
if len(type_set) == 1 and tuple(type_set)[0] in self.BASIC_TYPES: | |||||
return type_set.pop() | |||||
elif len(type_set) == 2 and float in type_set and int in type_set: | |||||
# up-cast int to float | |||||
return float | |||||
else: | |||||
raise TypeError("Cannot create FieldArray with type {}".format(*type_set)) | |||||
else: | |||||
raise TypeError("Cannot create FieldArray with type {}".format(type(content))) | |||||
@staticmethod | |||||
def _map_to_np_type(basic_type): | |||||
type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray} | |||||
return type_mapping[basic_type] | |||||
def __repr__(self): | |||||
return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | |||||
def append(self, val): | |||||
"""Add a new item to the tail of FieldArray. | |||||
:param val: int, float, str, or a list of one. | |||||
""" | |||||
if self.is_target is True or self.is_input is True: | |||||
# only check type when used as target or input | |||||
val_type = type(val) | |||||
if val_type == list: # shape check | |||||
if self.is_2d_list is False: | |||||
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.") | |||||
if len(val) == 0: | |||||
raise RuntimeError("Cannot append an empty list.") | |||||
val_list_type = set([type(_) for _ in val]) # type check | |||||
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type: | |||||
# up-cast int to float | |||||
val_type = float | |||||
elif len(val_list_type) == 1: | |||||
val_type = val_list_type.pop() | |||||
else: | |||||
raise TypeError("Cannot append a list of {}".format(val_list_type)) | |||||
else: | |||||
if self.is_2d_list is True: | |||||
raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.") | |||||
if val_type == float and self.pytype == int: | |||||
# up-cast | |||||
self.pytype = float | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
elif val_type == int and self.pytype == float: | |||||
pass | |||||
elif val_type == self.pytype: | |||||
pass | |||||
else: | |||||
raise TypeError("Cannot append type {} into type {}".format(val_type, self.pytype)) | |||||
self.content.append(val) | |||||
def __getitem__(self, indices): | |||||
return self.get(indices) | |||||
def __setitem__(self, idx, val): | |||||
assert isinstance(idx, int) | |||||
self.content[idx] = val | |||||
def get(self, indices): | |||||
"""Fetch instances based on indices. | |||||
:param indices: an int, or a list of int. | |||||
:return: | |||||
""" | |||||
if isinstance(indices, int): | |||||
return self.content[indices] | |||||
if self.is_input is False and self.is_target is False: | |||||
raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) | |||||
batch_size = len(indices) | |||||
if not is_iterable(self.content[0]): | |||||
array = np.array([self.content[i] for i in indices], dtype=self.dtype) | |||||
elif self.dtype in (np.int64, np.float64): | |||||
max_len = max([len(self.content[i]) for i in indices]) | |||||
array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype) | |||||
for i, idx in enumerate(indices): | |||||
array[i][:len(self.content[idx])] = self.content[idx] | |||||
else: # should only be str | |||||
array = np.array([self.content[i] for i in indices]) | |||||
return array | |||||
def __len__(self): | |||||
"""Returns the size of FieldArray. | |||||
:return int length: | |||||
""" | |||||
return len(self.content) | |||||
def is_iterable(content): | |||||
try: | |||||
_ = (e for e in content) | |||||
except TypeError: | |||||
return False | |||||
return True |
@@ -1,33 +1,25 @@ | |||||
import torch | |||||
class Instance(object): | class Instance(object): | ||||
"""An instance which consists of Fields is an example in the DataSet. | |||||
"""An Instance is an example of data. It is the collection of Fields. | |||||
:: | |||||
Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) | |||||
""" | """ | ||||
def __init__(self, **fields): | def __init__(self, **fields): | ||||
""" | |||||
:param fields: a dict of (str: list). | |||||
""" | |||||
self.fields = fields | self.fields = fields | ||||
self.has_index = False | |||||
self.indexes = {} | |||||
def add_field(self, field_name, field): | def add_field(self, field_name, field): | ||||
self.fields[field_name] = field | |||||
return self | |||||
def rename_field(self, old_name, new_name): | |||||
if old_name in self.fields: | |||||
self.fields[new_name] = self.fields.pop(old_name) | |||||
if old_name in self.indexes: | |||||
self.indexes[new_name] = self.indexes.pop(old_name) | |||||
else: | |||||
raise KeyError("error, no such field: {}".format(old_name)) | |||||
return self | |||||
"""Add a new field to the instance. | |||||
def set_target(self, **fields): | |||||
for name, val in fields.items(): | |||||
if name in self.fields: | |||||
self.fields[name].is_target = val | |||||
return self | |||||
:param field_name: str, the name of the field. | |||||
:param field: | |||||
""" | |||||
self.fields[field_name] = field | |||||
def __getitem__(self, name): | def __getitem__(self, name): | ||||
if name in self.fields: | if name in self.fields: | ||||
@@ -35,50 +27,9 @@ class Instance(object): | |||||
else: | else: | ||||
raise KeyError("{} not found".format(name)) | raise KeyError("{} not found".format(name)) | ||||
def get_length(self): | |||||
"""Fetch the length of all fields in the instance. | |||||
def __setitem__(self, name, field): | |||||
return self.add_field(name, field) | |||||
:return length: dict of (str: int), which means (field name: field length). | |||||
""" | |||||
length = {name: field.get_length() for name, field in self.fields.items()} | |||||
return length | |||||
def index_field(self, field_name, vocab): | |||||
"""use `vocab` to index certain field | |||||
""" | |||||
self.indexes[field_name] = self.fields[field_name].index(vocab) | |||||
return self | |||||
def index_all(self, vocab): | |||||
"""use `vocab` to index all fields | |||||
""" | |||||
if self.has_index: | |||||
print("error") | |||||
return self.indexes | |||||
indexes = {name: field.index(vocab) for name, field in self.fields.items()} | |||||
self.indexes = indexes | |||||
return indexes | |||||
def to_tensor(self, padding_length: dict, origin_len=None): | |||||
"""Convert instance to tensor. | |||||
:param padding_length: dict of (str: int), which means (field name: padding_length of this field) | |||||
:return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||||
tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||||
If is_target is False for all fields, tensor_y would be an empty dict. | |||||
""" | |||||
tensor_x = {} | |||||
tensor_y = {} | |||||
for name, field in self.fields.items(): | |||||
if field.is_target is True: | |||||
tensor_y[name] = field.to_tensor(padding_length[name]) | |||||
elif field.is_target is False: | |||||
tensor_x[name] = field.to_tensor(padding_length[name]) | |||||
else: | |||||
# is_target is None | |||||
continue | |||||
if origin_len is not None: | |||||
name, field_name = origin_len | |||||
tensor_x[name] = torch.LongTensor([self.fields[field_name].get_length()]) | |||||
return tensor_x, tensor_y | |||||
def __repr__(self): | |||||
return "{" + ",\n".join( | |||||
"\'" + field_name + "\': " + str(self.fields[field_name]) for field_name in self.fields) + "}" |
@@ -1,58 +0,0 @@ | |||||
import torch | |||||
class Loss(object): | |||||
"""Loss function of the algorithm, | |||||
either the wrapper of a loss function from framework, or a user-defined loss (need pytorch auto_grad support) | |||||
""" | |||||
def __init__(self, args): | |||||
""" | |||||
:param args: None or str, the name of a loss function. | |||||
""" | |||||
if args is None: | |||||
# this is useful when Trainer.__init__ performs type check | |||||
self._loss = None | |||||
elif isinstance(args, str): | |||||
self._loss = self._borrow_from_pytorch(args) | |||||
else: | |||||
raise NotImplementedError | |||||
def get(self): | |||||
""" | |||||
:return self._loss: the loss function | |||||
""" | |||||
return self._loss | |||||
@staticmethod | |||||
def _borrow_from_pytorch(loss_name): | |||||
"""Given a name of a loss function, return it from PyTorch. | |||||
:param loss_name: str, the name of a loss function | |||||
- cross_entropy: combines log softmax and nll loss in a single function. | |||||
- nll: negative log likelihood | |||||
:return loss: a PyTorch loss | |||||
""" | |||||
class InnerCrossEntropy: | |||||
"""A simple wrapper to guarantee input shapes.""" | |||||
def __init__(self): | |||||
self.f = torch.nn.CrossEntropyLoss() | |||||
def __call__(self, predict, truth): | |||||
truth = truth.view(-1, ) | |||||
return self.f(predict, truth) | |||||
if loss_name == "cross_entropy": | |||||
return InnerCrossEntropy() | |||||
elif loss_name == 'nll': | |||||
return torch.nn.NLLLoss() | |||||
else: | |||||
raise NotImplementedError |
@@ -0,0 +1,358 @@ | |||||
import inspect | |||||
from collections import defaultdict | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import CheckRes | |||||
from fastNLP.core.utils import _build_args | |||||
from fastNLP.core.utils import _check_arg_dict_list | |||||
from fastNLP.core.utils import _check_function_or_method | |||||
from fastNLP.core.utils import get_func_signature | |||||
class LossBase(object): | |||||
def __init__(self): | |||||
self.param_map = {} | |||||
self._checked = False | |||||
def get_loss(self, *args, **kwargs): | |||||
raise NotImplementedError | |||||
def _init_param_map(self, key_map=None, **kwargs): | |||||
"""Check the validity of key_map and other param map. Add these into self.param_map | |||||
:param key_map: dict | |||||
:param kwargs: | |||||
:return: None | |||||
""" | |||||
value_counter = defaultdict(set) | |||||
if key_map is not None: | |||||
if not isinstance(key_map, dict): | |||||
raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) | |||||
for key, value in key_map.items(): | |||||
if value is None: | |||||
self.param_map[key] = key | |||||
continue | |||||
if not isinstance(key, str): | |||||
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | |||||
if not isinstance(value, str): | |||||
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | |||||
self.param_map[key] = value | |||||
value_counter[value].add(key) | |||||
for key, value in kwargs.items(): | |||||
if value is None: | |||||
self.param_map[key] = key | |||||
continue | |||||
if not isinstance(value, str): | |||||
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | |||||
self.param_map[key] = value | |||||
value_counter[value].add(key) | |||||
for value, key_set in value_counter.items(): | |||||
if len(key_set) > 1: | |||||
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | |||||
# check consistence between signature and param_map | |||||
func_spect = inspect.getfullargspec(self.get_loss) | |||||
func_args = [arg for arg in func_spect.args if arg != 'self'] | |||||
for func_param, input_param in self.param_map.items(): | |||||
if func_param not in func_args: | |||||
raise NameError( | |||||
f"Parameter `{func_param}` is not in {get_func_signature(self.get_loss)}. Please check the " | |||||
f"initialization parameters, or change its signature.") | |||||
# evaluate should not have varargs. | |||||
if func_spect.varargs: | |||||
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " | |||||
f"positional argument.).") | |||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
""" | |||||
Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||||
such as pred_dict has one element, target_dict has one element | |||||
:param pred_dict: | |||||
:param target_dict: | |||||
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | |||||
""" | |||||
fast_param = {} | |||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
fast_param['pred'] = list(pred_dict.values())[0] | |||||
fast_param['target'] = list(target_dict.values())[0] | |||||
return fast_param | |||||
return fast_param | |||||
def __call__(self, pred_dict, target_dict, check=False): | |||||
""" | |||||
:param pred_dict: A dict from forward function of the network. | |||||
:param target_dict: A dict from DataSet.batch_y. | |||||
:param check: Boolean. Force to check the mapping functions when it is running. | |||||
:return: | |||||
""" | |||||
fast_param = self._fast_param_map(pred_dict, target_dict) | |||||
if fast_param: | |||||
loss = self.get_loss(**fast_param) | |||||
return loss | |||||
if not self._checked: | |||||
# 1. check consistence between signature and param_map | |||||
func_spect = inspect.getfullargspec(self.get_loss) | |||||
func_args = set([arg for arg in func_spect.args if arg != 'self']) | |||||
for func_arg, input_arg in self.param_map.items(): | |||||
if func_arg not in func_args: | |||||
raise NameError(f"`{func_arg}` not in {get_func_signature(self.get_loss)}.") | |||||
# 2. only part of the param_map are passed, left are not | |||||
for arg in func_args: | |||||
if arg not in self.param_map: | |||||
self.param_map[arg] = arg # This param does not need mapping. | |||||
self._evaluate_args = func_args | |||||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | |||||
# need to wrap inputs in dict. | |||||
mapped_pred_dict = {} | |||||
mapped_target_dict = {} | |||||
duplicated = [] | |||||
for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): | |||||
not_duplicate_flag = 0 | |||||
if input_arg in self._reverse_param_map: | |||||
mapped_arg = self._reverse_param_map[input_arg] | |||||
not_duplicate_flag += 1 | |||||
else: | |||||
mapped_arg = input_arg | |||||
if input_arg in pred_dict: | |||||
mapped_pred_dict[mapped_arg] = pred_dict[input_arg] | |||||
not_duplicate_flag += 1 | |||||
if input_arg in target_dict: | |||||
mapped_target_dict[mapped_arg] = target_dict[input_arg] | |||||
not_duplicate_flag += 1 | |||||
if not_duplicate_flag == 3: | |||||
duplicated.append(input_arg) | |||||
# missing | |||||
if not self._checked: | |||||
check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict]) | |||||
# replace missing. | |||||
missing = check_res.missing | |||||
replaced_missing = list(missing) | |||||
for idx, func_arg in enumerate(missing): | |||||
# Don't delete `` in this information, nor add `` | |||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||||
f"in `{self.__class__.__name__}`)" | |||||
check_res = CheckRes(missing=replaced_missing, | |||||
unused=check_res.unused, | |||||
duplicated=duplicated, | |||||
required=check_res.required, | |||||
all_needed=check_res.all_needed, | |||||
varargs=check_res.varargs) | |||||
if check_res.missing or check_res.duplicated or check_res.varargs: | |||||
raise CheckError(check_res=check_res, | |||||
func_signature=get_func_signature(self.get_loss)) | |||||
refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) | |||||
loss = self.get_loss(**refined_args) | |||||
self._checked = True | |||||
return loss | |||||
class LossFunc(LossBase): | |||||
"""A wrapper of user-provided loss function. | |||||
""" | |||||
def __init__(self, func, key_map=None, **kwargs): | |||||
""" | |||||
:param func: a callable object, such as a function. | |||||
:param dict key_map: | |||||
:param kwargs: | |||||
""" | |||||
super(LossFunc, self).__init__() | |||||
_check_function_or_method(func) | |||||
if key_map is not None: | |||||
if not isinstance(key_map, dict): | |||||
raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}") | |||||
self.param_map = key_map | |||||
if len(kwargs) > 0: | |||||
for key, val in kwargs.items(): | |||||
self.param_map.update({key: val}) | |||||
self.get_loss = func | |||||
class CrossEntropyLoss(LossBase): | |||||
def __init__(self, pred=None, target=None, padding_idx=-100): | |||||
# TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际却需要 | |||||
# TODO (16, 4) | |||||
super(CrossEntropyLoss, self).__init__() | |||||
self._init_param_map(pred=pred, target=target) | |||||
self.padding_idx = padding_idx | |||||
def get_loss(self, pred, target): | |||||
return F.cross_entropy(input=pred, target=target, | |||||
ignore_index=self.padding_idx) | |||||
class L1Loss(LossBase): | |||||
def __init__(self, pred=None, target=None): | |||||
super(L1Loss, self).__init__() | |||||
self._init_param_map(pred=pred, target=target) | |||||
def get_loss(self, pred, target): | |||||
return F.l1_loss(input=pred, target=target) | |||||
class BCELoss(LossBase): | |||||
def __init__(self, pred=None, target=None): | |||||
super(BCELoss, self).__init__() | |||||
self._init_param_map(pred=pred, target=target) | |||||
def get_loss(self, pred, target): | |||||
return F.binary_cross_entropy(input=pred, target=target) | |||||
class NLLLoss(LossBase): | |||||
def __init__(self, pred=None, target=None): | |||||
super(NLLLoss, self).__init__() | |||||
self._init_param_map(pred=pred, target=target) | |||||
def get_loss(self, pred, target): | |||||
return F.nll_loss(input=pred, target=target) | |||||
class LossInForward(LossBase): | |||||
def __init__(self, loss_key='loss'): | |||||
super().__init__() | |||||
if not isinstance(loss_key, str): | |||||
raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") | |||||
self.loss_key = loss_key | |||||
def get_loss(self, **kwargs): | |||||
if self.loss_key not in kwargs: | |||||
check_res = CheckRes(missing=[self.loss_key + f"(assign to `{self.loss_key}` " \ | |||||
f"in `{self.__class__.__name__}`"], | |||||
unused=[], | |||||
duplicated=[], | |||||
required=[], | |||||
all_needed=[], | |||||
varargs=[]) | |||||
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) | |||||
return kwargs[self.loss_key] | |||||
def __call__(self, pred_dict, target_dict, check=False): | |||||
loss = self.get_loss(**pred_dict) | |||||
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | |||||
if not isinstance(loss, torch.Tensor): | |||||
raise TypeError(f"loss excepts to be a torch.Tensor, got {type(loss)}") | |||||
raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | |||||
return loss | |||||
def _prepare_losser(losser): | |||||
if losser is None: | |||||
losser = LossInForward() | |||||
return losser | |||||
elif isinstance(losser, LossBase): | |||||
return losser | |||||
else: | |||||
raise TypeError(f"Type of loss should be `fastNLP.LossBase`, got {type(losser)}") | |||||
def squash(predict, truth, **kwargs): | |||||
"""To reshape tensors in order to fit loss functions in pytorch | |||||
:param predict : Tensor, model output | |||||
:param truth : Tensor, truth from dataset | |||||
:param **kwargs : extra arguments | |||||
:return predict , truth: predict & truth after processing | |||||
""" | |||||
return predict.view(-1, predict.size()[-1]), truth.view(-1, ) | |||||
def unpad(predict, truth, **kwargs): | |||||
"""To process padded sequence output to get true loss | |||||
Using pack_padded_sequence() method | |||||
This method contains squash() | |||||
:param predict : Tensor, [batch_size , max_len , tag_size] | |||||
:param truth : Tensor, [batch_size , max_len] | |||||
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist | |||||
kwargs["lens"] : list or LongTensor, [batch_size] | |||||
the i-th element is true lengths of i-th sequence | |||||
:return predict , truth: predict & truth after processing | |||||
""" | |||||
if kwargs.get("lens") is None: | |||||
return predict, truth | |||||
lens = torch.LongTensor(kwargs["lens"]) | |||||
lens, idx = torch.sort(lens, descending=True) | |||||
predict = torch.nn.utils.rnn.pack_padded_sequence(predict[idx], lens, batch_first=True).data | |||||
truth = torch.nn.utils.rnn.pack_padded_sequence(truth[idx], lens, batch_first=True).data | |||||
return predict, truth | |||||
def unpad_mask(predict, truth, **kwargs): | |||||
"""To process padded sequence output to get true loss | |||||
Using mask() method | |||||
This method contains squash() | |||||
:param predict : Tensor, [batch_size , max_len , tag_size] | |||||
:param truth : Tensor, [batch_size , max_len] | |||||
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist | |||||
kwargs["lens"] : list or LongTensor, [batch_size] | |||||
the i-th element is true lengths of i-th sequence | |||||
:return predict , truth: predict & truth after processing | |||||
""" | |||||
if kwargs.get("lens") is None: | |||||
return predict, truth | |||||
mas = make_mask(kwargs["lens"], truth.size()[1]) | |||||
return mask(predict, truth, mask=mas) | |||||
def mask(predict, truth, **kwargs): | |||||
"""To select specific elements from Tensor | |||||
This method contains squash() | |||||
:param predict : Tensor, [batch_size , max_len , tag_size] | |||||
:param truth : Tensor, [batch_size , max_len] | |||||
:param **kwargs : extra arguments, kwargs["mask"] is expected to be exsist | |||||
kwargs["mask"] : ByteTensor, [batch_size , max_len] | |||||
the mask Tensor , the position that is 1 will be selected | |||||
:return predict , truth: predict & truth after processing | |||||
""" | |||||
if kwargs.get("mask") is None: | |||||
return predict, truth | |||||
mask = kwargs["mask"] | |||||
predict, truth = squash(predict, truth) | |||||
mask = mask.view(-1, ) | |||||
predict = torch.masked_select(predict.permute(1, 0), mask).view(predict.size()[-1], -1).permute(1, 0) | |||||
truth = torch.masked_select(truth, mask) | |||||
return predict, truth | |||||
def make_mask(lens, tar_len): | |||||
"""to generate a mask that select [:lens[i]] for i-th element | |||||
embezzle from fastNLP.models.sequence_modeling.seq_mask | |||||
:param lens : list or LongTensor, [batch_size] | |||||
:param tar_len : int | |||||
:return mask : ByteTensor | |||||
""" | |||||
lens = torch.LongTensor(lens) | |||||
mask = [torch.ge(lens, i + 1) for i in range(tar_len)] | |||||
mask = torch.stack(mask, 1) | |||||
return mask | |||||
@@ -1,243 +1,310 @@ | |||||
import warnings | |||||
import inspect | |||||
from collections import defaultdict | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
class Evaluator(object): | |||||
def __init__(self): | |||||
pass | |||||
def __call__(self, predict, truth): | |||||
""" | |||||
:param predict: list of tensors, the network outputs from all batches. | |||||
:param truth: list of dict, the ground truths from all batch_y. | |||||
:return: | |||||
""" | |||||
raise NotImplementedError | |||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import CheckRes | |||||
from fastNLP.core.utils import _build_args | |||||
from fastNLP.core.utils import _check_arg_dict_list | |||||
from fastNLP.core.utils import get_func_signature | |||||
from fastNLP.core.utils import seq_lens_to_masks | |||||
class ClassifyEvaluator(Evaluator): | |||||
class MetricBase(object): | |||||
def __init__(self): | def __init__(self): | ||||
super(ClassifyEvaluator, self).__init__() | |||||
def __call__(self, predict, truth): | |||||
y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict] | |||||
y_prob = torch.cat(y_prob, dim=0) | |||||
y_pred = torch.argmax(y_prob, dim=-1) | |||||
y_true = torch.cat(truth, dim=0) | |||||
acc = float(torch.sum(y_pred == y_true)) / len(y_true) | |||||
return {"accuracy": acc} | |||||
self.param_map = {} # key is param in function, value is input param. | |||||
self._checked = False | |||||
def evaluate(self, *args, **kwargs): | |||||
raise NotImplementedError | |||||
class SeqLabelEvaluator(Evaluator): | |||||
def __init__(self): | |||||
super(SeqLabelEvaluator, self).__init__() | |||||
def _init_param_map(self, key_map=None, **kwargs): | |||||
"""Check the validity of key_map and other param map. Add these into self.param_map | |||||
def __call__(self, predict, truth): | |||||
:param key_map: dict | |||||
:param kwargs: | |||||
:return: None | |||||
""" | |||||
value_counter = defaultdict(set) | |||||
if key_map is not None: | |||||
if not isinstance(key_map, dict): | |||||
raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) | |||||
for key, value in key_map.items(): | |||||
if value is None: | |||||
self.param_map[key] = key | |||||
continue | |||||
if not isinstance(key, str): | |||||
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | |||||
if not isinstance(value, str): | |||||
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | |||||
self.param_map[key] = value | |||||
value_counter[value].add(key) | |||||
for key, value in kwargs.items(): | |||||
if value is None: | |||||
self.param_map[key] = key | |||||
continue | |||||
if not isinstance(value, str): | |||||
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | |||||
self.param_map[key] = value | |||||
value_counter[value].add(key) | |||||
for value, key_set in value_counter.items(): | |||||
if len(key_set) > 1: | |||||
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | |||||
# check consistence between signature and param_map | |||||
func_spect = inspect.getfullargspec(self.evaluate) | |||||
func_args = [arg for arg in func_spect.args if arg != 'self'] | |||||
for func_param, input_param in self.param_map.items(): | |||||
if func_param not in func_args: | |||||
raise NameError( | |||||
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " | |||||
f"initialization parameters, or change its signature.") | |||||
# evaluate should not have varargs. | |||||
if func_spect.varargs: | |||||
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use " | |||||
f"positional argument.).") | |||||
def get_metric(self, reset=True): | |||||
raise NotImplemented | |||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
""" | """ | ||||
:param predict: list of List, the network outputs from all batches. | |||||
:param truth: list of dict, the ground truths from all batch_y. | |||||
:return accuracy: | |||||
Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||||
such as pred_dict has one element, target_dict has one element | |||||
:param pred_dict: | |||||
:param target_dict: | |||||
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | |||||
""" | |||||
fast_param = {} | |||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
fast_param['pred'] = list(pred_dict.values())[0] | |||||
fast_param['target'] = list(pred_dict.values())[0] | |||||
return fast_param | |||||
return fast_param | |||||
def __call__(self, pred_dict, target_dict): | |||||
""" | """ | ||||
truth = [item["truth"] for item in truth] | |||||
total_correct, total_count= 0., 0. | |||||
for x, y in zip(predict, truth): | |||||
x = torch.Tensor(x) | |||||
y = y.to(x) # make sure they are in the same device | |||||
mask = x.ge(1).float() | |||||
# correct = torch.sum(x * mask.float() == (y * mask.long()).float()) | |||||
correct = torch.sum(x * mask == y * mask) | |||||
correct -= torch.sum(x.le(0)) | |||||
total_correct += float(correct) | |||||
total_count += float(torch.sum(mask)) | |||||
accuracy = total_correct / total_count | |||||
return {"accuracy": float(accuracy)} | |||||
class SNLIEvaluator(Evaluator): | |||||
def __init__(self): | |||||
super(SNLIEvaluator, self).__init__() | |||||
def __call__(self, predict, truth): | |||||
y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict] | |||||
y_prob = torch.cat(y_prob, dim=0) | |||||
y_pred = torch.argmax(y_prob, dim=-1) | |||||
truth = [t['truth'] for t in truth] | |||||
y_true = torch.cat(truth, dim=0).view(-1) | |||||
acc = float(torch.sum(y_pred == y_true)) / y_true.size(0) | |||||
return {"accuracy": acc} | |||||
This method will call self.evaluate method. | |||||
Before calling self.evaluate, it will first check the validity of output_dict, target_dict | |||||
(1) whether self.evaluate has varargs, which is not supported. | |||||
(2) whether params needed by self.evaluate is not included in output_dict,target_dict. | |||||
(3) whether params needed by self.evaluate duplicate in pred_dict, target_dict | |||||
(4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) | |||||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | |||||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | |||||
will be conducted.) | |||||
This function also support _fast_param_map. | |||||
:param pred_dict: usually the output of forward or prediction function | |||||
:param target_dict: usually features set as target.. | |||||
:return: | |||||
""" | |||||
if not callable(self.evaluate): | |||||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | |||||
fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict) | |||||
if fast_param: | |||||
self.evaluate(**fast_param) | |||||
return | |||||
if not self._checked: | |||||
# 1. check consistence between signature and param_map | |||||
func_spect = inspect.getfullargspec(self.evaluate) | |||||
func_args = set([arg for arg in func_spect.args if arg != 'self']) | |||||
for func_arg, input_arg in self.param_map.items(): | |||||
if func_arg not in func_args: | |||||
raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.") | |||||
# 2. only part of the param_map are passed, left are not | |||||
for arg in func_args: | |||||
if arg not in self.param_map: | |||||
self.param_map[arg] = arg # This param does not need mapping. | |||||
self._evaluate_args = func_args | |||||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | |||||
# need to wrap inputs in dict. | |||||
mapped_pred_dict = {} | |||||
mapped_target_dict = {} | |||||
duplicated = [] | |||||
for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): | |||||
not_duplicate_flag = 0 | |||||
if input_arg in self._reverse_param_map: | |||||
mapped_arg = self._reverse_param_map[input_arg] | |||||
not_duplicate_flag += 1 | |||||
else: | |||||
mapped_arg = input_arg | |||||
if input_arg in pred_dict: | |||||
mapped_pred_dict[mapped_arg] = pred_dict[input_arg] | |||||
not_duplicate_flag += 1 | |||||
if input_arg in target_dict: | |||||
mapped_target_dict[mapped_arg] = target_dict[input_arg] | |||||
not_duplicate_flag += 1 | |||||
if not_duplicate_flag == 3: | |||||
duplicated.append(input_arg) | |||||
# missing | |||||
if not self._checked: | |||||
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) | |||||
# only check missing. | |||||
# replace missing. | |||||
missing = check_res.missing | |||||
replaced_missing = list(missing) | |||||
for idx, func_arg in enumerate(missing): | |||||
# Don't delete `` in this information, nor add `` | |||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||||
f"in `{self.__class__.__name__}`)" | |||||
check_res = CheckRes(missing=replaced_missing, | |||||
unused=check_res.unused, | |||||
duplicated=duplicated, | |||||
required=check_res.required, | |||||
all_needed=check_res.all_needed, | |||||
varargs=check_res.varargs) | |||||
if check_res.missing or check_res.duplicated or check_res.varargs: | |||||
raise CheckError(check_res=check_res, | |||||
func_signature=get_func_signature(self.evaluate)) | |||||
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | |||||
self.evaluate(**refined_args) | |||||
self._checked = True | |||||
return | |||||
class AccuracyMetric(MetricBase): | |||||
def __init__(self, pred=None, target=None, seq_lens=None): | |||||
super().__init__() | |||||
self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) | |||||
self.total = 0 | |||||
self.acc_count = 0 | |||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
""" | |||||
Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||||
such as pred_dict has one element, target_dict has one element | |||||
:param pred_dict: | |||||
:param target_dict: | |||||
:return: dict, if dict is not None, pass it to self.evaluate. Otherwise do mapping. | |||||
""" | |||||
fast_param = {} | |||||
targets = list(target_dict.values()) | |||||
if len(targets) == 1 and isinstance(targets[0], torch.Tensor): | |||||
if len(pred_dict) == 1: | |||||
pred = list(pred_dict.values())[0] | |||||
fast_param['pred'] = pred | |||||
elif len(pred_dict) == 2: | |||||
pred1 = list(pred_dict.values())[0] | |||||
pred2 = list(pred_dict.values())[1] | |||||
if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)): | |||||
return fast_param | |||||
if len(pred1.size()) < len(pred2.size()) and len(pred1.size()) == 1: | |||||
seq_lens = pred1 | |||||
pred = pred2 | |||||
elif len(pred1.size()) > len(pred2.size()) and len(pred2.size()) == 1: | |||||
seq_lens = pred2 | |||||
pred = pred1 | |||||
else: | |||||
return fast_param | |||||
fast_param['pred'] = pred | |||||
fast_param['seq_lens'] = seq_lens | |||||
else: | |||||
return fast_param | |||||
fast_param['target'] = targets[0] | |||||
# TODO need to make sure they all have same batch_size | |||||
return fast_param | |||||
def evaluate(self, pred, target, seq_lens=None): | |||||
""" | |||||
def _conver_numpy(x): | |||||
"""convert input data to numpy array | |||||
:param pred: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: | |||||
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) | |||||
:param target: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||||
torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len]) | |||||
:param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||||
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. | |||||
:return: dict({'acc': float}) | |||||
""" | |||||
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value | |||||
if not isinstance(pred, torch.Tensor): | |||||
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(pred)}.") | |||||
if not isinstance(target, torch.Tensor): | |||||
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(target)}.") | |||||
if seq_lens is not None and not isinstance(seq_lens, torch.Tensor): | |||||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(seq_lens)}.") | |||||
if seq_lens is not None: | |||||
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) | |||||
else: | |||||
masks = None | |||||
""" | |||||
if isinstance(x, np.ndarray): | |||||
return x | |||||
elif isinstance(x, torch.Tensor): | |||||
return x.numpy() | |||||
elif isinstance(x, list): | |||||
return np.array(x) | |||||
raise TypeError('cannot accept object: {}'.format(x)) | |||||
if pred.size() == target.size(): | |||||
pass | |||||
elif len(pred.size()) == len(target.size()) + 1: | |||||
pred = pred.argmax(dim=-1) | |||||
else: | |||||
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | |||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||||
f"{pred.size()[:-1]}, got {target.size()}.") | |||||
pred = pred.float() | |||||
target = target.float() | |||||
def _check_same_len(*arrays, axis=0): | |||||
"""check if input array list has same length for one dimension | |||||
if masks is not None: | |||||
self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item() | |||||
self.total += torch.sum(masks.float()).item() | |||||
else: | |||||
self.acc_count += torch.sum(torch.eq(pred, target).float()).item() | |||||
self.total += np.prod(list(pred.size())) | |||||
""" | |||||
lens = set([x.shape[axis] for x in arrays if x is not None]) | |||||
return len(lens) == 1 | |||||
def get_metric(self, reset=True): | |||||
evaluate_result = {'acc': round(self.acc_count / self.total, 6)} | |||||
if reset: | |||||
self.acc_count = 0 | |||||
self.total = 0 | |||||
return evaluate_result | |||||
def _label_types(y): | |||||
"""Determine the type | |||||
- "binary" | |||||
- "multiclass" | |||||
- "multiclass-multioutput" | |||||
- "multilabel" | |||||
- "unknown" | |||||
def _prepare_metrics(metrics): | |||||
""" | """ | ||||
# never squeeze the first dimension | |||||
y = y.squeeze() if y.shape[0] > 1 else y.resize(1, -1) | |||||
shape = y.shape | |||||
if len(shape) < 1: | |||||
raise ValueError('cannot accept data: {}'.format(y)) | |||||
if len(shape) == 1: | |||||
return 'multiclass' if np.unique(y).shape[0] > 2 else 'binary', y | |||||
if len(shape) == 2: | |||||
return 'multiclass-multioutput' if np.unique(y).shape[0] > 2 else 'multilabel', y | |||||
return 'unknown', y | |||||
def _check_data(y_true, y_pred): | |||||
"""Check if y_true and y_pred is same type of data e.g both binary or multiclass | |||||
Prepare list of Metric based on input | |||||
:param metrics: | |||||
:return: List[fastNLP.MetricBase] | |||||
""" | """ | ||||
y_true, y_pred = _conver_numpy(y_true), _conver_numpy(y_pred) | |||||
if not _check_same_len(y_true, y_pred): | |||||
raise ValueError('cannot accept data with different shape {0}, {1}'.format(y_true, y_pred)) | |||||
type_true, y_true = _label_types(y_true) | |||||
type_pred, y_pred = _label_types(y_pred) | |||||
type_set = set(['binary', 'multiclass']) | |||||
if type_true in type_set and type_pred in type_set: | |||||
return type_true if type_true == type_pred else 'multiclass', y_true, y_pred | |||||
type_set = set(['multiclass-multioutput', 'multilabel']) | |||||
if type_true in type_set and type_pred in type_set: | |||||
return type_true if type_true == type_pred else 'multiclass-multioutput', y_true, y_pred | |||||
raise ValueError('cannot accept data mixed of {0} and {1} target'.format(type_true, type_pred)) | |||||
def _weight_sum(y, normalize=True, sample_weight=None): | |||||
if normalize: | |||||
return np.average(y, weights=sample_weight) | |||||
if sample_weight is None: | |||||
return y.sum() | |||||
else: | |||||
return np.dot(y, sample_weight) | |||||
def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None): | |||||
y_type, y_true, y_pred = _check_data(y_true, y_pred) | |||||
if y_type == 'multiclass-multioutput': | |||||
raise ValueError('cannot accept data type {0}'.format(y_type)) | |||||
if y_type == 'multilabel': | |||||
equel = (y_true == y_pred).sum(1) | |||||
count = equel == y_true.shape[1] | |||||
else: | |||||
count = y_true == y_pred | |||||
return _weight_sum(count, normalize=normalize, sample_weight=sample_weight) | |||||
def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||||
y_type, y_true, y_pred = _check_data(y_true, y_pred) | |||||
if average == 'binary': | |||||
if y_type != 'binary': | |||||
raise ValueError("data type is {} but use average type {}".format(y_type, average)) | |||||
else: | |||||
pos = (y_true == pos_label) | |||||
tp = np.logical_and((y_true == y_pred), pos).sum() | |||||
pos_sum = pos.sum() | |||||
return tp / pos_sum if pos_sum > 0 else 0 | |||||
elif average == None: | |||||
y_labels = set(list(np.unique(y_true))) | |||||
if labels is None: | |||||
labels = list(y_labels) | |||||
else: | |||||
for i in labels: | |||||
if (i not in y_labels and y_type != 'multilabel') or (y_type == 'multilabel' and i >= y_true.shape[1]): | |||||
warnings.warn('label {} is not contained in data'.format(i), UserWarning) | |||||
if y_type in ['binary', 'multiclass']: | |||||
y_pred_right = y_true == y_pred | |||||
pos_list = [y_true == i for i in labels] | |||||
pos_sum_list = [pos_i.sum() for pos_i in pos_list] | |||||
return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ | |||||
for pos_i, sum_i in zip(pos_list, pos_sum_list)]) | |||||
elif y_type == 'multilabel': | |||||
y_pred_right = y_true == y_pred | |||||
pos = (y_true == pos_label) | |||||
tp = np.logical_and(y_pred_right, pos).sum(0) | |||||
pos_sum = pos.sum(0) | |||||
return np.array([tp[i] / pos_sum[i] if pos_sum[i] > 0 else 0 for i in labels]) | |||||
else: | |||||
raise ValueError('not support targets type {}'.format(y_type)) | |||||
raise ValueError('not support for average type {}'.format(average)) | |||||
def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||||
y_type, y_true, y_pred = _check_data(y_true, y_pred) | |||||
if average == 'binary': | |||||
if y_type != 'binary': | |||||
raise ValueError("data type is {} but use average type {}".format(y_type, average)) | |||||
else: | |||||
pos = (y_true == pos_label) | |||||
tp = np.logical_and((y_true == y_pred), pos).sum() | |||||
pos_pred = (y_pred == pos_label).sum() | |||||
return tp / pos_pred if pos_pred > 0 else 0 | |||||
elif average == None: | |||||
y_labels = set(list(np.unique(y_true))) | |||||
if labels is None: | |||||
labels = list(y_labels) | |||||
_metrics = [] | |||||
if metrics: | |||||
if isinstance(metrics, list): | |||||
for metric in metrics: | |||||
if isinstance(metric, type): | |||||
metric = metric() | |||||
if isinstance(metric, MetricBase): | |||||
metric_name = metric.__class__.__name__ | |||||
if not callable(metric.evaluate): | |||||
raise TypeError(f"{metric_name}.evaluate must be callable, got {type(metric.evaluate)}.") | |||||
if not callable(metric.get_metric): | |||||
raise TypeError(f"{metric_name}.get_metric must be callable, got {type(metric.get_metric)}.") | |||||
_metrics.append(metric) | |||||
else: | |||||
raise TypeError( | |||||
f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.") | |||||
elif isinstance(metrics, MetricBase): | |||||
_metrics = [metrics] | |||||
else: | else: | ||||
for i in labels: | |||||
if (i not in y_labels and y_type != 'multilabel') or (y_type == 'multilabel' and i >= y_true.shape[1]): | |||||
warnings.warn('label {} is not contained in data'.format(i), UserWarning) | |||||
if y_type in ['binary', 'multiclass']: | |||||
y_pred_right = y_true == y_pred | |||||
pos_list = [y_true == i for i in labels] | |||||
pos_sum_list = [(y_pred == i).sum() for i in labels] | |||||
return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ | |||||
for pos_i, sum_i in zip(pos_list, pos_sum_list)]) | |||||
elif y_type == 'multilabel': | |||||
y_pred_right = y_true == y_pred | |||||
pos = (y_true == pos_label) | |||||
tp = np.logical_and(y_pred_right, pos).sum(0) | |||||
pos_sum = (y_pred == pos_label).sum(0) | |||||
return np.array([tp[i] / pos_sum[i] if pos_sum[i] > 0 else 0 for i in labels]) | |||||
else: | |||||
raise ValueError('not support targets type {}'.format(y_type)) | |||||
raise ValueError('not support for average type {}'.format(average)) | |||||
def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||||
precision = precision_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) | |||||
recall = recall_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) | |||||
if isinstance(precision, np.ndarray): | |||||
res = 2 * precision * recall / (precision + recall) | |||||
res[(precision + recall) <= 0] = 0 | |||||
return res | |||||
return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 | |||||
def classification_report(y_true, y_pred, labels=None, target_names=None, digits=2): | |||||
raise NotImplementedError | |||||
raise TypeError(f"The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, " | |||||
f"got {type(metrics)}.") | |||||
return _metrics | |||||
def accuracy_topk(y_true, y_prob, k=1): | def accuracy_topk(y_true, y_prob, k=1): | ||||
@@ -275,8 +342,3 @@ def pred_topk(y_prob, k=1): | |||||
(1, k)) | (1, k)) | ||||
y_prob_topk = y_prob[x_axis_index, y_pred_topk] | y_prob_topk = y_prob[x_axis_index, y_pred_topk] | ||||
return y_pred_topk, y_prob_topk | return y_pred_topk, y_prob_topk | ||||
if __name__ == '__main__': | |||||
y = np.array([1, 0, 1, 0, 1, 1]) | |||||
print(_label_types(y)) |
@@ -2,61 +2,48 @@ import torch | |||||
class Optimizer(object): | class Optimizer(object): | ||||
"""Wrapper of optimizer from framework | |||||
def __init__(self, model_params, **kwargs): | |||||
if model_params is not None and not hasattr(model_params, "__next__"): | |||||
raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params))) | |||||
self.model_params = model_params | |||||
self.settings = kwargs | |||||
1. Adam: lr (float), weight_decay (float) | |||||
2. AdaGrad | |||||
3. RMSProp | |||||
4. SGD: lr (float), momentum (float) | |||||
""" | |||||
def __init__(self, optimizer_name, **kwargs): | |||||
class SGD(Optimizer): | |||||
def __init__(self, lr=0.01, momentum=0, model_params=None): | |||||
""" | """ | ||||
:param optimizer_name: str, the name of the optimizer | |||||
:param kwargs: the arguments | |||||
:param float lr: learning rate. Default: 0.01 | |||||
:param float momentum: momentum. Default: 0 | |||||
:param model_params: a generator. E.g. model.parameters() for PyTorch models. | |||||
""" | """ | ||||
self.optim_name = optimizer_name | |||||
self.kwargs = kwargs | |||||
if not isinstance(lr, float): | |||||
raise TypeError("learning rate has to be float.") | |||||
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) | |||||
@property | |||||
def name(self): | |||||
"""The name of the optimizer. | |||||
def construct_from_pytorch(self, model_params): | |||||
if self.model_params is None: | |||||
# careful! generator cannot be assigned. | |||||
return torch.optim.SGD(model_params, **self.settings) | |||||
else: | |||||
return torch.optim.SGD(self.model_params, **self.settings) | |||||
:return: str | |||||
""" | |||||
return self.optim_name | |||||
@property | |||||
def params(self): | |||||
"""The arguments used to create the optimizer. | |||||
class Adam(Optimizer): | |||||
def __init__(self, lr=0.01, weight_decay=0, model_params=None): | |||||
""" | |||||
:return: dict of (str, *) | |||||
:param float lr: learning rate | |||||
:param float weight_decay: | |||||
:param model_params: a generator. E.g. model.parameters() for PyTorch models. | |||||
""" | """ | ||||
return self.kwargs | |||||
if not isinstance(lr, float): | |||||
raise TypeError("learning rate has to be float.") | |||||
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) | |||||
def construct_from_pytorch(self, model_params): | def construct_from_pytorch(self, model_params): | ||||
"""Construct a optimizer from framework over given model parameters.""" | |||||
if self.optim_name in ["SGD", "sgd"]: | |||||
if "lr" in self.kwargs: | |||||
if "momentum" not in self.kwargs: | |||||
self.kwargs["momentum"] = 0 | |||||
optimizer = torch.optim.SGD(model_params, lr=self.kwargs["lr"], momentum=self.kwargs["momentum"]) | |||||
else: | |||||
raise ValueError("requires learning rate for SGD optimizer") | |||||
elif self.optim_name in ["adam", "Adam"]: | |||||
if "lr" in self.kwargs: | |||||
if "weight_decay" not in self.kwargs: | |||||
self.kwargs["weight_decay"] = 0 | |||||
optimizer = torch.optim.Adam(model_params, lr=self.kwargs["lr"], | |||||
weight_decay=self.kwargs["weight_decay"]) | |||||
else: | |||||
raise ValueError("requires learning rate for Adam optimizer") | |||||
if self.model_params is None: | |||||
# careful! generator cannot be assigned. | |||||
return torch.optim.Adam(model_params, **self.settings) | |||||
else: | else: | ||||
raise NotImplementedError | |||||
return optimizer | |||||
return torch.optim.Adam(self.model_params, **self.settings) |
@@ -1,10 +1,7 @@ | |||||
import numpy as np | |||||
import torch | import torch | ||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.preprocess import load_pickle | |||||
from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
from fastNLP.loader.dataset_loader import convert_seq2seq_dataset, convert_seq2tag_dataset, convert_seq_dataset | |||||
class Predictor(object): | class Predictor(object): | ||||
@@ -16,42 +13,29 @@ class Predictor(object): | |||||
Currently, Predictor does not support GPU. | Currently, Predictor does not support GPU. | ||||
""" | """ | ||||
def __init__(self, pickle_path, post_processor): | |||||
""" | |||||
:param pickle_path: str, the path to the pickle files. | |||||
:param post_processor: a function or callable object, that takes list of batch outputs as input | |||||
""" | |||||
def __init__(self): | |||||
self.batch_size = 1 | self.batch_size = 1 | ||||
self.batch_output = [] | self.batch_output = [] | ||||
self.pickle_path = pickle_path | |||||
self._post_processor = post_processor | |||||
self.label_vocab = load_pickle(self.pickle_path, "label2id.pkl") | |||||
self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl") | |||||
def predict(self, network, data): | def predict(self, network, data): | ||||
"""Perform inference using the trained model. | """Perform inference using the trained model. | ||||
:param network: a PyTorch model (cpu) | :param network: a PyTorch model (cpu) | ||||
:param data: a DataSet object. | :param data: a DataSet object. | ||||
:return: list of list of strings, [num_examples, tag_seq_length] | |||||
:return: list of batch outputs | |||||
""" | """ | ||||
# transform strings into DataSet object | |||||
# data = self.prepare_input(data) | |||||
# turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
self.mode(network, test=True) | self.mode(network, test=True) | ||||
batch_output = [] | batch_output = [] | ||||
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False) | |||||
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||||
for batch_x, _ in data_iterator: | for batch_x, _ in data_iterator: | ||||
with torch.no_grad(): | with torch.no_grad(): | ||||
prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
batch_output.append(prediction) | batch_output.append(prediction) | ||||
return self._post_processor(batch_output, self.label_vocab) | |||||
return batch_output | |||||
def mode(self, network, test=True): | def mode(self, network, test=True): | ||||
if test: | if test: | ||||
@@ -63,51 +47,3 @@ class Predictor(object): | |||||
"""Forward through network.""" | """Forward through network.""" | ||||
y = network(**x) | y = network(**x) | ||||
return y | return y | ||||
def prepare_input(self, data): | |||||
"""Transform two-level list of strings into an DataSet object. | |||||
In the training pipeline, this is done by Preprocessor. But in inference time, we do not call Preprocessor. | |||||
:param data: list of list of strings. | |||||
:: | |||||
[ | |||||
[word_11, word_12, ...], | |||||
[word_21, word_22, ...], | |||||
... | |||||
] | |||||
:return data_set: a DataSet instance. | |||||
""" | |||||
assert isinstance(data, list) | |||||
data = convert_seq_dataset(data) | |||||
data.index_field("word_seq", self.word_vocab) | |||||
class SeqLabelInfer(Predictor): | |||||
def __init__(self, pickle_path): | |||||
print( | |||||
"[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor directly.") | |||||
super(SeqLabelInfer, self).__init__(pickle_path, seq_label_post_processor) | |||||
class ClassificationInfer(Predictor): | |||||
def __init__(self, pickle_path): | |||||
print( | |||||
"[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor directly.") | |||||
super(ClassificationInfer, self).__init__(pickle_path, text_classify_post_processor) | |||||
def seq_label_post_processor(batch_outputs, label_vocab): | |||||
results = [] | |||||
for batch in batch_outputs: | |||||
for example in np.array(batch): | |||||
results.append([label_vocab.to_word(int(x)) for x in example]) | |||||
return results | |||||
def text_classify_post_processor(batch_outputs, label_vocab): | |||||
results = [] | |||||
for batch_out in batch_outputs: | |||||
idx = np.argmax(batch_out.detach().numpy(), axis=-1) | |||||
results.extend([label_vocab.to_word(i) for i in idx]) | |||||
return results |
@@ -1,48 +0,0 @@ | |||||
import _pickle | |||||
import os | |||||
# the first vocab in dict with the index = 5 | |||||
def save_pickle(obj, pickle_path, file_name): | |||||
"""Save an object into a pickle file. | |||||
:param obj: an object | |||||
:param pickle_path: str, the directory where the pickle file is to be saved | |||||
:param file_name: str, the name of the pickle file. In general, it should be ended by "pkl". | |||||
""" | |||||
if not os.path.exists(pickle_path): | |||||
os.mkdir(pickle_path) | |||||
print("make dir {} before saving pickle file".format(pickle_path)) | |||||
with open(os.path.join(pickle_path, file_name), "wb") as f: | |||||
_pickle.dump(obj, f) | |||||
print("{} saved in {}".format(file_name, pickle_path)) | |||||
def load_pickle(pickle_path, file_name): | |||||
"""Load an object from a given pickle file. | |||||
:param pickle_path: str, the directory where the pickle file is. | |||||
:param file_name: str, the name of the pickle file. | |||||
:return obj: an object stored in the pickle | |||||
""" | |||||
with open(os.path.join(pickle_path, file_name), "rb") as f: | |||||
obj = _pickle.load(f) | |||||
print("{} loaded from {}".format(file_name, pickle_path)) | |||||
return obj | |||||
def pickle_exist(pickle_path, pickle_name): | |||||
"""Check if a given pickle file exists in the directory. | |||||
:param pickle_path: the directory of target pickle file | |||||
:param pickle_name: the filename of target pickle file | |||||
:return: True if file exists else False | |||||
""" | |||||
if not os.path.exists(pickle_path): | |||||
os.makedirs(pickle_path) | |||||
file_name = os.path.join(pickle_path, pickle_name) | |||||
if os.path.exists(file_name): | |||||
return True | |||||
else: | |||||
return False |
@@ -1,3 +1,5 @@ | |||||
from itertools import chain | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
@@ -44,12 +46,52 @@ class RandomSampler(BaseSampler): | |||||
return list(np.random.permutation(len(data_set))) | return list(np.random.permutation(len(data_set))) | ||||
class BucketSampler(BaseSampler): | |||||
def __init__(self, num_buckets=10, batch_size=32, seq_lens_field_name='seq_lens'): | |||||
self.num_buckets = num_buckets | |||||
self.batch_size = batch_size | |||||
self.seq_lens_field_name = seq_lens_field_name | |||||
def __call__(self, data_set): | |||||
seq_lens = data_set.get_all_fields()[self.seq_lens_field_name].content | |||||
total_sample_num = len(seq_lens) | |||||
bucket_indexes = [] | |||||
num_sample_per_bucket = total_sample_num // self.num_buckets | |||||
for i in range(self.num_buckets): | |||||
bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)]) | |||||
bucket_indexes[-1][1] = total_sample_num | |||||
sorted_seq_lens = list(sorted([(idx, seq_len) for | |||||
idx, seq_len in zip(range(total_sample_num), seq_lens)], | |||||
key=lambda x: x[1])) | |||||
batchs = [] | |||||
left_init_indexes = [] | |||||
for b_idx in range(self.num_buckets): | |||||
start_idx = bucket_indexes[b_idx][0] | |||||
end_idx = bucket_indexes[b_idx][1] | |||||
sorted_bucket_seq_lens = sorted_seq_lens[start_idx:end_idx] | |||||
left_init_indexes.extend([tup[0] for tup in sorted_bucket_seq_lens]) | |||||
num_batch_per_bucket = len(left_init_indexes) // self.batch_size | |||||
np.random.shuffle(left_init_indexes) | |||||
for i in range(num_batch_per_bucket): | |||||
batchs.append(left_init_indexes[i * self.batch_size:(i + 1) * self.batch_size]) | |||||
left_init_indexes = left_init_indexes[num_batch_per_bucket * self.batch_size:] | |||||
if (left_init_indexes) != 0: | |||||
batchs.append(left_init_indexes) | |||||
np.random.shuffle(batchs) | |||||
return list(chain(*batchs)) | |||||
def simple_sort_bucketing(lengths): | def simple_sort_bucketing(lengths): | ||||
""" | """ | ||||
:param lengths: list of int, the lengths of all examples. | :param lengths: list of int, the lengths of all examples. | ||||
:param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length | |||||
threshold for each bucket (This is usually None.). | |||||
:return data: 2-level list | :return data: 2-level list | ||||
:: | :: | ||||
@@ -65,6 +107,7 @@ def simple_sort_bucketing(lengths): | |||||
# TODO: need to return buckets | # TODO: need to return buckets | ||||
return [idx for idx, _ in sorted_lengths] | return [idx for idx, _ in sorted_lengths] | ||||
def k_means_1d(x, k, max_iter=100): | def k_means_1d(x, k, max_iter=100): | ||||
"""Perform k-means on 1-D data. | """Perform k-means on 1-D data. | ||||
@@ -75,6 +118,7 @@ def k_means_1d(x, k, max_iter=100): | |||||
assignment: numpy array, 1-D, the bucket id assigned to each example. | assignment: numpy array, 1-D, the bucket id assigned to each example. | ||||
""" | """ | ||||
sorted_x = sorted(list(set(x))) | sorted_x = sorted(list(set(x))) | ||||
x = np.array(x) | |||||
if len(sorted_x) < k: | if len(sorted_x) < k: | ||||
raise ValueError("too few buckets") | raise ValueError("too few buckets") | ||||
gap = len(sorted_x) / k | gap = len(sorted_x) / k | ||||
@@ -118,35 +162,3 @@ def k_means_bucketing(lengths, buckets): | |||||
if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]: | if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]: | ||||
bucket_data[bucket_id].append(idx) | bucket_data[bucket_id].append(idx) | ||||
return bucket_data | return bucket_data | ||||
class BucketSampler(BaseSampler): | |||||
"""Partition all samples into multiple buckets, each of which contains sentences of approximately the same length. | |||||
In sampling, first random choose a bucket. Then sample data from it. | |||||
The number of buckets is decided dynamically by the variance of sentence lengths. | |||||
""" | |||||
def __call__(self, data_set, batch_size, num_buckets): | |||||
return self._process(data_set, batch_size, num_buckets) | |||||
def _process(self, data_set, batch_size, num_buckets, use_kmeans=False): | |||||
""" | |||||
:param data_set: a DataSet object | |||||
:param batch_size: int | |||||
:param num_buckets: int, number of buckets for grouping these sequences. | |||||
:param use_kmeans: bool, whether to use k-means to create buckets. | |||||
""" | |||||
buckets = ([None] * num_buckets) | |||||
if use_kmeans is True: | |||||
buckets = k_means_bucketing(data_set, buckets) | |||||
else: | |||||
buckets = simple_sort_bucketing(data_set) | |||||
index_list = [] | |||||
for _ in range(len(data_set) // batch_size): | |||||
chosen_bucket = buckets[np.random.randint(0, len(buckets))] | |||||
np.random.shuffle(chosen_bucket) | |||||
index_list += [idx for idx in chosen_bucket[:batch_size]] | |||||
return index_list |
@@ -1,91 +1,88 @@ | |||||
from collections import defaultdict | |||||
import torch | import torch | ||||
from torch import nn | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.metrics import Evaluator | |||||
from fastNLP.core.sampler import RandomSampler | |||||
from fastNLP.saver.logger import create_logger | |||||
logger = create_logger(__name__, "./train_test.log") | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.metrics import _prepare_metrics | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import _build_args | |||||
from fastNLP.core.utils import _check_loss_evaluate | |||||
from fastNLP.core.utils import _move_dict_value_to_device | |||||
from fastNLP.core.utils import get_func_signature | |||||
class Tester(object): | class Tester(object): | ||||
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | ||||
def __init__(self, **kwargs): | |||||
""" | |||||
:param kwargs: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | |||||
""" | |||||
def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=1): | |||||
super(Tester, self).__init__() | super(Tester, self).__init__() | ||||
""" | |||||
"default_args" provides default value for important settings. | |||||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||||
Otherwise, error will raise. | |||||
""" | |||||
default_args = {"batch_size": 8, | |||||
"use_cuda": False, | |||||
"pickle_path": "./save/", | |||||
"model_name": "dev_best_model.pkl", | |||||
"evaluator": Evaluator() | |||||
} | |||||
""" | |||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||||
This is used to warn users of essential settings in the training. | |||||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | |||||
""" | |||||
required_args = {} | |||||
for req_key in required_args: | |||||
if req_key not in kwargs: | |||||
logger.error("Tester lacks argument {}".format(req_key)) | |||||
raise ValueError("Tester lacks argument {}".format(req_key)) | |||||
for key in default_args: | |||||
if key in kwargs: | |||||
if isinstance(kwargs[key], type(default_args[key])): | |||||
default_args[key] = kwargs[key] | |||||
else: | |||||
msg = "Argument %s type mismatch: expected %s while get %s" % ( | |||||
key, type(default_args[key]), type(kwargs[key])) | |||||
logger.error(msg) | |||||
raise ValueError(msg) | |||||
else: | |||||
# Tester doesn't care about extra arguments | |||||
pass | |||||
print(default_args) | |||||
self.batch_size = default_args["batch_size"] | |||||
self.pickle_path = default_args["pickle_path"] | |||||
self.use_cuda = default_args["use_cuda"] | |||||
self._evaluator = default_args["evaluator"] | |||||
self._model = None | |||||
self.eval_history = [] # evaluation results of all batches | |||||
def test(self, network, dev_data): | |||||
if not isinstance(data, DataSet): | |||||
raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.") | |||||
if not isinstance(model, nn.Module): | |||||
raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.") | |||||
self.metrics = _prepare_metrics(metrics) | |||||
self.data = data | |||||
self.use_cuda = use_cuda | |||||
self.batch_size = batch_size | |||||
self.verbose = verbose | |||||
self._model_device = model.parameters().__next__().device | |||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
self._model = network.cuda() | |||||
self._model = model.cuda() | |||||
else: | else: | ||||
self._model = network | |||||
self._model = model | |||||
# check predict | |||||
if hasattr(self._model, 'predict'): | |||||
self._predict_func = self._model.predict | |||||
if not callable(self._predict_func): | |||||
_model_name = model.__class__.__name__ | |||||
raise TypeError(f"`{_model_name}.predict` must be callable to be used " | |||||
f"for evaluation, not `{type(self._predict_func)}`.") | |||||
else: | |||||
self._predict_func = self._model.forward | |||||
def test(self): | |||||
# turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
self.mode(network, is_test=True) | |||||
self.eval_history.clear() | |||||
output_list = [] | |||||
truth_list = [] | |||||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | |||||
for batch_x, batch_y in data_iterator: | |||||
network = self._model | |||||
self._mode(network, is_test=True) | |||||
data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||||
eval_results = {} | |||||
try: | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
prediction = self.data_forward(network, batch_x) | |||||
output_list.append(prediction) | |||||
truth_list.append(batch_y) | |||||
eval_results = self.evaluate(output_list, truth_list) | |||||
print("[tester] {}".format(self.print_eval_results(eval_results))) | |||||
logger.info("[tester] {}".format(self.print_eval_results(eval_results))) | |||||
def mode(self, model, is_test=False): | |||||
for batch_x, batch_y in data_iterator: | |||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||||
pred_dict = self._data_forward(self._predict_func, batch_x) | |||||
if not isinstance(pred_dict, dict): | |||||
raise TypeError(f"The return value of {get_func_signature(self._predict_func)} " | |||||
f"must be `dict`, got {type(pred_dict)}.") | |||||
for metric in self.metrics: | |||||
metric(pred_dict, batch_y) | |||||
for metric in self.metrics: | |||||
eval_result = metric.get_metric() | |||||
if not isinstance(eval_result, dict): | |||||
raise TypeError(f"The return value of {get_func_signature(metric.get_metric)} must be " | |||||
f"`dict`, got {type(eval_result)}") | |||||
metric_name = metric.__class__.__name__ | |||||
eval_results[metric_name] = eval_result | |||||
except CheckError as e: | |||||
prev_func_signature = get_func_signature(self._predict_func) | |||||
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | |||||
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, | |||||
dataset=self.data, check_level=0) | |||||
if self.verbose >= 1: | |||||
print("[tester] \n{}".format(self._format_eval_results(eval_results))) | |||||
self._mode(network, is_test=False) | |||||
return eval_results | |||||
def _mode(self, model, is_test=False): | |||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
:param model: a PyTorch model | :param model: a PyTorch model | ||||
@@ -97,45 +94,21 @@ class Tester(object): | |||||
else: | else: | ||||
model.train() | model.train() | ||||
def data_forward(self, network, x): | |||||
def _data_forward(self, func, x): | |||||
"""A forward pass of the model. """ | """A forward pass of the model. """ | ||||
y = network(**x) | |||||
x = _build_args(func, **x) | |||||
y = func(**x) | |||||
return y | return y | ||||
def evaluate(self, predict, truth): | |||||
"""Compute evaluation metrics. | |||||
:param predict: list of Tensor | |||||
:param truth: list of dict | |||||
:return eval_results: can be anything. It will be stored in self.eval_history | |||||
""" | |||||
return self._evaluator(predict, truth) | |||||
def print_eval_results(self, results): | |||||
def _format_eval_results(self, results): | |||||
"""Override this method to support more print formats. | """Override this method to support more print formats. | ||||
:param results: dict, (str: float) is (metrics name: value) | :param results: dict, (str: float) is (metrics name: value) | ||||
""" | """ | ||||
return ", ".join([str(key) + "=" + str(value) for key, value in results.items()]) | |||||
class SeqLabelTester(Tester): | |||||
def __init__(self, **test_args): | |||||
print( | |||||
"[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester directly.") | |||||
super(SeqLabelTester, self).__init__(**test_args) | |||||
class ClassificationTester(Tester): | |||||
def __init__(self, **test_args): | |||||
print( | |||||
"[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester directly.") | |||||
super(ClassificationTester, self).__init__(**test_args) | |||||
class SNLITester(Tester): | |||||
def __init__(self, **test_args): | |||||
print( | |||||
"[FastNLP Warning] SNLITester will be deprecated. Please use Tester directly.") | |||||
super(SNLITester, self).__init__(**test_args) | |||||
_str = '' | |||||
for metric_name, metric_result in results.items(): | |||||
_str += metric_name + ': ' | |||||
_str += ", ".join([str(key) + "=" + str(value) for key, value in metric_result.items()]) | |||||
_str += '\n' | |||||
return _str[:-1] |
@@ -1,178 +1,275 @@ | |||||
import os | import os | ||||
import time | import time | ||||
from datetime import datetime | |||||
from datetime import timedelta | from datetime import timedelta | ||||
import torch | import torch | ||||
from tensorboardX import SummaryWriter | from tensorboardX import SummaryWriter | ||||
from torch import nn | |||||
from tqdm.autonotebook import tqdm | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.loss import Loss | |||||
from fastNLP.core.metrics import Evaluator | |||||
from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.losses import _prepare_losser | |||||
from fastNLP.core.metrics import _prepare_metrics | |||||
from fastNLP.core.optimizer import Adam | |||||
from fastNLP.core.sampler import BaseSampler | |||||
from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
from fastNLP.core.tester import SeqLabelTester, ClassificationTester, SNLITester | |||||
from fastNLP.saver.logger import create_logger | |||||
from fastNLP.saver.model_saver import ModelSaver | |||||
logger = create_logger(__name__, "./train_test.log") | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.tester import Tester | |||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import _build_args | |||||
from fastNLP.core.utils import _check_forward_error | |||||
from fastNLP.core.utils import _check_loss_evaluate | |||||
from fastNLP.core.utils import _move_dict_value_to_device | |||||
from fastNLP.core.utils import get_func_signature | |||||
class Trainer(object): | class Trainer(object): | ||||
"""Operations of training a model, including data loading, gradient descent, and validation. | |||||
"""Main Training Loop | |||||
""" | """ | ||||
def __init__(self, **kwargs): | |||||
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | |||||
validate_every=-1, dev_data=None, use_cuda=False, save_path=None, | |||||
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | |||||
metric_key=None, sampler=RandomSampler(), use_tqdm=True): | |||||
""" | """ | ||||
:param kwargs: dict of (key, value), or dict-like object. key is str. | |||||
The base trainer requires the following keys: | |||||
- epochs: int, the number of epochs in training | |||||
- validate: bool, whether or not to validate on dev set | |||||
- batch_size: int | |||||
- pickle_path: str, the path to pickle files for pre-processing | |||||
:param DataSet train_data: the training data | |||||
:param torch.nn.modules.module model: a PyTorch model | |||||
:param LossBase loss: a loss object | |||||
:param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics | |||||
:param int n_epochs: the number of training epochs | |||||
:param int batch_size: batch size for training and validation | |||||
:param int print_every: step interval to print next training information. Default: -1(no print). | |||||
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch). | |||||
:param DataSet dev_data: the validation data | |||||
:param use_cuda: | |||||
:param save_path: file path to save models | |||||
:param Optimizer optimizer: an optimizer object | |||||
:param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict. | |||||
`ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means | |||||
it will raise error if some field are not used. | |||||
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one | |||||
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets | |||||
smaller, add a `-` character in front of the string. For example | |||||
:: | |||||
metric_key="-PPL" # language model gets better as perplexity gets smaller | |||||
:param sampler: method used to generate batch data. | |||||
:param use_tqdm: boolean, use tqdm to show train progress. | |||||
""" | """ | ||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
if not isinstance(train_data, DataSet): | |||||
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") | |||||
if not isinstance(model, nn.Module): | |||||
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | |||||
# check metrics and dev_data | |||||
if (not metrics) and dev_data is not None: | |||||
raise ValueError("No metric for dev_data evaluation.") | |||||
if metrics and (dev_data is None): | |||||
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | |||||
# check save_path | |||||
if not (save_path is None or isinstance(save_path, str)): | |||||
raise ValueError("save_path can only be None or `str`.") | |||||
# prepare evaluate | |||||
metrics = _prepare_metrics(metrics) | |||||
# parse metric_key | |||||
# increase_better is True. It means the exp result gets better if the indicator increases. | |||||
# It is true by default. | |||||
self.increase_better = True | |||||
if metric_key is not None: | |||||
self.increase_better = False if metric_key[0] == "-" else True | |||||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||||
elif len(metrics) > 0: | |||||
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') | |||||
# prepare loss | |||||
losser = _prepare_losser(loss) | |||||
# sampler check | |||||
if not isinstance(sampler, BaseSampler): | |||||
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | |||||
if check_code_level > -1: | |||||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||||
metric_key=metric_key, check_level=check_code_level) | |||||
self.train_data = train_data | |||||
self.dev_data = dev_data # If None, No validation. | |||||
self.model = model | |||||
self.losser = losser | |||||
self.metrics = metrics | |||||
self.n_epochs = int(n_epochs) | |||||
self.batch_size = int(batch_size) | |||||
self.use_cuda = bool(use_cuda) | |||||
self.save_path = save_path | |||||
self.print_every = int(print_every) | |||||
self.validate_every = int(validate_every) | |||||
self.best_metric_indicator = None | |||||
self.sampler = sampler | |||||
self._model_device = model.parameters().__next__().device | |||||
if isinstance(optimizer, torch.optim.Optimizer): | |||||
self.optimizer = optimizer | |||||
else: | |||||
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | |||||
self.use_tqdm = use_tqdm | |||||
if self.use_tqdm: | |||||
tester_verbose = 0 | |||||
else: | |||||
tester_verbose = 1 | |||||
if self.dev_data is not None: | |||||
self.tester = Tester(model=self.model, | |||||
data=self.dev_data, | |||||
metrics=self.metrics, | |||||
batch_size=self.batch_size, | |||||
use_cuda=self.use_cuda, | |||||
verbose=tester_verbose) | |||||
self.step = 0 | |||||
self.start_time = None # start timestamp | |||||
def train(self): | |||||
"""Start Training. | |||||
""" | """ | ||||
"default_args" provides default value for important settings. | |||||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||||
Otherwise, error will raise. | |||||
""" | |||||
default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | |||||
"save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1, | |||||
"loss": Loss(None), # used to pass type check | |||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | |||||
"evaluator": Evaluator() | |||||
} | |||||
""" | |||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||||
This is used to warn users of essential settings in the training. | |||||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | |||||
""" | |||||
required_args = {} | |||||
try: | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
self.model = self.model.cuda() | |||||
for req_key in required_args: | |||||
if req_key not in kwargs: | |||||
logger.error("Trainer lacks argument {}".format(req_key)) | |||||
raise ValueError("Trainer lacks argument {}".format(req_key)) | |||||
self._mode(self.model, is_test=False) | |||||
for key in default_args: | |||||
if key in kwargs: | |||||
if isinstance(kwargs[key], type(default_args[key])): | |||||
default_args[key] = kwargs[key] | |||||
else: | |||||
msg = "Argument %s type mismatch: expected %s while get %s" % ( | |||||
key, type(default_args[key]), type(kwargs[key])) | |||||
logger.error(msg) | |||||
raise ValueError(msg) | |||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S')) | |||||
print("training epochs started " + self.start_time, flush=True) | |||||
if self.save_path is None: | |||||
class psudoSW: | |||||
def __getattr__(self, item): | |||||
def pass_func(*args, **kwargs): | |||||
pass | |||||
return pass_func | |||||
self._summary_writer = psudoSW() | |||||
else: | else: | ||||
# Trainer doesn't care about extra arguments | |||||
pass | |||||
print(default_args) | |||||
self.n_epochs = default_args["epochs"] | |||||
self.batch_size = default_args["batch_size"] | |||||
self.pickle_path = default_args["pickle_path"] | |||||
self.validate = default_args["validate"] | |||||
self.save_best_dev = default_args["save_best_dev"] | |||||
self.use_cuda = default_args["use_cuda"] | |||||
self.model_name = default_args["model_name"] | |||||
self.print_every_step = default_args["print_every_step"] | |||||
self._model = None | |||||
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None | |||||
self._optimizer = None | |||||
self._optimizer_proto = default_args["optimizer"] | |||||
self._evaluator = default_args["evaluator"] | |||||
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') | |||||
self._graph_summaried = False | |||||
self._best_accuracy = 0.0 | |||||
def train(self, network, train_data, dev_data=None): | |||||
"""General Training Procedure | |||||
:param network: a model | |||||
:param train_data: a DataSet instance, the training data | |||||
:param dev_data: a DataSet instance, the validation data (optional) | |||||
""" | |||||
# transfer model to gpu if available | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
self._model = network.cuda() | |||||
# self._model is used to access model-specific loss | |||||
else: | |||||
self._model = network | |||||
# define Tester over dev data | |||||
if self.validate: | |||||
default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path, | |||||
"use_cuda": self.use_cuda, "evaluator": self._evaluator} | |||||
validator = self._create_validator(default_valid_args) | |||||
logger.info("validator defined as {}".format(str(validator))) | |||||
# optimizer and loss | |||||
self.define_optimizer() | |||||
logger.info("optimizer defined as {}".format(str(self._optimizer))) | |||||
self.define_loss() | |||||
logger.info("loss function defined as {}".format(str(self._loss_func))) | |||||
# main training procedure | |||||
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) | |||||
self._summary_writer = SummaryWriter(path) | |||||
if self.use_tqdm: | |||||
self._tqdm_train() | |||||
else: | |||||
self._print_train() | |||||
finally: | |||||
self._summary_writer.close() | |||||
del self._summary_writer | |||||
def _tqdm_train(self): | |||||
self.step = 0 | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | |||||
as_numpy=False) | |||||
total_steps = data_iterator.num_batches*self.n_epochs | |||||
epoch = 1 | |||||
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||||
ava_loss = 0 | |||||
for epoch in range(1, self.n_epochs+1): | |||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||||
for batch_x, batch_y in data_iterator: | |||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||||
prediction = self._data_forward(self.model, batch_x) | |||||
loss = self._compute_loss(prediction, batch_y) | |||||
ava_loss += loss.item() | |||||
self._grad_backward(loss) | |||||
self._update() | |||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | |||||
for name, param in self.model.named_parameters(): | |||||
if param.requires_grad: | |||||
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) | |||||
if (self.step+1) % self.print_every == 0: | |||||
pbar.set_postfix_str("loss:{0:<6.5f}".format(ava_loss / self.print_every)) | |||||
ava_loss = 0 | |||||
pbar.update(1) | |||||
self.step += 1 | |||||
if self.validate_every > 0 and self.step % self.validate_every == 0 \ | |||||
and self.dev_data is not None: | |||||
eval_res = self._do_validation() | |||||
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | |||||
self.tester._format_eval_results(eval_res) | |||||
pbar.write(eval_str) | |||||
if self.validate_every < 0 and self.dev_data: | |||||
eval_res = self._do_validation() | |||||
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | |||||
self.tester._format_eval_results(eval_res) | |||||
pbar.write(eval_str) | |||||
if epoch!=self.n_epochs: | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | |||||
as_numpy=False) | |||||
pbar.close() | |||||
def _print_train(self): | |||||
epoch = 1 | |||||
start = time.time() | start = time.time() | ||||
logger.info("training epochs started") | |||||
for epoch in range(1, self.n_epochs + 1): | |||||
logger.info("training epoch {}".format(epoch)) | |||||
# turn on network training mode | |||||
self.mode(network, is_test=False) | |||||
# prepare mini-batch iterator | |||||
data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(), | |||||
use_cuda=self.use_cuda) | |||||
logger.info("prepared data iterator") | |||||
# one forward and backward pass | |||||
self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch) | |||||
# validation | |||||
if self.validate: | |||||
if dev_data is None: | |||||
raise RuntimeError( | |||||
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.") | |||||
logger.info("validation started") | |||||
validator.test(network, dev_data) | |||||
def _train_step(self, data_iterator, network, **kwargs): | |||||
"""Training process in one epoch. | |||||
kwargs should contain: | |||||
- n_print: int, print training information every n steps. | |||||
- start: time.time(), the starting time of this step. | |||||
- epoch: int, | |||||
""" | |||||
step = 0 | |||||
for batch_x, batch_y in data_iterator: | |||||
prediction = self.data_forward(network, batch_x) | |||||
loss = self.get_loss(prediction, batch_y) | |||||
self.grad_backward(loss) | |||||
self.update() | |||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=step) | |||||
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: | |||||
end = time.time() | |||||
diff = timedelta(seconds=round(end - kwargs["start"])) | |||||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( | |||||
kwargs["epoch"], step, loss.data, diff) | |||||
print(print_output) | |||||
logger.info(print_output) | |||||
step += 1 | |||||
def mode(self, model, is_test=False): | |||||
while epoch <= self.n_epochs: | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | |||||
as_numpy=False) | |||||
for batch_x, batch_y in data_iterator: | |||||
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 | |||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||||
prediction = self._data_forward(self.model, batch_x) | |||||
loss = self._compute_loss(prediction, batch_y) | |||||
self._grad_backward(loss) | |||||
self._update() | |||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | |||||
for name, param in self.model.named_parameters(): | |||||
if param.requires_grad: | |||||
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) | |||||
if self.print_every > 0 and self.step % self.print_every == 0: | |||||
end = time.time() | |||||
diff = timedelta(seconds=round(end - start)) | |||||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( | |||||
epoch, self.step, loss.data, diff) | |||||
print(print_output) | |||||
if (self.validate_every > 0 and self.step % self.validate_every == 0 and | |||||
self.dev_data is not None): | |||||
self._do_validation() | |||||
self.step += 1 | |||||
# validate_every override validation at end of epochs | |||||
if self.dev_data and self.validate_every <= 0: | |||||
self._do_validation() | |||||
epoch += 1 | |||||
def _do_validation(self): | |||||
res = self.tester.test() | |||||
for name, metric in res.items(): | |||||
for metric_key, metric_val in metric.items(): | |||||
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | |||||
global_step=self.step) | |||||
if self.save_path is not None and self._better_eval_result(res): | |||||
metric_key = self.metric_key if self.metric_key is not None else "" | |||||
self._save_model(self.model, | |||||
"best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time])) | |||||
return res | |||||
def _mode(self, model, is_test=False): | |||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
:param model: a PyTorch model | :param model: a PyTorch model | ||||
:param is_test: bool, whether in test mode or not. | |||||
:param bool is_test: whether in test mode or not. | |||||
""" | """ | ||||
if is_test: | if is_test: | ||||
@@ -180,127 +277,153 @@ class Trainer(object): | |||||
else: | else: | ||||
model.train() | model.train() | ||||
def define_optimizer(self): | |||||
"""Define framework-specific optimizer specified by the models. | |||||
""" | |||||
self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters()) | |||||
def update(self): | |||||
def _update(self): | |||||
"""Perform weight update on a model. | """Perform weight update on a model. | ||||
For PyTorch, just call optimizer to update. | |||||
""" | """ | ||||
self._optimizer.step() | |||||
self.optimizer.step() | |||||
def data_forward(self, network, x): | |||||
def _data_forward(self, network, x): | |||||
x = _build_args(network.forward, **x) | |||||
y = network(**x) | y = network(**x) | ||||
if not self._graph_summaried: | |||||
# self._summary_writer.add_graph(network, x, verbose=False) | |||||
self._graph_summaried = True | |||||
if not isinstance(y, dict): | |||||
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") | |||||
return y | return y | ||||
def grad_backward(self, loss): | |||||
def _grad_backward(self, loss): | |||||
"""Compute gradient with link rules. | """Compute gradient with link rules. | ||||
:param loss: a scalar where back-prop starts | :param loss: a scalar where back-prop starts | ||||
For PyTorch, just do "loss.backward()" | For PyTorch, just do "loss.backward()" | ||||
""" | """ | ||||
self._model.zero_grad() | |||||
self.model.zero_grad() | |||||
loss.backward() | loss.backward() | ||||
def get_loss(self, predict, truth): | |||||
def _compute_loss(self, predict, truth): | |||||
"""Compute loss given prediction and ground truth. | """Compute loss given prediction and ground truth. | ||||
:param predict: prediction label vector | |||||
:param truth: ground truth label vector | |||||
:param predict: prediction dict, produced by model.forward | |||||
:param truth: ground truth dict, produced by batch_y | |||||
:return: a scalar | :return: a scalar | ||||
""" | """ | ||||
if len(truth) > 1: | |||||
raise NotImplementedError("Not ready to handle multi-labels.") | |||||
truth = list(truth.values())[0] if len(truth) > 0 else None | |||||
return self._loss_func(predict, truth) | |||||
def define_loss(self): | |||||
"""Define a loss for the trainer. | |||||
If the model defines a loss, use model's loss. | |||||
Otherwise, Trainer must has a loss argument, use it as loss. | |||||
These two losses cannot be defined at the same time. | |||||
Trainer does not handle loss definition or choose default losses. | |||||
""" | |||||
# if hasattr(self._model, "loss") and self._loss_func is not None: | |||||
# raise ValueError("Both the model and Trainer define loss. Please take out your loss.") | |||||
return self.losser(predict, truth) | |||||
if hasattr(self._model, "loss"): | |||||
self._loss_func = self._model.loss | |||||
logger.info("The model has a loss function, use it.") | |||||
else: | |||||
if self._loss_func is None: | |||||
raise ValueError("Please specify a loss function.") | |||||
logger.info("The model didn't define loss, use Trainer's loss.") | |||||
def _save_model(self, model, model_name, only_param=False): | |||||
if self.save_path is not None: | |||||
model_name = os.path.join(self.save_path, model_name) | |||||
if only_param: | |||||
torch.save(model.state_dict(), model_name) | |||||
else: | |||||
torch.save(model, model_name) | |||||
def best_eval_result(self, validator): | |||||
def _better_eval_result(self, metrics): | |||||
"""Check if the current epoch yields better validation results. | """Check if the current epoch yields better validation results. | ||||
:param validator: a Tester instance | |||||
:return: bool, True means current results on dev set is the best. | |||||
:return bool value: True means current results on dev set is the best. | |||||
""" | """ | ||||
loss, accuracy = validator.metrics | |||||
if accuracy > self._best_accuracy: | |||||
self._best_accuracy = accuracy | |||||
return True | |||||
indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics) | |||||
is_better = True | |||||
if self.best_metric_indicator is None: | |||||
# first-time validation | |||||
self.best_metric_indicator = indicator_val | |||||
else: | else: | ||||
return False | |||||
def save_model(self, network, model_name): | |||||
"""Save this model with such a name. | |||||
This method may be called multiple times by Trainer to overwritten a better model. | |||||
:param network: the PyTorch model | |||||
:param model_name: str | |||||
""" | |||||
if model_name[-4:] != ".pkl": | |||||
model_name += ".pkl" | |||||
ModelSaver(os.path.join(self.pickle_path, model_name)).save_pytorch(network) | |||||
def _create_validator(self, valid_args): | |||||
raise NotImplementedError | |||||
class SeqLabelTrainer(Trainer): | |||||
"""Trainer for Sequence Labeling | |||||
""" | |||||
def __init__(self, **kwargs): | |||||
print( | |||||
"[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer directly.") | |||||
super(SeqLabelTrainer, self).__init__(**kwargs) | |||||
def _create_validator(self, valid_args): | |||||
return SeqLabelTester(**valid_args) | |||||
class ClassificationTrainer(Trainer): | |||||
"""Trainer for text classification.""" | |||||
def __init__(self, **train_args): | |||||
print( | |||||
"[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer directly.") | |||||
super(ClassificationTrainer, self).__init__(**train_args) | |||||
def _create_validator(self, valid_args): | |||||
return ClassificationTester(**valid_args) | |||||
class SNLITrainer(Trainer): | |||||
"""Trainer for text SNLI.""" | |||||
def __init__(self, **train_args): | |||||
print( | |||||
"[FastNLP Warning] SNLITrainer will be deprecated. Please use Trainer directly.") | |||||
super(SNLITrainer, self).__init__(**train_args) | |||||
def _create_validator(self, valid_args): | |||||
return SNLITester(**valid_args) | |||||
if self.increase_better is True: | |||||
if indicator_val > self.best_metric_indicator: | |||||
self.best_metric_indicator = indicator_val | |||||
else: | |||||
is_better = False | |||||
else: | |||||
if indicator_val < self.best_metric_indicator: | |||||
self.best_metric_indicator = indicator_val | |||||
else: | |||||
is_better = False | |||||
return is_better | |||||
DEFAULT_CHECK_BATCH_SIZE = 2 | |||||
DEFAULT_CHECK_NUM_BATCH = 2 | |||||
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | |||||
dev_data=None, metric_key=None, | |||||
check_level=0): | |||||
# check get_loss 方法 | |||||
model_devcie = model.parameters().__next__().device | |||||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||||
for batch_count, (batch_x, batch_y) in enumerate(batch): | |||||
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | |||||
# forward check | |||||
if batch_count==0: | |||||
_check_forward_error(forward_func=model.forward, dataset=dataset, | |||||
batch_x=batch_x, check_level=check_level) | |||||
refined_batch_x = _build_args(model.forward, **batch_x) | |||||
pred_dict = model(**refined_batch_x) | |||||
func_signature = get_func_signature(model.forward) | |||||
if not isinstance(pred_dict, dict): | |||||
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.") | |||||
# loss check | |||||
try: | |||||
loss = losser(pred_dict, batch_y) | |||||
# check loss output | |||||
if batch_count == 0: | |||||
if not isinstance(loss, torch.Tensor): | |||||
raise TypeError( | |||||
f"The return value of {get_func_signature(losser.get_loss)} should be `torch.Tensor`, " | |||||
f"but got `{type(loss)}`.") | |||||
if len(loss.size()) != 0: | |||||
raise ValueError( | |||||
f"The size of return value of {get_func_signature(losser.get_loss)} is {loss.size()}, " | |||||
f"should be torch.size([])") | |||||
loss.backward() | |||||
except CheckError as e: | |||||
# TODO: another error raised if CheckError caught | |||||
pre_func_signature = get_func_signature(model.forward) | |||||
_check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature, | |||||
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, | |||||
dataset=dataset, check_level=check_level) | |||||
model.zero_grad() | |||||
if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH: | |||||
break | |||||
if dev_data is not None: | |||||
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | |||||
batch_size=batch_size, verbose=-1) | |||||
evaluate_results = tester.test() | |||||
_check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics) | |||||
def _check_eval_results(metrics, metric_key, metric_list): | |||||
# metrics: tester返回的结果 | |||||
# metric_key: 一个用来做筛选的指标,来自Trainer的初始化 | |||||
# metric_list: 多个用来做评价的指标,来自Trainer的初始化 | |||||
if isinstance(metrics, tuple): | |||||
loss, metrics = metrics | |||||
if isinstance(metrics, dict): | |||||
if len(metrics) == 1: | |||||
# only single metric, just use it | |||||
metric_dict = list(metrics.values())[0] | |||||
metrics_name = list(metrics.keys())[0] | |||||
else: | |||||
metrics_name = metric_list[0].__class__.__name__ | |||||
if metrics_name not in metrics: | |||||
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") | |||||
metric_dict = metrics[metrics_name] | |||||
if len(metric_dict) == 1: | |||||
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] | |||||
elif len(metric_dict) > 1 and metric_key is None: | |||||
raise RuntimeError( | |||||
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?") | |||||
else: | |||||
# metric_key is set | |||||
if metric_key not in metric_dict: | |||||
raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}") | |||||
indicator_val = metric_dict[metric_key] | |||||
else: | |||||
raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics))) | |||||
return indicator_val |
@@ -0,0 +1,432 @@ | |||||
import _pickle | |||||
import inspect | |||||
import os | |||||
import warnings | |||||
from collections import Counter | |||||
from collections import namedtuple | |||||
import numpy as np | |||||
import torch | |||||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | |||||
'varargs'], verbose=False) | |||||
def save_pickle(obj, pickle_path, file_name): | |||||
"""Save an object into a pickle file. | |||||
:param obj: an object | |||||
:param pickle_path: str, the directory where the pickle file is to be saved | |||||
:param file_name: str, the name of the pickle file. In general, it should be ended by "pkl". | |||||
""" | |||||
if not os.path.exists(pickle_path): | |||||
os.mkdir(pickle_path) | |||||
print("make dir {} before saving pickle file".format(pickle_path)) | |||||
with open(os.path.join(pickle_path, file_name), "wb") as f: | |||||
_pickle.dump(obj, f) | |||||
print("{} saved in {}".format(file_name, pickle_path)) | |||||
def load_pickle(pickle_path, file_name): | |||||
"""Load an object from a given pickle file. | |||||
:param pickle_path: str, the directory where the pickle file is. | |||||
:param file_name: str, the name of the pickle file. | |||||
:return obj: an object stored in the pickle | |||||
""" | |||||
with open(os.path.join(pickle_path, file_name), "rb") as f: | |||||
obj = _pickle.load(f) | |||||
print("{} loaded from {}".format(file_name, pickle_path)) | |||||
return obj | |||||
def pickle_exist(pickle_path, pickle_name): | |||||
"""Check if a given pickle file exists in the directory. | |||||
:param pickle_path: the directory of target pickle file | |||||
:param pickle_name: the filename of target pickle file | |||||
:return: True if file exists else False | |||||
""" | |||||
if not os.path.exists(pickle_path): | |||||
os.makedirs(pickle_path) | |||||
file_name = os.path.join(pickle_path, pickle_name) | |||||
if os.path.exists(file_name): | |||||
return True | |||||
else: | |||||
return False | |||||
def _build_args(func, **kwargs): | |||||
spect = inspect.getfullargspec(func) | |||||
if spect.varkw is not None: | |||||
return kwargs | |||||
needed_args = set(spect.args) | |||||
defaults = [] | |||||
if spect.defaults is not None: | |||||
defaults = [arg for arg in spect.defaults] | |||||
start_idx = len(spect.args) - len(defaults) | |||||
output = {name: default for name, default in zip(spect.args[start_idx:], defaults)} | |||||
output.update({name: val for name, val in kwargs.items() if name in needed_args}) | |||||
return output | |||||
def _map_args(maps: dict, **kwargs): | |||||
# maps: key=old name, value= new name | |||||
output = {} | |||||
for name, val in kwargs.items(): | |||||
if name in maps: | |||||
assert isinstance(maps[name], str) | |||||
output.update({maps[name]: val}) | |||||
else: | |||||
output.update({name: val}) | |||||
for keys in maps.keys(): | |||||
if keys not in output.keys(): | |||||
# TODO: add UNUSED warning. | |||||
pass | |||||
return output | |||||
def _get_arg_list(func): | |||||
assert callable(func) | |||||
spect = inspect.getfullargspec(func) | |||||
if spect.defaults is not None: | |||||
args = spect.args[: -len(spect.defaults)] | |||||
defaults = spect.args[-len(spect.defaults):] | |||||
defaults_val = spect.defaults | |||||
else: | |||||
args = spect.args | |||||
defaults = None | |||||
defaults_val = None | |||||
varargs = spect.varargs | |||||
kwargs = spect.varkw | |||||
return args, defaults, defaults_val, varargs, kwargs | |||||
# check args | |||||
def _check_arg_dict_list(func, args): | |||||
if isinstance(args, dict): | |||||
arg_dict_list = [args] | |||||
else: | |||||
arg_dict_list = args | |||||
assert callable(func) and isinstance(arg_dict_list, (list, tuple)) | |||||
assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) | |||||
spect = inspect.getfullargspec(func) | |||||
all_args = set([arg for arg in spect.args if arg != 'self']) | |||||
defaults = [] | |||||
if spect.defaults is not None: | |||||
defaults = [arg for arg in spect.defaults] | |||||
start_idx = len(spect.args) - len(defaults) | |||||
default_args = set(spect.args[start_idx:]) | |||||
require_args = all_args - default_args | |||||
input_arg_count = Counter() | |||||
for arg_dict in arg_dict_list: | |||||
input_arg_count.update(arg_dict.keys()) | |||||
duplicated = [name for name, val in input_arg_count.items() if val > 1] | |||||
input_args = set(input_arg_count.keys()) | |||||
missing = list(require_args - input_args) | |||||
unused = list(input_args - all_args) | |||||
varargs = [] if not spect.varargs else [arg for arg in spect.varargs] | |||||
return CheckRes(missing=missing, | |||||
unused=unused, | |||||
duplicated=duplicated, | |||||
required=list(require_args), | |||||
all_needed=list(all_args), | |||||
varargs=varargs) | |||||
def get_func_signature(func): | |||||
""" | |||||
Given a function or method, return its signature. | |||||
For example: | |||||
(1) function | |||||
def func(a, b='a', *args): | |||||
xxxx | |||||
get_func_signature(func) # 'func(a, b='a', *args)' | |||||
(2) method | |||||
class Demo: | |||||
def __init__(self): | |||||
xxx | |||||
def forward(self, a, b='a', **args) | |||||
demo = Demo() | |||||
get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)' | |||||
:param func: a function or a method | |||||
:return: str or None | |||||
""" | |||||
if inspect.ismethod(func): | |||||
class_name = func.__self__.__class__.__name__ | |||||
signature = inspect.signature(func) | |||||
signature_str = str(signature) | |||||
if len(signature_str) > 2: | |||||
_self = '(self, ' | |||||
else: | |||||
_self = '(self' | |||||
signature_str = class_name + '.' + func.__name__ + _self + signature_str[1:] | |||||
return signature_str | |||||
elif inspect.isfunction(func): | |||||
signature = inspect.signature(func) | |||||
signature_str = str(signature) | |||||
signature_str = func.__name__ + signature_str | |||||
return signature_str | |||||
def _is_function_or_method(func): | |||||
""" | |||||
:param func: | |||||
:return: | |||||
""" | |||||
if not inspect.ismethod(func) and not inspect.isfunction(func): | |||||
return False | |||||
return True | |||||
def _check_function_or_method(func): | |||||
if not _is_function_or_method(func): | |||||
raise TypeError(f"{type(func)} is not a method or function.") | |||||
def _move_dict_value_to_device(*args, device: torch.device): | |||||
""" | |||||
move data to model's device, element in *args should be dict. This is a inplace change. | |||||
:param device: torch.device | |||||
:param args: | |||||
:return: | |||||
""" | |||||
if not isinstance(device, torch.device): | |||||
raise TypeError(f"device must be `torch.device`, got `{type(device)}`") | |||||
for arg in args: | |||||
if isinstance(arg, dict): | |||||
for key, value in arg.items(): | |||||
if isinstance(value, torch.Tensor): | |||||
arg[key] = value.to(device) | |||||
else: | |||||
raise TypeError("Only support `dict` type right now.") | |||||
class CheckError(Exception): | |||||
""" | |||||
CheckError. Used in losses.LossBase, metrics.MetricBase. | |||||
""" | |||||
def __init__(self, check_res: CheckRes, func_signature: str): | |||||
errs = [f'Problems occurred when calling `{func_signature}`'] | |||||
if check_res.varargs: | |||||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | |||||
if check_res.missing: | |||||
errs.append(f"\tmissing param: {check_res.missing}") | |||||
if check_res.duplicated: | |||||
errs.append(f"\tduplicated param: {check_res.duplicated}") | |||||
if check_res.unused: | |||||
errs.append(f"\tunused param: {check_res.unused}") | |||||
Exception.__init__(self, '\n'.join(errs)) | |||||
self.check_res = check_res | |||||
self.func_signature = func_signature | |||||
IGNORE_CHECK_LEVEL = 0 | |||||
WARNING_CHECK_LEVEL = 1 | |||||
STRICT_CHECK_LEVEL = 2 | |||||
def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_res: CheckRes, | |||||
pred_dict: dict, target_dict: dict, dataset, check_level=0): | |||||
errs = [] | |||||
unuseds = [] | |||||
_unused_field = [] | |||||
_unused_param = [] | |||||
suggestions = [] | |||||
if check_res.varargs: | |||||
errs.append(f"\tvarargs: *{check_res.varargs}") | |||||
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | |||||
if check_res.unused: | |||||
for _unused in check_res.unused: | |||||
if _unused in target_dict: | |||||
_unused_field.append(_unused) | |||||
else: | |||||
_unused_param.append(_unused) | |||||
if _unused_field: | |||||
unuseds.append(f"\tunused field: {_unused_field}") | |||||
if _unused_param: | |||||
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward | |||||
module_name = func_signature.split('.')[0] | |||||
if check_res.missing: | |||||
errs.append(f"\tmissing param: {check_res.missing}") | |||||
import re | |||||
mapped_missing = [] | |||||
unmapped_missing = [] | |||||
input_func_map = {} | |||||
for _miss in check_res.missing: | |||||
if '(' in _miss: | |||||
# if they are like 'SomeParam(assign to xxx)' | |||||
_miss = _miss.split('(')[0] | |||||
matches = re.findall("(?<=`)[a-zA-Z0-9]*?(?=`)", _miss) | |||||
if len(matches) == 2: | |||||
fun_arg, module_name = matches | |||||
input_func_map[_miss] = fun_arg | |||||
if fun_arg == _miss: | |||||
unmapped_missing.append(_miss) | |||||
else: | |||||
mapped_missing.append(_miss) | |||||
else: | |||||
unmapped_missing.append(_miss) | |||||
for _miss in mapped_missing: | |||||
if _miss in dataset: | |||||
suggestions.append(f"Set {_miss} as target.") | |||||
else: | |||||
_tmp = '' | |||||
if check_res.unused: | |||||
_tmp = f"Check key assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}." | |||||
if _tmp: | |||||
_tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.' | |||||
else: | |||||
_tmp = f'Provide {_miss} in DataSet or output of {prev_func_signature}.' | |||||
suggestions.append(_tmp) | |||||
for _miss in unmapped_missing: | |||||
if _miss in dataset: | |||||
suggestions.append(f"Set {_miss} as target.") | |||||
else: | |||||
_tmp = '' | |||||
if check_res.unused: | |||||
_tmp = f"Specify your assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}." | |||||
if _tmp: | |||||
_tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.' | |||||
else: | |||||
_tmp = f'Provide {_miss} in output of {prev_func_signature} or DataSet.' | |||||
suggestions.append(_tmp) | |||||
if check_res.duplicated: | |||||
errs.append(f"\tduplicated param: {check_res.duplicated}.") | |||||
suggestions.append(f"Delete {check_res.duplicated} in the output of " | |||||
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | |||||
if len(errs)>0: | |||||
errs.extend(unuseds) | |||||
elif check_level == STRICT_CHECK_LEVEL: | |||||
errs.extend(unuseds) | |||||
if len(errs) > 0: | |||||
errs.insert(0, f'Problems occurred when calling {func_signature}') | |||||
sugg_str = "" | |||||
if len(suggestions) > 1: | |||||
for idx, sugg in enumerate(suggestions): | |||||
if idx>0: | |||||
sugg_str += '\t\t\t' | |||||
sugg_str += f'({idx+1}). {sugg}\n' | |||||
sugg_str = sugg_str[:-1] | |||||
else: | |||||
sugg_str += suggestions[0] | |||||
errs.append(f'\ttarget field: {list(target_dict.keys())}') | |||||
errs.append(f'\tparam from {prev_func_signature}: {list(pred_dict.keys())}') | |||||
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str | |||||
raise NameError(err_str) | |||||
if check_res.unused: | |||||
if check_level == WARNING_CHECK_LEVEL: | |||||
if not module_name: | |||||
module_name = func_signature.split('.')[0] | |||||
_unused_warn = f'{check_res.unused} is not used by {module_name}.' | |||||
warnings.warn(message=_unused_warn) | |||||
def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||||
check_res = _check_arg_dict_list(forward_func, batch_x) | |||||
func_signature = get_func_signature(forward_func) | |||||
errs = [] | |||||
suggestions = [] | |||||
_unused = [] | |||||
if check_res.varargs: | |||||
errs.append(f"\tvarargs: {check_res.varargs}") | |||||
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | |||||
if check_res.missing: | |||||
errs.append(f"\tmissing param: {check_res.missing}") | |||||
_miss_in_dataset = [] | |||||
_miss_out_dataset = [] | |||||
for _miss in check_res.missing: | |||||
if _miss in dataset: | |||||
_miss_in_dataset.append(_miss) | |||||
else: | |||||
_miss_out_dataset.append(_miss) | |||||
if _miss_in_dataset: | |||||
suggestions.append(f"You might need to set {_miss_in_dataset} as input. ") | |||||
if _miss_out_dataset: | |||||
_tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. " | |||||
# if check_res.unused: | |||||
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ | |||||
# f"rename the field in `unused field:`." | |||||
suggestions.append(_tmp) | |||||
if check_res.unused: | |||||
_unused = [f"\tunused field: {check_res.unused}"] | |||||
if len(errs)>0: | |||||
errs.extend(_unused) | |||||
elif check_level == STRICT_CHECK_LEVEL: | |||||
errs.extend(_unused) | |||||
if len(errs) > 0: | |||||
errs.insert(0, f'Problems occurred when calling {func_signature}') | |||||
sugg_str = "" | |||||
if len(suggestions) > 1: | |||||
for idx, sugg in enumerate(suggestions): | |||||
sugg_str += f'({idx+1}). {sugg}' | |||||
else: | |||||
sugg_str += suggestions[0] | |||||
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str | |||||
raise NameError(err_str) | |||||
if _unused: | |||||
if check_level == WARNING_CHECK_LEVEL: | |||||
_unused_warn = _unused[0] + f' in {func_signature}.' | |||||
warnings.warn(message=_unused_warn) | |||||
def seq_lens_to_masks(seq_lens, float=False): | |||||
""" | |||||
Convert seq_lens to masks. | |||||
:param seq_lens: list, np.ndarray, or torch.LongTensor, shape should all be (B,) | |||||
:param float: if True, the return masks is in float type, otherwise it is byte. | |||||
:return: list, np.ndarray or torch.Tensor, shape will be (B, max_length) | |||||
""" | |||||
if isinstance(seq_lens, np.ndarray): | |||||
assert len(np.shape(seq_lens)) == 1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}." | |||||
assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}." | |||||
raise NotImplemented | |||||
elif isinstance(seq_lens, torch.LongTensor): | |||||
assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}." | |||||
batch_size = seq_lens.size(0) | |||||
max_len = seq_lens.max() | |||||
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) | |||||
masks = indexes.lt(seq_lens.unsqueeze(1)) | |||||
if float: | |||||
masks = masks.float() | |||||
return masks | |||||
elif isinstance(seq_lens, list): | |||||
raise NotImplemented | |||||
else: | |||||
raise NotImplemented | |||||
def seq_mask(seq_len, max_len): | |||||
"""Create sequence mask. | |||||
:param seq_len: list or torch.Tensor, the lengths of sequences in a batch. | |||||
:param max_len: int, the maximum sequence length in a batch. | |||||
:return mask: torch.LongTensor, [batch_size, max_len] | |||||
""" | |||||
if not isinstance(seq_len, torch.Tensor): | |||||
seq_len = torch.LongTensor(seq_len) | |||||
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] | |||||
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] | |||||
return torch.gt(seq_len, seq_range) # [batch_size, max_len] |
@@ -1,32 +1,33 @@ | |||||
from copy import deepcopy | |||||
from collections import Counter | |||||
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | |||||
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | |||||
DEFAULT_RESERVED_LABEL = ['<reserved-2>', | |||||
'<reserved-3>', | |||||
'<reserved-4>'] # dict index = 2~4 | |||||
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, | |||||
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, | |||||
DEFAULT_RESERVED_LABEL[2]: 4} | |||||
def check_build_vocab(func): | |||||
"""A decorator to make sure the indexing is built before used. | |||||
""" | |||||
def _wrapper(self, *args, **kwargs): | |||||
if self.word2idx is None or self.rebuild is True: | |||||
self.build_vocab() | |||||
return func(self, *args, **kwargs) | |||||
return _wrapper | |||||
def isiterable(p_object): | |||||
try: | |||||
it = iter(p_object) | |||||
except TypeError: | |||||
return False | |||||
return True | |||||
def check_build_status(func): | |||||
"""A decorator to check whether the vocabulary updates after the last build. | |||||
""" | |||||
def check_build_vocab(func): | |||||
def _wrapper(self, *args, **kwargs): | def _wrapper(self, *args, **kwargs): | ||||
if self.word2idx is None: | |||||
self.build_vocab() | |||||
self.build_reverse_vocab() | |||||
elif self.idx2word is None: | |||||
self.build_reverse_vocab() | |||||
if self.rebuild is False: | |||||
self.rebuild = True | |||||
if self.max_size is not None and len(self.word_count) >= self.max_size: | |||||
print("[Warning] Vocabulary has reached the max size {} when calling {} method. " | |||||
"Adding more words may cause unexpected behaviour of Vocabulary. ".format( | |||||
self.max_size, func.__name__)) | |||||
return func(self, *args, **kwargs) | return func(self, *args, **kwargs) | ||||
return _wrapper | return _wrapper | ||||
@@ -41,69 +42,95 @@ class Vocabulary(object): | |||||
vocab["word"] | vocab["word"] | ||||
vocab.to_word(5) | vocab.to_word(5) | ||||
""" | """ | ||||
def __init__(self, need_default=True, max_size=None, min_freq=None): | |||||
def __init__(self, max_size=None, min_freq=None, unknown='<unk>', padding='<pad>'): | |||||
""" | """ | ||||
:param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. | |||||
:param int max_size: set the max number of words in Vocabulary. Default: None | :param int max_size: set the max number of words in Vocabulary. Default: None | ||||
:param int min_freq: set the min occur frequency of words in Vocabulary. Default: None | :param int min_freq: set the min occur frequency of words in Vocabulary. Default: None | ||||
""" | """ | ||||
self.max_size = max_size | self.max_size = max_size | ||||
self.min_freq = min_freq | self.min_freq = min_freq | ||||
self.word_count = {} | |||||
self.has_default = need_default | |||||
self.word_count = Counter() | |||||
self.unknown = unknown | |||||
self.padding = padding | |||||
self.word2idx = None | self.word2idx = None | ||||
self.idx2word = None | self.idx2word = None | ||||
self.rebuild = True | |||||
def update(self, word): | |||||
"""add word or list of words into Vocabulary | |||||
@check_build_status | |||||
def update(self, word_lst): | |||||
"""Add a list of words into the vocabulary. | |||||
:param word: a list of string or a single string | |||||
:param list word_lst: a list of strings | |||||
""" | """ | ||||
if not isinstance(word, str) and isiterable(word): | |||||
# it's a nested list | |||||
for w in word: | |||||
self.update(w) | |||||
else: | |||||
# it's a word to be added | |||||
if word not in self.word_count: | |||||
self.word_count[word] = 1 | |||||
else: | |||||
self.word_count[word] += 1 | |||||
self.word2idx = None | |||||
return self | |||||
self.word_count.update(word_lst) | |||||
def build_vocab(self): | |||||
"""build 'word to index' dict, and filter the word using `max_size` and `min_freq` | |||||
@check_build_status | |||||
def add(self, word): | |||||
"""Add a single word into the vocabulary. | |||||
:param str word: a word or token. | |||||
""" | """ | ||||
if self.has_default: | |||||
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | |||||
self.padding_label = DEFAULT_PADDING_LABEL | |||||
self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||||
else: | |||||
self.word2idx = {} | |||||
self.padding_label = None | |||||
self.unknown_label = None | |||||
self.word_count[word] += 1 | |||||
@check_build_status | |||||
def add_word(self, word): | |||||
"""Add a single word into the vocabulary. | |||||
:param str word: a word or token. | |||||
""" | |||||
self.add(word) | |||||
@check_build_status | |||||
def add_word_lst(self, word_lst): | |||||
"""Add a list of words into the vocabulary. | |||||
:param list word_lst: a list of strings | |||||
""" | |||||
self.update(word_lst) | |||||
def build_vocab(self): | |||||
"""Build 'word to index' dict, and filter the word using `max_size` and `min_freq`. | |||||
words = sorted(self.word_count.items(), key=lambda kv: kv[1], reverse=True) | |||||
""" | |||||
self.word2idx = {} | |||||
if self.padding is not None: | |||||
self.word2idx[self.padding] = 0 | |||||
if self.unknown is not None: | |||||
self.word2idx[self.unknown] = 1 | |||||
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | |||||
words = self.word_count.most_common(max_size) | |||||
if self.min_freq is not None: | if self.min_freq is not None: | ||||
words = list(filter(lambda kv: kv[1] >= self.min_freq, words)) | |||||
if self.max_size is not None and len(words) > self.max_size: | |||||
words = words[:self.max_size] | |||||
for w, _ in words: | |||||
self.word2idx[w] = len(self.word2idx) | |||||
words = filter(lambda kv: kv[1] >= self.min_freq, words) | |||||
if self.word2idx is not None: | |||||
words = filter(lambda kv: kv[0] not in self.word2idx, words) | |||||
start_idx = len(self.word2idx) | |||||
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | |||||
self.build_reverse_vocab() | |||||
self.rebuild = False | |||||
def build_reverse_vocab(self): | def build_reverse_vocab(self): | ||||
"""build 'index to word' dict based on 'word to index' dict | |||||
"""Build 'index to word' dict based on 'word to index' dict. | |||||
""" | """ | ||||
self.idx2word = {self.word2idx[w] : w for w in self.word2idx} | |||||
self.idx2word = {i: w for w, i in self.word2idx.items()} | |||||
@check_build_vocab | @check_build_vocab | ||||
def __len__(self): | def __len__(self): | ||||
return len(self.word2idx) | return len(self.word2idx) | ||||
@check_build_vocab | @check_build_vocab | ||||
def __contains__(self, item): | |||||
"""Check if a word in vocabulary. | |||||
:param item: the word | |||||
:return: True or False | |||||
""" | |||||
return item in self.word2idx | |||||
def has_word(self, w): | def has_word(self, w): | ||||
return w in self.word2idx | |||||
return self.__contains__(w) | |||||
@check_build_vocab | @check_build_vocab | ||||
def __getitem__(self, w): | def __getitem__(self, w): | ||||
@@ -113,46 +140,45 @@ class Vocabulary(object): | |||||
""" | """ | ||||
if w in self.word2idx: | if w in self.word2idx: | ||||
return self.word2idx[w] | return self.word2idx[w] | ||||
elif self.has_default: | |||||
return self.word2idx[DEFAULT_UNKNOWN_LABEL] | |||||
if self.unknown is not None: | |||||
return self.word2idx[self.unknown] | |||||
else: | else: | ||||
raise ValueError("word {} not in vocabulary".format(w)) | raise ValueError("word {} not in vocabulary".format(w)) | ||||
@check_build_vocab | |||||
def to_index(self, w): | def to_index(self, w): | ||||
""" like to_index(w) function, turn a word to the index | |||||
if w is not in Vocabulary, return the unknown label | |||||
""" Turn a word to an index. | |||||
If w is not in Vocabulary, return the unknown label. | |||||
:param str w: | :param str w: | ||||
""" | """ | ||||
return self[w] | |||||
return self.__getitem__(w) | |||||
@property | @property | ||||
@check_build_vocab | @check_build_vocab | ||||
def unknown_idx(self): | def unknown_idx(self): | ||||
if self.unknown_label is None: | |||||
if self.unknown is None: | |||||
return None | return None | ||||
return self.word2idx[self.unknown_label] | |||||
return self.word2idx[self.unknown] | |||||
@property | @property | ||||
@check_build_vocab | @check_build_vocab | ||||
def padding_idx(self): | def padding_idx(self): | ||||
if self.padding_label is None: | |||||
if self.padding is None: | |||||
return None | return None | ||||
return self.word2idx[self.padding_label] | |||||
return self.word2idx[self.padding] | |||||
@check_build_vocab | @check_build_vocab | ||||
def to_word(self, idx): | def to_word(self, idx): | ||||
"""given a word's index, return the word itself | """given a word's index, return the word itself | ||||
:param int idx: | |||||
:param int idx: the index | |||||
:return str word: the indexed word | |||||
""" | """ | ||||
if self.idx2word is None: | |||||
self.build_reverse_vocab() | |||||
return self.idx2word[idx] | return self.idx2word[idx] | ||||
def __getstate__(self): | def __getstate__(self): | ||||
"""use to prepare data for pickle | |||||
"""Use to prepare data for pickle. | |||||
""" | """ | ||||
state = self.__dict__.copy() | state = self.__dict__.copy() | ||||
# no need to pickle idx2word as it can be constructed from word2idx | # no need to pickle idx2word as it can be constructed from word2idx | ||||
@@ -160,15 +186,8 @@ class Vocabulary(object): | |||||
return state | return state | ||||
def __setstate__(self, state): | def __setstate__(self, state): | ||||
"""use to restore state from pickle | |||||
""" | |||||
self.__dict__.update(state) | |||||
self.idx2word = None | |||||
"""Use to restore state from pickle. | |||||
def __contains__(self, item): | |||||
"""Check if a word in vocabulary. | |||||
:param item: the word | |||||
:return: True or False | |||||
""" | """ | ||||
return self.has_word(item) | |||||
self.__dict__.update(state) | |||||
self.build_reverse_vocab() |
@@ -1,343 +0,0 @@ | |||||
import os | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.loader.dataset_loader import convert_seq_dataset | |||||
from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer | |||||
from fastNLP.core.preprocess import load_pickle | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.loader.model_loader import ModelLoader | |||||
""" | |||||
mapping from model name to [URL, file_name.class_name, model_pickle_name] | |||||
Notice that the class of the model should be in "models" directory. | |||||
Example: | |||||
"seq_label_model": { | |||||
"url": "www.fudan.edu.cn", | |||||
"class": "sequence_modeling.SeqLabeling", # file_name.class_name in models/ | |||||
"pickle": "seq_label_model.pkl", | |||||
"type": "seq_label", | |||||
"config_file_name": "config", # the name of the config file which stores model initialization parameters | |||||
"config_section_name": "text_class_model" # the name of the section in the config file which stores model init params | |||||
}, | |||||
"text_class_model": { | |||||
"url": "www.fudan.edu.cn", | |||||
"class": "cnn_text_classification.CNNText", | |||||
"pickle": "text_class_model.pkl", | |||||
"type": "text_class" | |||||
} | |||||
""" | |||||
FastNLP_MODEL_COLLECTION = { | |||||
"cws_basic_model": { | |||||
"url": "", | |||||
"class": "sequence_modeling.AdvSeqLabel", | |||||
"pickle": "cws_basic_model_v_0.pkl", | |||||
"type": "seq_label", | |||||
"config_file_name": "cws.cfg", | |||||
"config_section_name": "text_class_model" | |||||
}, | |||||
"pos_tag_model": { | |||||
"url": "", | |||||
"class": "sequence_modeling.AdvSeqLabel", | |||||
"pickle": "pos_tag_model_v_0.pkl", | |||||
"type": "seq_label", | |||||
"config_file_name": "pos_tag.cfg", | |||||
"config_section_name": "pos_tag_model" | |||||
}, | |||||
"text_classify_model": { | |||||
"url": "", | |||||
"class": "cnn_text_classification.CNNText", | |||||
"pickle": "text_class_model_v0.pkl", | |||||
"type": "text_class", | |||||
"config_file_name": "text_classify.cfg", | |||||
"config_section_name": "model" | |||||
} | |||||
} | |||||
class FastNLP(object): | |||||
""" | |||||
High-level interface for direct model inference. | |||||
Example Usage | |||||
:: | |||||
fastnlp = FastNLP() | |||||
fastnlp.load("zh_pos_tag_model") | |||||
text = "这是最好的基于深度学习的中文分词系统。" | |||||
result = fastnlp.run(text) | |||||
print(result) # ["这", "是", "最好", "的", "基于", "深度学习", "的", "中文", "分词", "系统", "。"] | |||||
""" | |||||
def __init__(self, model_dir="./"): | |||||
""" | |||||
:param model_dir: this directory should contain the following files: | |||||
1. a trained model | |||||
2. a config file, which is a fastNLP's configuration. | |||||
3. two Vocab files, which are pickle objects of Vocab instances, representing feature and label vocabs. | |||||
""" | |||||
self.model_dir = model_dir | |||||
self.model = None | |||||
self.infer_type = None # "seq_label"/"text_class" | |||||
self.word_vocab = None | |||||
self.label_vocab = None | |||||
def load(self, model_name, config_file="config", section_name="model"): | |||||
""" | |||||
Load a pre-trained FastNLP model together with additional data. | |||||
:param model_name: str, the name of a FastNLP model. | |||||
:param config_file: str, the name of the config file which stores the initialization information of the model. | |||||
(default: "config") | |||||
:param section_name: str, the name of the corresponding section in the config file. (default: model) | |||||
""" | |||||
assert type(model_name) is str | |||||
if model_name not in FastNLP_MODEL_COLLECTION: | |||||
raise ValueError("No FastNLP model named {}.".format(model_name)) | |||||
if not self.model_exist(model_dir=self.model_dir): | |||||
self._download(model_name, FastNLP_MODEL_COLLECTION[model_name]["url"]) | |||||
model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"]) | |||||
print("Restore model class {}".format(str(model_class))) | |||||
model_args = ConfigSection() | |||||
ConfigLoader.load_config(os.path.join(self.model_dir, config_file), {section_name: model_args}) | |||||
print("Restore model hyper-parameters {}".format(str(model_args.data))) | |||||
# fetch dictionary size and number of labels from pickle files | |||||
self.word_vocab = load_pickle(self.model_dir, "word2id.pkl") | |||||
model_args["vocab_size"] = len(self.word_vocab) | |||||
self.label_vocab = load_pickle(self.model_dir, "label2id.pkl") | |||||
model_args["num_classes"] = len(self.label_vocab) | |||||
# Construct the model | |||||
model = model_class(model_args) | |||||
print("Model constructed.") | |||||
# To do: framework independent | |||||
ModelLoader.load_pytorch(model, os.path.join(self.model_dir, FastNLP_MODEL_COLLECTION[model_name]["pickle"])) | |||||
print("Model weights loaded.") | |||||
self.model = model | |||||
self.infer_type = FastNLP_MODEL_COLLECTION[model_name]["type"] | |||||
print("Inference ready.") | |||||
def run(self, raw_input): | |||||
""" | |||||
Perform inference over given input using the loaded model. | |||||
:param raw_input: list of string. Each list is an input query. | |||||
:return results: | |||||
""" | |||||
infer = self._create_inference(self.model_dir) | |||||
# tokenize: list of string ---> 2-D list of string | |||||
infer_input = self.tokenize(raw_input, language="zh") | |||||
# create DataSet: 2-D list of strings ----> DataSet | |||||
infer_data = self._create_data_set(infer_input) | |||||
# DataSet ---> 2-D list of tags | |||||
results = infer.predict(self.model, infer_data) | |||||
# 2-D list of tags ---> list of final answers | |||||
outputs = self._make_output(results, infer_input) | |||||
return outputs | |||||
@staticmethod | |||||
def _get_model_class(file_class_name): | |||||
""" | |||||
Feature the class specified by <file_class_name> | |||||
:param file_class_name: str, contains the name of the Python module followed by the name of the class. | |||||
Example: "sequence_modeling.SeqLabeling" | |||||
:return module: the model class | |||||
""" | |||||
import_prefix = "fastNLP.models." | |||||
parts = (import_prefix + file_class_name).split(".") | |||||
from_module = ".".join(parts[:-1]) | |||||
module = __import__(from_module) | |||||
for sub in parts[1:]: | |||||
module = getattr(module, sub) | |||||
return module | |||||
def _create_inference(self, model_dir): | |||||
"""Specify which task to perform. | |||||
:param model_dir: | |||||
:return: | |||||
""" | |||||
if self.infer_type == "seq_label": | |||||
return SeqLabelInfer(model_dir) | |||||
elif self.infer_type == "text_class": | |||||
return ClassificationInfer(model_dir) | |||||
else: | |||||
raise ValueError("fail to create inference instance") | |||||
def _create_data_set(self, infer_input): | |||||
"""Create a DataSet object given the raw inputs. | |||||
:param infer_input: 2-D lists of strings | |||||
:return data_set: a DataSet object | |||||
""" | |||||
if self.infer_type in ["seq_label", "text_class"]: | |||||
data_set = convert_seq_dataset(infer_input) | |||||
data_set.index_field("word_seq", self.word_vocab) | |||||
if self.infer_type == "seq_label": | |||||
data_set.set_origin_len("word_seq") | |||||
return data_set | |||||
else: | |||||
raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type)) | |||||
def _load(self, model_dir, model_name): | |||||
return 0 | |||||
def _download(self, model_name, url): | |||||
""" | |||||
Download the model weights from <url> and save in <self.model_dir>. | |||||
:param model_name: | |||||
:param url: | |||||
""" | |||||
print("Downloading {} from {}".format(model_name, url)) | |||||
# TODO: download model via url | |||||
def model_exist(self, model_dir): | |||||
""" | |||||
Check whether the desired model is already in the directory. | |||||
:param model_dir: | |||||
""" | |||||
return True | |||||
def tokenize(self, text, language): | |||||
"""Extract tokens from strings. | |||||
For English, extract words separated by space. | |||||
For Chinese, extract characters. | |||||
TODO: more complex tokenization methods | |||||
:param text: list of string | |||||
:param language: str, one of ('zh', 'en'), Chinese or English. | |||||
:return data: list of list of string, each string is a token. | |||||
""" | |||||
assert language in ("zh", "en") | |||||
data = [] | |||||
for sent in text: | |||||
if language == "en": | |||||
tokens = sent.strip().split() | |||||
elif language == "zh": | |||||
tokens = [char for char in sent] | |||||
else: | |||||
raise RuntimeError("Unknown language {}".format(language)) | |||||
data.append(tokens) | |||||
return data | |||||
def _make_output(self, results, infer_input): | |||||
"""Transform the infer output into user-friendly output. | |||||
:param results: 1 or 2-D list of strings. | |||||
If self.infer_type == "seq_label", it is of shape [num_examples, tag_seq_length] | |||||
If self.infer_type == "text_class", it is of shape [num_examples] | |||||
:param infer_input: 2-D list of string, the input query before inference. | |||||
:return outputs: list. Each entry is a prediction. | |||||
""" | |||||
if self.infer_type == "seq_label": | |||||
outputs = make_seq_label_output(results, infer_input) | |||||
elif self.infer_type == "text_class": | |||||
outputs = make_class_output(results, infer_input) | |||||
else: | |||||
raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type)) | |||||
return outputs | |||||
def make_seq_label_output(result, infer_input): | |||||
"""Transform model output into user-friendly contents. | |||||
:param result: 2-D list of strings. (model output) | |||||
:param infer_input: 2-D list of string (model input) | |||||
:return ret: list of list of tuples | |||||
[ | |||||
[(word_11, label_11), (word_12, label_12), ...], | |||||
[(word_21, label_21), (word_22, label_22), ...], | |||||
... | |||||
] | |||||
""" | |||||
ret = [] | |||||
for example_x, example_y in zip(infer_input, result): | |||||
ret.append([(x, y) for x, y in zip(example_x, example_y)]) | |||||
return ret | |||||
def make_class_output(result, infer_input): | |||||
"""Transform model output into user-friendly contents. | |||||
:param result: 2-D list of strings. (model output) | |||||
:param infer_input: 1-D list of string (model input) | |||||
:return ret: the same as result, [label_1, label_2, ...] | |||||
""" | |||||
return result | |||||
def interpret_word_seg_results(char_seq, label_seq): | |||||
"""Transform model output into user-friendly contents. | |||||
Example: In CWS, convert <BMES> labeling into segmented text. | |||||
:param char_seq: list of string, | |||||
:param label_seq: list of string, the same length as char_seq | |||||
Each entry is one of ('B', 'M', 'E', 'S'). | |||||
:return output: list of words | |||||
""" | |||||
words = [] | |||||
word = "" | |||||
for char, label in zip(char_seq, label_seq): | |||||
if label[0] == "B": | |||||
if word != "": | |||||
words.append(word) | |||||
word = char | |||||
elif label[0] == "M": | |||||
word += char | |||||
elif label[0] == "E": | |||||
word += char | |||||
words.append(word) | |||||
word = "" | |||||
elif label[0] == "S": | |||||
if word != "": | |||||
words.append(word) | |||||
word = "" | |||||
words.append(char) | |||||
else: | |||||
raise ValueError("invalid label {}".format(label[0])) | |||||
return words | |||||
def interpret_cws_pos_results(char_seq, label_seq): | |||||
"""Transform model output into user-friendly contents. | |||||
:param char_seq: list of string | |||||
:param label_seq: list of string, the same length as char_seq. | |||||
:return outputs: list of tuple (words, pos_tag): | |||||
""" | |||||
def pos_tag_check(seq): | |||||
"""check whether all entries are the same """ | |||||
return len(set(seq)) <= 1 | |||||
word = [] | |||||
word_pos = [] | |||||
outputs = [] | |||||
for char, label in zip(char_seq, label_seq): | |||||
tmp = label.split("-") | |||||
cws_label, pos_tag = tmp[0], tmp[1] | |||||
if cws_label == "B" or cws_label == "M": | |||||
word.append(char) | |||||
word_pos.append(pos_tag) | |||||
elif cws_label == "E": | |||||
word.append(char) | |||||
word_pos.append(pos_tag) | |||||
if not pos_tag_check(word_pos): | |||||
raise RuntimeError("character-wise pos tags inconsistent. ") | |||||
outputs.append(("".join(word), word_pos[0])) | |||||
word.clear() | |||||
word_pos.clear() | |||||
elif cws_label == "S": | |||||
outputs.append((char, pos_tag)) | |||||
return outputs |
@@ -0,0 +1,51 @@ | |||||
import _pickle as pickle | |||||
import os | |||||
class BaseLoader(object): | |||||
def __init__(self): | |||||
super(BaseLoader, self).__init__() | |||||
@staticmethod | |||||
def load_lines(data_path): | |||||
with open(data_path, "r", encoding="utf=8") as f: | |||||
text = f.readlines() | |||||
return [line.strip() for line in text] | |||||
@classmethod | |||||
def load(cls, data_path): | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
text = f.readlines() | |||||
return [[word for word in sent.strip()] for sent in text] | |||||
@classmethod | |||||
def load_with_cache(cls, data_path, cache_path): | |||||
if os.path.isfile(cache_path) and os.path.getmtime(data_path) < os.path.getmtime(cache_path): | |||||
with open(cache_path, 'rb') as f: | |||||
return pickle.load(f) | |||||
else: | |||||
obj = cls.load(data_path) | |||||
with open(cache_path, 'wb') as f: | |||||
pickle.dump(obj, f) | |||||
return obj | |||||
class DataLoaderRegister: | |||||
""""register for data sets""" | |||||
_readers = {} | |||||
@classmethod | |||||
def set_reader(cls, reader_cls, read_fn_name): | |||||
# def wrapper(reader_cls): | |||||
if read_fn_name in cls._readers: | |||||
raise KeyError('duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls, read_fn_name)) | |||||
if hasattr(reader_cls, 'load'): | |||||
cls._readers[read_fn_name] = reader_cls().load | |||||
return reader_cls | |||||
@classmethod | |||||
def get_reader(cls, read_fn_name): | |||||
if read_fn_name in cls._readers: | |||||
return cls._readers[read_fn_name] | |||||
raise AttributeError('no read function: {}'.format(read_fn_name)) |
@@ -0,0 +1,295 @@ | |||||
import configparser | |||||
import json | |||||
import os | |||||
from fastNLP.io.base_loader import BaseLoader | |||||
class ConfigLoader(BaseLoader): | |||||
"""loader for configuration files""" | |||||
def __init__(self, data_path=None): | |||||
super(ConfigLoader, self).__init__() | |||||
if data_path is not None: | |||||
self.config = self.parse(super(ConfigLoader, self).load(data_path)) | |||||
@staticmethod | |||||
def parse(string): | |||||
raise NotImplementedError | |||||
@staticmethod | |||||
def load_config(file_path, sections): | |||||
""" | |||||
:param file_path: the path of config file | |||||
:param sections: the dict of {section_name(string): Section instance} | |||||
Example: | |||||
test_args = ConfigSection() | |||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
:return: return nothing, but the value of attributes are saved in sessions | |||||
""" | |||||
assert isinstance(sections, dict) | |||||
cfg = configparser.ConfigParser() | |||||
if not os.path.exists(file_path): | |||||
raise FileNotFoundError("config file {} not found. ".format(file_path)) | |||||
cfg.read(file_path) | |||||
for s in sections: | |||||
attr_list = [i for i in sections[s].__dict__.keys() if | |||||
not callable(getattr(sections[s], i)) and not i.startswith("__")] | |||||
if s not in cfg: | |||||
print('section %s not found in config file' % (s)) | |||||
continue | |||||
gen_sec = cfg[s] | |||||
for attr in gen_sec.keys(): | |||||
try: | |||||
val = json.loads(gen_sec[attr]) | |||||
# print(s, attr, val, type(val)) | |||||
if attr in attr_list: | |||||
assert type(val) == type(getattr(sections[s], attr)), \ | |||||
'type not match, except %s but got %s' % \ | |||||
(type(getattr(sections[s], attr)), type(val)) | |||||
""" | |||||
if attr in attr_list then check its type and | |||||
update its value. | |||||
else add a new attr in sections[s] | |||||
""" | |||||
setattr(sections[s], attr, val) | |||||
except Exception as e: | |||||
print("cannot load attribute %s in section %s" | |||||
% (attr, s)) | |||||
pass | |||||
class ConfigSection(object): | |||||
def __init__(self): | |||||
pass | |||||
def __getitem__(self, key): | |||||
""" | |||||
:param key: str, the name of the attribute | |||||
:return attr: the value of this attribute | |||||
if key not in self.__dict__.keys(): | |||||
return self[key] | |||||
else: | |||||
raise AttributeError | |||||
""" | |||||
if key in self.__dict__.keys(): | |||||
return getattr(self, key) | |||||
raise AttributeError("do NOT have attribute %s" % key) | |||||
def __setitem__(self, key, value): | |||||
""" | |||||
:param key: str, the name of the attribute | |||||
:param value: the value of this attribute | |||||
if key not in self.__dict__.keys(): | |||||
self[key] will be added | |||||
else: | |||||
self[key] will be updated | |||||
""" | |||||
if key in self.__dict__.keys(): | |||||
if not isinstance(value, type(getattr(self, key))): | |||||
raise AttributeError("attr %s except %s but got %s" % | |||||
(key, str(type(getattr(self, key))), str(type(value)))) | |||||
setattr(self, key, value) | |||||
def __contains__(self, item): | |||||
""" | |||||
:param item: The key of item. | |||||
:return: True if the key in self.__dict__.keys() else False. | |||||
""" | |||||
return item in self.__dict__.keys() | |||||
def __eq__(self, other): | |||||
"""Overwrite the == operator | |||||
:param other: Another ConfigSection() object which to be compared. | |||||
:return: True if value of each key in each ConfigSection() object are equal to the other, else False. | |||||
""" | |||||
for k in self.__dict__.keys(): | |||||
if k not in other.__dict__.keys(): | |||||
return False | |||||
if getattr(self, k) != getattr(self, k): | |||||
return False | |||||
for k in other.__dict__.keys(): | |||||
if k not in self.__dict__.keys(): | |||||
return False | |||||
if getattr(self, k) != getattr(self, k): | |||||
return False | |||||
return True | |||||
def __ne__(self, other): | |||||
"""Overwrite the != operator | |||||
:param other: | |||||
:return: | |||||
""" | |||||
return not self.__eq__(other) | |||||
@property | |||||
def data(self): | |||||
return self.__dict__ | |||||
if __name__ == "__main__": | |||||
config = ConfigLoader('there is no data') | |||||
section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()} | |||||
""" | |||||
General and My can be found in config file, so the attr and | |||||
value will be updated | |||||
A cannot be found in config file, so nothing will be done | |||||
""" | |||||
config.load_config("../../test/data_for_tests/config", section) | |||||
for s in section: | |||||
print(s) | |||||
for attr in section[s].__dict__.keys(): | |||||
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr))) | |||||
class ConfigSaver(object): | |||||
def __init__(self, file_path): | |||||
self.file_path = file_path | |||||
if not os.path.exists(self.file_path): | |||||
raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) | |||||
def _get_section(self, sect_name): | |||||
"""This is the function to get the section with the section name. | |||||
:param sect_name: The name of section what wants to load. | |||||
:return: The section. | |||||
""" | |||||
sect = ConfigSection() | |||||
ConfigLoader().load_config(self.file_path, {sect_name: sect}) | |||||
return sect | |||||
def _read_section(self): | |||||
"""This is the function to read sections from the config file. | |||||
:return: sect_list, sect_key_list | |||||
sect_list: A list of ConfigSection(). | |||||
sect_key_list: A list of names in sect_list. | |||||
""" | |||||
sect_name = None | |||||
sect_list = {} | |||||
sect_key_list = [] | |||||
single_section = {} | |||||
single_section_key = [] | |||||
with open(self.file_path, 'r') as f: | |||||
lines = f.readlines() | |||||
for line in lines: | |||||
if line.startswith('[') and line.endswith(']\n'): | |||||
if sect_name is None: | |||||
pass | |||||
else: | |||||
sect_list[sect_name] = single_section, single_section_key | |||||
single_section = {} | |||||
single_section_key = [] | |||||
sect_key_list.append(sect_name) | |||||
sect_name = line[1: -2] | |||||
continue | |||||
if line.startswith('#'): | |||||
single_section[line] = '#' | |||||
single_section_key.append(line) | |||||
continue | |||||
if line.startswith('\n'): | |||||
single_section_key.append('\n') | |||||
continue | |||||
if '=' not in line: | |||||
# log = create_logger(__name__, './config_saver.log') | |||||
# log.error("can NOT load config file [%s]" % self.file_path) | |||||
raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) | |||||
key = line.split('=', maxsplit=1)[0].strip() | |||||
value = line.split('=', maxsplit=1)[1].strip() + '\n' | |||||
single_section[key] = value | |||||
single_section_key.append(key) | |||||
if sect_name is not None: | |||||
sect_list[sect_name] = single_section, single_section_key | |||||
sect_key_list.append(sect_name) | |||||
return sect_list, sect_key_list | |||||
def _write_section(self, sect_list, sect_key_list): | |||||
"""This is the function to write config file with section list and name list. | |||||
:param sect_list: A list of ConfigSection() need to be writen into file. | |||||
:param sect_key_list: A list of name of sect_list. | |||||
:return: | |||||
""" | |||||
with open(self.file_path, 'w') as f: | |||||
for sect_key in sect_key_list: | |||||
single_section, single_section_key = sect_list[sect_key] | |||||
f.write('[' + sect_key + ']\n') | |||||
for key in single_section_key: | |||||
if key == '\n': | |||||
f.write('\n') | |||||
continue | |||||
if single_section[key] == '#': | |||||
f.write(key) | |||||
continue | |||||
f.write(key + ' = ' + single_section[key]) | |||||
f.write('\n') | |||||
def save_config_file(self, section_name, section): | |||||
"""This is the function to be called to change the config file with a single section and its name. | |||||
:param section_name: The name of section what needs to be changed and saved. | |||||
:param section: The section with key and value what needs to be changed and saved. | |||||
:return: | |||||
""" | |||||
section_file = self._get_section(section_name) | |||||
if len(section_file.__dict__.keys()) == 0: # the section not in the file before | |||||
# append this section to config file | |||||
with open(self.file_path, 'a') as f: | |||||
f.write('[' + section_name + ']\n') | |||||
for k in section.__dict__.keys(): | |||||
f.write(k + ' = ') | |||||
if isinstance(section[k], str): | |||||
f.write('\"' + str(section[k]) + '\"\n\n') | |||||
else: | |||||
f.write(str(section[k]) + '\n\n') | |||||
else: | |||||
# the section exists | |||||
change_file = False | |||||
for k in section.__dict__.keys(): | |||||
if k not in section_file: | |||||
# find a new key in this section | |||||
change_file = True | |||||
break | |||||
if section_file[k] != section[k]: | |||||
# logger = create_logger(__name__, "./config_loader.log") | |||||
# logger.warning("section [%s] in config file [%s] has been changed" % ( | |||||
# section_name, self.file_path | |||||
# )) | |||||
change_file = True | |||||
break | |||||
if not change_file: | |||||
return | |||||
sect_list, sect_key_list = self._read_section() | |||||
if section_name not in sect_key_list: | |||||
raise AttributeError() | |||||
sect, sect_key = sect_list[section_name] | |||||
for k in section.__dict__.keys(): | |||||
if k not in sect_key: | |||||
if sect_key[-1] != '\n': | |||||
sect_key.append('\n') | |||||
sect_key.append(k) | |||||
sect[k] = str(section[k]) | |||||
if isinstance(section[k], str): | |||||
sect[k] = "\"" + sect[k] + "\"" | |||||
sect[k] = sect[k] + "\n" | |||||
sect_list[section_name] = sect, sect_key | |||||
self._write_section(sect_list, sect_key_list) |
@@ -1,9 +1,8 @@ | |||||
import os | import os | ||||
from fastNLP.loader.base_loader import BaseLoader | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.field import * | |||||
from fastNLP.io.base_loader import DataLoaderRegister | |||||
def convert_seq_dataset(data): | def convert_seq_dataset(data): | ||||
@@ -20,8 +19,7 @@ def convert_seq_dataset(data): | |||||
""" | """ | ||||
dataset = DataSet() | dataset = DataSet() | ||||
for word_seq in data: | for word_seq in data: | ||||
x = TextField(word_seq, is_target=False) | |||||
dataset.append(Instance(word_seq=x)) | |||||
dataset.append(Instance(word_seq=word_seq)) | |||||
return dataset | return dataset | ||||
@@ -40,11 +38,7 @@ def convert_seq2tag_dataset(data): | |||||
""" | """ | ||||
dataset = DataSet() | dataset = DataSet() | ||||
for sample in data: | for sample in data: | ||||
word_seq, label = sample[0], sample[1] | |||||
ins = Instance() | |||||
ins.add_field("word_seq", TextField(word_seq, is_target=False)) \ | |||||
.add_field("label", LabelField(label, is_target=True)) | |||||
dataset.append(ins) | |||||
dataset.append(Instance(word_seq=sample[0], label=sample[1])) | |||||
return dataset | return dataset | ||||
@@ -63,20 +57,13 @@ def convert_seq2seq_dataset(data): | |||||
""" | """ | ||||
dataset = DataSet() | dataset = DataSet() | ||||
for sample in data: | for sample in data: | ||||
word_seq, label_seq = sample[0], sample[1] | |||||
ins = Instance() | |||||
ins.add_field("word_seq", TextField(word_seq, is_target=False)) \ | |||||
.add_field("label_seq", TextField(label_seq, is_target=True)) | |||||
dataset.append(ins) | |||||
dataset.append(Instance(word_seq=sample[0], label_seq=sample[1])) | |||||
return dataset | return dataset | ||||
class DataSetLoader(BaseLoader): | |||||
class DataSetLoader: | |||||
""""loader for data sets""" | """"loader for data sets""" | ||||
def __init__(self): | |||||
super(DataSetLoader, self).__init__() | |||||
def load(self, path): | def load(self, path): | ||||
""" load data in `path` into a dataset | """ load data in `path` into a dataset | ||||
""" | """ | ||||
@@ -88,7 +75,20 @@ class DataSetLoader(BaseLoader): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
@DataSet.set_reader('read_raw') | |||||
class NativeDataSetLoader(DataSetLoader): | |||||
def __init__(self): | |||||
super(NativeDataSetLoader, self).__init__() | |||||
def load(self, path): | |||||
ds = DataSet.read_csv(path, headers=("raw_sentence", "label"), sep="\t") | |||||
ds.set_input("raw_sentence") | |||||
ds.set_target("label") | |||||
return ds | |||||
DataLoaderRegister.set_reader(NativeDataSetLoader, 'read_naive') | |||||
class RawDataSetLoader(DataSetLoader): | class RawDataSetLoader(DataSetLoader): | ||||
def __init__(self): | def __init__(self): | ||||
super(RawDataSetLoader, self).__init__() | super(RawDataSetLoader, self).__init__() | ||||
@@ -104,7 +104,9 @@ class RawDataSetLoader(DataSetLoader): | |||||
return convert_seq_dataset(data) | return convert_seq_dataset(data) | ||||
@DataSet.set_reader('read_pos') | |||||
DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') | |||||
class POSDataSetLoader(DataSetLoader): | class POSDataSetLoader(DataSetLoader): | ||||
"""Dataset Loader for POS Tag datasets. | """Dataset Loader for POS Tag datasets. | ||||
@@ -174,7 +176,9 @@ class POSDataSetLoader(DataSetLoader): | |||||
return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
@DataSet.set_reader('read_tokenize') | |||||
DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos') | |||||
class TokenizeDataSetLoader(DataSetLoader): | class TokenizeDataSetLoader(DataSetLoader): | ||||
""" | """ | ||||
Data set loader for tokenization data sets | Data set loader for tokenization data sets | ||||
@@ -234,7 +238,6 @@ class TokenizeDataSetLoader(DataSetLoader): | |||||
return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
@DataSet.set_reader('read_class') | |||||
class ClassDataSetLoader(DataSetLoader): | class ClassDataSetLoader(DataSetLoader): | ||||
"""Loader for classification data sets""" | """Loader for classification data sets""" | ||||
@@ -273,7 +276,6 @@ class ClassDataSetLoader(DataSetLoader): | |||||
return convert_seq2tag_dataset(data) | return convert_seq2tag_dataset(data) | ||||
@DataSet.set_reader('read_conll') | |||||
class ConllLoader(DataSetLoader): | class ConllLoader(DataSetLoader): | ||||
"""loader for conll format files""" | """loader for conll format files""" | ||||
@@ -315,7 +317,6 @@ class ConllLoader(DataSetLoader): | |||||
pass | pass | ||||
@DataSet.set_reader('read_lm') | |||||
class LMDataSetLoader(DataSetLoader): | class LMDataSetLoader(DataSetLoader): | ||||
"""Language Model Dataset Loader | """Language Model Dataset Loader | ||||
@@ -352,7 +353,6 @@ class LMDataSetLoader(DataSetLoader): | |||||
pass | pass | ||||
@DataSet.set_reader('read_people_daily') | |||||
class PeopleDailyCorpusLoader(DataSetLoader): | class PeopleDailyCorpusLoader(DataSetLoader): | ||||
""" | """ | ||||
People Daily Corpus: Chinese word segmentation, POS tag, NER | People Daily Corpus: Chinese word segmentation, POS tag, NER | ||||
@@ -368,6 +368,8 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||||
pos_tag_examples = [] | pos_tag_examples = [] | ||||
ner_examples = [] | ner_examples = [] | ||||
for sent in sents: | for sent in sents: | ||||
if len(sent) <= 2: | |||||
continue | |||||
inside_ne = False | inside_ne = False | ||||
sent_pos_tag = [] | sent_pos_tag = [] | ||||
sent_words = [] | sent_words = [] | ||||
@@ -400,10 +402,20 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||||
sent_words.append(token) | sent_words.append(token) | ||||
pos_tag_examples.append([sent_words, sent_pos_tag]) | pos_tag_examples.append([sent_words, sent_pos_tag]) | ||||
ner_examples.append([sent_words, sent_ner]) | ner_examples.append([sent_words, sent_ner]) | ||||
return pos_tag_examples, ner_examples | |||||
# List[List[List[str], List[str]]] | |||||
# ner_examples not used | |||||
return self.convert(pos_tag_examples) | |||||
def convert(self, data): | def convert(self, data): | ||||
pass | |||||
data_set = DataSet() | |||||
for item in data: | |||||
sent_words, sent_pos_tag = item[0], item[1] | |||||
data_set.append(Instance(words=sent_words, tags=sent_pos_tag)) | |||||
data_set.apply(lambda ins: len(ins), new_field_name="seq_len") | |||||
data_set.set_target("tags") | |||||
data_set.set_input("sent_words") | |||||
data_set.set_input("seq_len") | |||||
return data_set | |||||
class SNLIDataSetLoader(DataSetLoader): | class SNLIDataSetLoader(DataSetLoader): | ||||
@@ -459,17 +471,13 @@ class SNLIDataSetLoader(DataSetLoader): | |||||
for example in data: | for example in data: | ||||
p, h, l = example | p, h, l = example | ||||
# list, list, str | # list, list, str | ||||
x1 = TextField(p, is_target=False) | |||||
x2 = TextField(h, is_target=False) | |||||
x1_len = TextField([1] * len(p), is_target=False) | |||||
x2_len = TextField([1] * len(h), is_target=False) | |||||
y = LabelField(l, is_target=True) | |||||
instance = Instance() | instance = Instance() | ||||
instance.add_field("premise", x1) | |||||
instance.add_field("hypothesis", x2) | |||||
instance.add_field("premise_len", x1_len) | |||||
instance.add_field("hypothesis_len", x2_len) | |||||
instance.add_field("truth", y) | |||||
instance.add_field("premise", p) | |||||
instance.add_field("hypothesis", h) | |||||
instance.add_field("truth", l) | |||||
data_set.append(instance) | data_set.append(instance) | ||||
data_set.apply(lambda ins: len(ins["premise"]), new_field_name="premise_len") | |||||
data_set.apply(lambda ins: len(ins["hypothesis"]), new_field_name="hypothesis_len") | |||||
data_set.set_input("premise", "hypothesis", "premise_len", "hypothesis_len") | |||||
data_set.set_target("truth") | |||||
return data_set | return data_set |
@@ -0,0 +1,113 @@ | |||||
import numpy as np | |||||
import torch | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.io.base_loader import BaseLoader | |||||
class EmbedLoader(BaseLoader): | |||||
"""docstring for EmbedLoader""" | |||||
def __init__(self): | |||||
super(EmbedLoader, self).__init__() | |||||
@staticmethod | |||||
def _load_glove(emb_file): | |||||
"""Read file as a glove embedding | |||||
file format: | |||||
embeddings are split by line, | |||||
for one embedding, word and numbers split by space | |||||
Example:: | |||||
word_1 float_1 float_2 ... float_emb_dim | |||||
word_2 float_1 float_2 ... float_emb_dim | |||||
... | |||||
""" | |||||
emb = {} | |||||
with open(emb_file, 'r', encoding='utf-8') as f: | |||||
for line in f: | |||||
line = list(filter(lambda w: len(w) > 0, line.strip().split(' '))) | |||||
if len(line) > 2: | |||||
emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) | |||||
return emb | |||||
@staticmethod | |||||
def _load_pretrain(emb_file, emb_type): | |||||
"""Read txt data from embedding file and convert to np.array as pre-trained embedding | |||||
:param str emb_file: the pre-trained embedding file path | |||||
:param str emb_type: the pre-trained embedding data format | |||||
:return dict embedding: `{str: np.array}` | |||||
""" | |||||
if emb_type == 'glove': | |||||
return EmbedLoader._load_glove(emb_file) | |||||
else: | |||||
raise Exception("embedding type {} not support yet".format(emb_type)) | |||||
@staticmethod | |||||
def load_embedding(emb_dim, emb_file, emb_type, vocab): | |||||
"""Load the pre-trained embedding and combine with the given dictionary. | |||||
:param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding. | |||||
:param str emb_file: the pre-trained embedding file path. | |||||
:param str emb_type: the pre-trained embedding format, support glove now | |||||
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding | |||||
:return embedding_tensor: Tensor of shape (len(word_dict), emb_dim) | |||||
vocab: input vocab or vocab built by pre-train | |||||
""" | |||||
pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) | |||||
if vocab is None: | |||||
# build vocabulary from pre-trained embedding | |||||
vocab = Vocabulary() | |||||
for w in pretrain.keys(): | |||||
vocab.add(w) | |||||
embedding_tensor = torch.randn(len(vocab), emb_dim) | |||||
for w, v in pretrain.items(): | |||||
if len(v.shape) > 1 or emb_dim != v.shape[0]: | |||||
raise ValueError( | |||||
"Pretrained embedding dim is {}. Dimension dismatched. Required {}".format(v.shape, (emb_dim,))) | |||||
if vocab.has_word(w): | |||||
embedding_tensor[vocab[w]] = v | |||||
return embedding_tensor, vocab | |||||
@staticmethod | |||||
def parse_glove_line(line): | |||||
line = list(filter(lambda w: len(w) > 0, line.strip().split(" "))) | |||||
if len(line) <= 2: | |||||
raise RuntimeError("something goes wrong in parsing glove embedding") | |||||
return line[0], torch.Tensor(list(map(float, line[1:]))) | |||||
@staticmethod | |||||
def fast_load_embedding(emb_dim, emb_file, vocab): | |||||
"""Fast load the pre-trained embedding and combine with the given dictionary. | |||||
This loading method uses line-by-line operation. | |||||
:param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding. | |||||
:param str emb_file: the pre-trained embedding file path. | |||||
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding | |||||
:return numpy.ndarray embedding_matrix: | |||||
""" | |||||
if vocab is None: | |||||
raise RuntimeError("You must provide a vocabulary.") | |||||
embedding_matrix = np.zeros(shape=(len(vocab), emb_dim)) | |||||
hit_flags = np.zeros(shape=(len(vocab),), dtype=int) | |||||
with open(emb_file, "r", encoding="utf-8") as f: | |||||
for line in f: | |||||
word, vector = EmbedLoader.parse_glove_line(line) | |||||
if word in vocab: | |||||
if len(vector.shape) > 1 or emb_dim != vector.shape[0]: | |||||
raise ValueError("Pre-trained embedding dim is {}. Expect {}.".format(vector.shape, (emb_dim,))) | |||||
embedding_matrix[vocab[word]] = vector | |||||
hit_flags[vocab[word]] = 1 | |||||
if np.sum(hit_flags) < len(vocab): | |||||
# some words from vocab are missing in pre-trained embedding | |||||
# we normally sample each dimension | |||||
vocab_embed = embedding_matrix[np.where(hit_flags)] | |||||
sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0), | |||||
size=(len(vocab) - np.sum(hit_flags), emb_dim)) | |||||
embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors | |||||
return embedding_matrix |
@@ -0,0 +1,56 @@ | |||||
import torch | |||||
from fastNLP.io.base_loader import BaseLoader | |||||
class ModelLoader(BaseLoader): | |||||
""" | |||||
Loader for models. | |||||
""" | |||||
def __init__(self): | |||||
super(ModelLoader, self).__init__() | |||||
@staticmethod | |||||
def load_pytorch(empty_model, model_path): | |||||
""" | |||||
Load model parameters from .pkl files into the empty PyTorch model. | |||||
:param empty_model: a PyTorch model with initialized parameters. | |||||
:param model_path: str, the path to the saved model. | |||||
""" | |||||
empty_model.load_state_dict(torch.load(model_path)) | |||||
@staticmethod | |||||
def load_pytorch_model(model_path): | |||||
"""Load the entire model. | |||||
""" | |||||
return torch.load(model_path) | |||||
class ModelSaver(object): | |||||
"""Save a model | |||||
Example:: | |||||
saver = ModelSaver("./save/model_ckpt_100.pkl") | |||||
saver.save_pytorch(model) | |||||
""" | |||||
def __init__(self, save_path): | |||||
""" | |||||
:param save_path: str, the path to the saving directory. | |||||
""" | |||||
self.save_path = save_path | |||||
def save_pytorch(self, model, param_only=True): | |||||
"""Save a pytorch model into .pkl file. | |||||
:param model: a PyTorch model | |||||
:param param_only: bool, whether only to save the model parameters or the entire model. | |||||
""" | |||||
if param_only is True: | |||||
torch.save(model.state_dict(), self.save_path) | |||||
else: | |||||
torch.save(model, self.save_path) |
@@ -1,32 +0,0 @@ | |||||
class BaseLoader(object): | |||||
def __init__(self): | |||||
super(BaseLoader, self).__init__() | |||||
@staticmethod | |||||
def load_lines(data_path): | |||||
with open(data_path, "r", encoding="utf=8") as f: | |||||
text = f.readlines() | |||||
return [line.strip() for line in text] | |||||
@staticmethod | |||||
def load(data_path): | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
text = f.readlines() | |||||
return [[word for word in sent.strip()] for sent in text] | |||||
class ToyLoader0(BaseLoader): | |||||
""" | |||||
For CharLM | |||||
""" | |||||
def __init__(self, data_path): | |||||
super(ToyLoader0, self).__init__(data_path) | |||||
def load(self): | |||||
with open(self.data_path, 'r') as f: | |||||
corpus = f.read().lower() | |||||
import re | |||||
corpus = re.sub(r"<unk>", "unk", corpus) | |||||
return corpus.split() |
@@ -1,149 +0,0 @@ | |||||
import configparser | |||||
import json | |||||
import os | |||||
from fastNLP.loader.base_loader import BaseLoader | |||||
class ConfigLoader(BaseLoader): | |||||
"""loader for configuration files""" | |||||
def __init__(self, data_path=None): | |||||
super(ConfigLoader, self).__init__() | |||||
if data_path is not None: | |||||
self.config = self.parse(super(ConfigLoader, self).load(data_path)) | |||||
@staticmethod | |||||
def parse(string): | |||||
raise NotImplementedError | |||||
@staticmethod | |||||
def load_config(file_path, sections): | |||||
""" | |||||
:param file_path: the path of config file | |||||
:param sections: the dict of {section_name(string): Section instance} | |||||
Example: | |||||
test_args = ConfigSection() | |||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
:return: return nothing, but the value of attributes are saved in sessions | |||||
""" | |||||
assert isinstance(sections, dict) | |||||
cfg = configparser.ConfigParser() | |||||
if not os.path.exists(file_path): | |||||
raise FileNotFoundError("config file {} not found. ".format(file_path)) | |||||
cfg.read(file_path) | |||||
for s in sections: | |||||
attr_list = [i for i in sections[s].__dict__.keys() if | |||||
not callable(getattr(sections[s], i)) and not i.startswith("__")] | |||||
if s not in cfg: | |||||
print('section %s not found in config file' % (s)) | |||||
continue | |||||
gen_sec = cfg[s] | |||||
for attr in gen_sec.keys(): | |||||
try: | |||||
val = json.loads(gen_sec[attr]) | |||||
# print(s, attr, val, type(val)) | |||||
if attr in attr_list: | |||||
assert type(val) == type(getattr(sections[s], attr)), \ | |||||
'type not match, except %s but got %s' % \ | |||||
(type(getattr(sections[s], attr)), type(val)) | |||||
""" | |||||
if attr in attr_list then check its type and | |||||
update its value. | |||||
else add a new attr in sections[s] | |||||
""" | |||||
setattr(sections[s], attr, val) | |||||
except Exception as e: | |||||
print("cannot load attribute %s in section %s" | |||||
% (attr, s)) | |||||
pass | |||||
class ConfigSection(object): | |||||
def __init__(self): | |||||
pass | |||||
def __getitem__(self, key): | |||||
""" | |||||
:param key: str, the name of the attribute | |||||
:return attr: the value of this attribute | |||||
if key not in self.__dict__.keys(): | |||||
return self[key] | |||||
else: | |||||
raise AttributeError | |||||
""" | |||||
if key in self.__dict__.keys(): | |||||
return getattr(self, key) | |||||
raise AttributeError("do NOT have attribute %s" % key) | |||||
def __setitem__(self, key, value): | |||||
""" | |||||
:param key: str, the name of the attribute | |||||
:param value: the value of this attribute | |||||
if key not in self.__dict__.keys(): | |||||
self[key] will be added | |||||
else: | |||||
self[key] will be updated | |||||
""" | |||||
if key in self.__dict__.keys(): | |||||
if not isinstance(value, type(getattr(self, key))): | |||||
raise AttributeError("attr %s except %s but got %s" % | |||||
(key, str(type(getattr(self, key))), str(type(value)))) | |||||
setattr(self, key, value) | |||||
def __contains__(self, item): | |||||
""" | |||||
:param item: The key of item. | |||||
:return: True if the key in self.__dict__.keys() else False. | |||||
""" | |||||
return item in self.__dict__.keys() | |||||
def __eq__(self, other): | |||||
"""Overwrite the == operator | |||||
:param other: Another ConfigSection() object which to be compared. | |||||
:return: True if value of each key in each ConfigSection() object are equal to the other, else False. | |||||
""" | |||||
for k in self.__dict__.keys(): | |||||
if k not in other.__dict__.keys(): | |||||
return False | |||||
if getattr(self, k) != getattr(self, k): | |||||
return False | |||||
for k in other.__dict__.keys(): | |||||
if k not in self.__dict__.keys(): | |||||
return False | |||||
if getattr(self, k) != getattr(self, k): | |||||
return False | |||||
return True | |||||
def __ne__(self, other): | |||||
"""Overwrite the != operator | |||||
:param other: | |||||
:return: | |||||
""" | |||||
return not self.__eq__(other) | |||||
@property | |||||
def data(self): | |||||
return self.__dict__ | |||||
if __name__ == "__main__": | |||||
config = ConfigLoader('there is no data') | |||||
section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()} | |||||
""" | |||||
General and My can be found in config file, so the attr and | |||||
value will be updated | |||||
A cannot be found in config file, so nothing will be done | |||||
""" | |||||
config.load_config("../../test/data_for_tests/config", section) | |||||
for s in section: | |||||
print(s) | |||||
for attr in section[s].__dict__.keys(): | |||||
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr))) |
@@ -1,85 +0,0 @@ | |||||
import _pickle | |||||
import os | |||||
import torch | |||||
from fastNLP.loader.base_loader import BaseLoader | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
class EmbedLoader(BaseLoader): | |||||
"""docstring for EmbedLoader""" | |||||
def __init__(self): | |||||
super(EmbedLoader, self).__init__() | |||||
@staticmethod | |||||
def _load_glove(emb_file): | |||||
"""Read file as a glove embedding | |||||
file format: | |||||
embeddings are split by line, | |||||
for one embedding, word and numbers split by space | |||||
Example:: | |||||
word_1 float_1 float_2 ... float_emb_dim | |||||
word_2 float_1 float_2 ... float_emb_dim | |||||
... | |||||
""" | |||||
emb = {} | |||||
with open(emb_file, 'r', encoding='utf-8') as f: | |||||
for line in f: | |||||
line = list(filter(lambda w: len(w)>0, line.strip().split(' '))) | |||||
if len(line) > 0: | |||||
emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) | |||||
return emb | |||||
@staticmethod | |||||
def _load_pretrain(emb_file, emb_type): | |||||
"""Read txt data from embedding file and convert to np.array as pre-trained embedding | |||||
:param emb_file: str, the pre-trained embedding file path | |||||
:param emb_type: str, the pre-trained embedding data format | |||||
:return dict: {str: np.array} | |||||
""" | |||||
if emb_type == 'glove': | |||||
return EmbedLoader._load_glove(emb_file) | |||||
else: | |||||
raise Exception("embedding type {} not support yet".format(emb_type)) | |||||
@staticmethod | |||||
def load_embedding(emb_dim, emb_file, emb_type, vocab, emb_pkl): | |||||
"""Load the pre-trained embedding and combine with the given dictionary. | |||||
:param emb_dim: int, the dimension of the embedding. Should be the same as pre-trained embedding. | |||||
:param emb_file: str, the pre-trained embedding file path. | |||||
:param emb_type: str, the pre-trained embedding format, support glove now | |||||
:param vocab: Vocabulary, a mapping from word to index, can be provided by user or built from pre-trained embedding | |||||
:param emb_pkl: str, the embedding pickle file. | |||||
:return embedding_tensor: Tensor of shape (len(word_dict), emb_dim) | |||||
vocab: input vocab or vocab built by pre-train | |||||
TODO: fragile code | |||||
""" | |||||
# If the embedding pickle exists, load it and return. | |||||
if os.path.exists(emb_pkl): | |||||
with open(emb_pkl, "rb") as f: | |||||
embedding_tensor, vocab = _pickle.load(f) | |||||
return embedding_tensor, vocab | |||||
# Otherwise, load the pre-trained embedding. | |||||
pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) | |||||
if vocab is None: | |||||
# build vocabulary from pre-trained embedding | |||||
vocab = Vocabulary() | |||||
for w in pretrain.keys(): | |||||
vocab.update(w) | |||||
embedding_tensor = torch.randn(len(vocab), emb_dim) | |||||
for w, v in pretrain.items(): | |||||
if len(v.shape) > 1 or emb_dim != v.shape[0]: | |||||
raise ValueError('pretrian embedding dim is {}, dismatching required {}'.format(v.shape, (emb_dim,))) | |||||
if vocab.has_word(w): | |||||
embedding_tensor[vocab[w]] = v | |||||
# save and return the result | |||||
with open(emb_pkl, "wb") as f: | |||||
_pickle.dump((embedding_tensor, vocab), f) | |||||
return embedding_tensor, vocab |
@@ -1,21 +0,0 @@ | |||||
import torch | |||||
from fastNLP.loader.base_loader import BaseLoader | |||||
class ModelLoader(BaseLoader): | |||||
""" | |||||
Loader for models. | |||||
""" | |||||
def __init__(self, data_path): | |||||
super(ModelLoader, self).__init__(data_path) | |||||
@staticmethod | |||||
def load_pytorch(empty_model, model_path): | |||||
""" | |||||
Load model parameters from .pkl files into the empty PyTorch model. | |||||
:param empty_model: a PyTorch model with initialized parameters. | |||||
:param model_path: str, the path to the saved model. | |||||
""" | |||||
empty_model.load_state_dict(torch.load(model_path)) |
@@ -0,0 +1,6 @@ | |||||
from .base_model import BaseModel | |||||
from .biaffine_parser import BiaffineParser, GraphParser | |||||
from .char_language_model import CharLM | |||||
from .cnn_text_classification import CNNText | |||||
from .sequence_modeling import SeqLabeling, AdvSeqLabel | |||||
from .snli import SNLI |
@@ -1,6 +1,6 @@ | |||||
import torch | import torch | ||||
from fastNLP.core.trainer import Trainer | |||||
from fastNLP.modules.decoder.MLP import MLP | |||||
class BaseModel(torch.nn.Module): | class BaseModel(torch.nn.Module): | ||||
@@ -11,5 +11,19 @@ class BaseModel(torch.nn.Module): | |||||
super(BaseModel, self).__init__() | super(BaseModel, self).__init__() | ||||
def fit(self, train_data, dev_data=None, **train_args): | def fit(self, train_data, dev_data=None, **train_args): | ||||
trainer = Trainer(**train_args) | |||||
trainer.train(self, train_data, dev_data) | |||||
pass | |||||
def predict(self, *args, **kwargs): | |||||
raise NotImplementedError | |||||
class NaiveClassifier(BaseModel): | |||||
def __init__(self, in_feature_dim, out_feature_dim): | |||||
super(NaiveClassifier, self).__init__() | |||||
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) | |||||
def forward(self, x): | |||||
return {"predict": torch.sigmoid(self.mlp(x))} | |||||
def predict(self, x): | |||||
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} |
@@ -9,6 +9,8 @@ from torch.nn import functional as F | |||||
from fastNLP.modules.utils import initial_parameter | from fastNLP.modules.utils import initial_parameter | ||||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | from fastNLP.modules.encoder.variational_rnn import VarLSTM | ||||
from fastNLP.modules.dropout import TimestepDropout | from fastNLP.modules.dropout import TimestepDropout | ||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.modules.utils import seq_mask | |||||
def mst(scores): | def mst(scores): | ||||
""" | """ | ||||
@@ -16,10 +18,9 @@ def mst(scores): | |||||
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | ||||
""" | """ | ||||
length = scores.shape[0] | length = scores.shape[0] | ||||
min_score = -np.inf | |||||
mask = np.zeros((length, length)) | |||||
np.fill_diagonal(mask, -np.inf) | |||||
scores = scores + mask | |||||
min_score = scores.min() - 1 | |||||
eye = np.eye(length) | |||||
scores = scores * (1 - eye) + min_score * eye | |||||
heads = np.argmax(scores, axis=1) | heads = np.argmax(scores, axis=1) | ||||
heads[0] = 0 | heads[0] = 0 | ||||
tokens = np.arange(1, length) | tokens = np.arange(1, length) | ||||
@@ -114,7 +115,7 @@ def _find_cycle(vertices, edges): | |||||
return [SCC for SCC in _SCCs if len(SCC) > 1] | return [SCC for SCC in _SCCs if len(SCC) > 1] | ||||
class GraphParser(nn.Module): | |||||
class GraphParser(BaseModel): | |||||
"""Graph based Parser helper class, support greedy decoding and MST(Maximum Spanning Tree) decoding | """Graph based Parser helper class, support greedy decoding and MST(Maximum Spanning Tree) decoding | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
@@ -123,22 +124,31 @@ class GraphParser(nn.Module): | |||||
def forward(self, x): | def forward(self, x): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def _greedy_decoder(self, arc_matrix, seq_mask=None): | |||||
def _greedy_decoder(self, arc_matrix, mask=None): | |||||
_, seq_len, _ = arc_matrix.shape | _, seq_len, _ = arc_matrix.shape | ||||
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | ||||
flip_mask = (mask == 0).byte() | |||||
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||||
_, heads = torch.max(matrix, dim=2) | _, heads = torch.max(matrix, dim=2) | ||||
if seq_mask is not None: | |||||
heads *= seq_mask.long() | |||||
if mask is not None: | |||||
heads *= mask.long() | |||||
return heads | return heads | ||||
def _mst_decoder(self, arc_matrix, seq_mask=None): | |||||
def _mst_decoder(self, arc_matrix, mask=None): | |||||
batch_size, seq_len, _ = arc_matrix.shape | batch_size, seq_len, _ = arc_matrix.shape | ||||
matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix) | matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix) | ||||
ans = matrix.new_zeros(batch_size, seq_len).long() | ans = matrix.new_zeros(batch_size, seq_len).long() | ||||
lens = (mask.long()).sum(1) if mask is not None else torch.zeros(batch_size) + seq_len | |||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device) | |||||
mask[batch_idx, lens-1] = 0 | |||||
for i, graph in enumerate(matrix): | for i, graph in enumerate(matrix): | ||||
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) | |||||
if seq_mask is not None: | |||||
ans *= seq_mask.long() | |||||
len_i = lens[i] | |||||
if len_i == seq_len: | |||||
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) | |||||
else: | |||||
ans[i, :len_i] = torch.as_tensor(mst(graph[:len_i, :len_i].cpu().numpy()), device=ans.device) | |||||
if mask is not None: | |||||
ans *= mask.long() | |||||
return ans | return ans | ||||
@@ -175,15 +185,13 @@ class LabelBilinear(nn.Module): | |||||
def __init__(self, in1_features, in2_features, num_label, bias=True): | def __init__(self, in1_features, in2_features, num_label, bias=True): | ||||
super(LabelBilinear, self).__init__() | super(LabelBilinear, self).__init__() | ||||
self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) | self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) | ||||
self.lin1 = nn.Linear(in1_features, num_label, bias=False) | |||||
self.lin2 = nn.Linear(in2_features, num_label, bias=False) | |||||
self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) | |||||
def forward(self, x1, x2): | def forward(self, x1, x2): | ||||
output = self.bilinear(x1, x2) | output = self.bilinear(x1, x2) | ||||
output += self.lin1(x1) + self.lin2(x2) | |||||
output += self.lin(torch.cat([x1, x2], dim=2)) | |||||
return output | return output | ||||
class BiaffineParser(GraphParser): | class BiaffineParser(GraphParser): | ||||
"""Biaffine Dependency Parser implemantation. | """Biaffine Dependency Parser implemantation. | ||||
refer to ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) | refer to ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) | ||||
@@ -194,6 +202,8 @@ class BiaffineParser(GraphParser): | |||||
word_emb_dim, | word_emb_dim, | ||||
pos_vocab_size, | pos_vocab_size, | ||||
pos_emb_dim, | pos_emb_dim, | ||||
word_hid_dim, | |||||
pos_hid_dim, | |||||
rnn_layers, | rnn_layers, | ||||
rnn_hidden_size, | rnn_hidden_size, | ||||
arc_mlp_size, | arc_mlp_size, | ||||
@@ -204,10 +214,15 @@ class BiaffineParser(GraphParser): | |||||
use_greedy_infer=False): | use_greedy_infer=False): | ||||
super(BiaffineParser, self).__init__() | super(BiaffineParser, self).__init__() | ||||
rnn_out_size = 2 * rnn_hidden_size | |||||
self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) | self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) | ||||
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | ||||
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) | |||||
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | |||||
self.word_norm = nn.LayerNorm(word_hid_dim) | |||||
self.pos_norm = nn.LayerNorm(pos_hid_dim) | |||||
if use_var_lstm: | if use_var_lstm: | ||||
self.lstm = VarLSTM(input_size=word_emb_dim + pos_emb_dim, | |||||
self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | hidden_size=rnn_hidden_size, | ||||
num_layers=rnn_layers, | num_layers=rnn_layers, | ||||
bias=True, | bias=True, | ||||
@@ -216,7 +231,7 @@ class BiaffineParser(GraphParser): | |||||
hidden_dropout=dropout, | hidden_dropout=dropout, | ||||
bidirectional=True) | bidirectional=True) | ||||
else: | else: | ||||
self.lstm = nn.LSTM(input_size=word_emb_dim + pos_emb_dim, | |||||
self.lstm = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | hidden_size=rnn_hidden_size, | ||||
num_layers=rnn_layers, | num_layers=rnn_layers, | ||||
bias=True, | bias=True, | ||||
@@ -224,141 +239,153 @@ class BiaffineParser(GraphParser): | |||||
dropout=dropout, | dropout=dropout, | ||||
bidirectional=True) | bidirectional=True) | ||||
rnn_out_size = 2 * rnn_hidden_size | |||||
self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), | self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), | ||||
nn.ELU()) | |||||
nn.LayerNorm(arc_mlp_size), | |||||
nn.ELU(), | |||||
TimestepDropout(p=dropout),) | |||||
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) | self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) | ||||
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), | self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), | ||||
nn.ELU()) | |||||
nn.LayerNorm(label_mlp_size), | |||||
nn.ELU(), | |||||
TimestepDropout(p=dropout),) | |||||
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | ||||
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | ||||
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | ||||
self.normal_dropout = nn.Dropout(p=dropout) | self.normal_dropout = nn.Dropout(p=dropout) | ||||
self.timestep_dropout = TimestepDropout(p=dropout) | |||||
self.use_greedy_infer = use_greedy_infer | self.use_greedy_infer = use_greedy_infer | ||||
initial_parameter(self) | |||||
self.reset_parameters() | |||||
self.explore_p = 0.2 | |||||
def reset_parameters(self): | |||||
for m in self.modules(): | |||||
if isinstance(m, nn.Embedding): | |||||
continue | |||||
elif isinstance(m, nn.LayerNorm): | |||||
nn.init.constant_(m.weight, 0.1) | |||||
nn.init.constant_(m.bias, 0) | |||||
else: | |||||
for p in m.parameters(): | |||||
nn.init.normal_(p, 0, 0.1) | |||||
def forward(self, word_seq, pos_seq, seq_mask, gold_heads=None, **_): | |||||
def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_): | |||||
""" | """ | ||||
:param word_seq: [batch_size, seq_len] sequence of word's indices | :param word_seq: [batch_size, seq_len] sequence of word's indices | ||||
:param pos_seq: [batch_size, seq_len] sequence of word's indices | :param pos_seq: [batch_size, seq_len] sequence of word's indices | ||||
:param seq_mask: [batch_size, seq_len] sequence of length masks | |||||
:param word_seq_origin_len: [batch_size, seq_len] sequence of length masks | |||||
:param gold_heads: [batch_size, seq_len] sequence of golden heads | :param gold_heads: [batch_size, seq_len] sequence of golden heads | ||||
:return dict: parsing results | :return dict: parsing results | ||||
arc_pred: [batch_size, seq_len, seq_len] | arc_pred: [batch_size, seq_len, seq_len] | ||||
label_pred: [batch_size, seq_len, seq_len] | label_pred: [batch_size, seq_len, seq_len] | ||||
seq_mask: [batch_size, seq_len] | |||||
mask: [batch_size, seq_len] | |||||
head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | ||||
""" | """ | ||||
# prepare embeddings | # prepare embeddings | ||||
device = self.parameters().__next__().device | |||||
word_seq = word_seq.long().to(device) | |||||
pos_seq = pos_seq.long().to(device) | |||||
word_seq_origin_len = word_seq_origin_len.long().to(device).view(-1) | |||||
batch_size, seq_len = word_seq.shape | batch_size, seq_len = word_seq.shape | ||||
# print('forward {} {}'.format(batch_size, seq_len)) | # print('forward {} {}'.format(batch_size, seq_len)) | ||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | |||||
# get sequence mask | # get sequence mask | ||||
seq_mask = seq_mask.long() | |||||
mask = seq_mask(word_seq_origin_len, seq_len).long() | |||||
word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] | word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] | ||||
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] | pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] | ||||
word, pos = self.word_fc(word), self.pos_fc(pos) | |||||
word, pos = self.word_norm(word), self.pos_norm(pos) | |||||
x = torch.cat([word, pos], dim=2) # -> [N,L,C] | x = torch.cat([word, pos], dim=2) # -> [N,L,C] | ||||
del word, pos | |||||
# lstm, extract features | # lstm, extract features | ||||
sort_lens, sort_idx = torch.sort(word_seq_origin_len, dim=0, descending=True) | |||||
x = x[sort_idx] | |||||
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) | |||||
feat, _ = self.lstm(x) # -> [N,L,C] | feat, _ = self.lstm(x) # -> [N,L,C] | ||||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | |||||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||||
feat = feat[unsort_idx] | |||||
# for arc biaffine | # for arc biaffine | ||||
# mlp, reduce dim | # mlp, reduce dim | ||||
arc_dep = self.timestep_dropout(self.arc_dep_mlp(feat)) | |||||
arc_head = self.timestep_dropout(self.arc_head_mlp(feat)) | |||||
label_dep = self.timestep_dropout(self.label_dep_mlp(feat)) | |||||
label_head = self.timestep_dropout(self.label_head_mlp(feat)) | |||||
arc_dep = self.arc_dep_mlp(feat) | |||||
arc_head = self.arc_head_mlp(feat) | |||||
label_dep = self.label_dep_mlp(feat) | |||||
label_head = self.label_head_mlp(feat) | |||||
del feat | |||||
# biaffine arc classifier | # biaffine arc classifier | ||||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | ||||
flip_mask = (seq_mask == 0) | |||||
arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||||
# use gold or predicted arc to predict label | # use gold or predicted arc to predict label | ||||
if gold_heads is None: | |||||
if gold_heads is None or not self.training: | |||||
# use greedy decoding in training | # use greedy decoding in training | ||||
if self.training or self.use_greedy_infer: | if self.training or self.use_greedy_infer: | ||||
heads = self._greedy_decoder(arc_pred, seq_mask) | |||||
heads = self._greedy_decoder(arc_pred, mask) | |||||
else: | else: | ||||
heads = self._mst_decoder(arc_pred, seq_mask) | |||||
heads = self._mst_decoder(arc_pred, mask) | |||||
head_pred = heads | head_pred = heads | ||||
else: | else: | ||||
head_pred = None | |||||
heads = gold_heads | |||||
assert self.training # must be training mode | |||||
if torch.rand(1).item() < self.explore_p: | |||||
heads = self._greedy_decoder(arc_pred, mask) | |||||
head_pred = heads | |||||
else: | |||||
head_pred = None | |||||
heads = gold_heads | |||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | |||||
label_head = label_head[batch_range, heads].contiguous() | label_head = label_head[batch_range, heads].contiguous() | ||||
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | ||||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask} | |||||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} | |||||
if head_pred is not None: | if head_pred is not None: | ||||
res_dict['head_pred'] = head_pred | res_dict['head_pred'] = head_pred | ||||
return res_dict | return res_dict | ||||
def loss(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **_): | |||||
def loss(self, arc_pred, label_pred, head_indices, head_labels, mask, **_): | |||||
""" | """ | ||||
Compute loss. | Compute loss. | ||||
:param arc_pred: [batch_size, seq_len, seq_len] | :param arc_pred: [batch_size, seq_len, seq_len] | ||||
:param label_pred: [batch_size, seq_len, seq_len] | |||||
:param label_pred: [batch_size, seq_len, n_tags] | |||||
:param head_indices: [batch_size, seq_len] | :param head_indices: [batch_size, seq_len] | ||||
:param head_labels: [batch_size, seq_len] | :param head_labels: [batch_size, seq_len] | ||||
:param seq_mask: [batch_size, seq_len] | |||||
:param mask: [batch_size, seq_len] | |||||
:return: loss value | :return: loss value | ||||
""" | """ | ||||
batch_size, seq_len, _ = arc_pred.shape | batch_size, seq_len, _ = arc_pred.shape | ||||
arc_logits = F.log_softmax(arc_pred, dim=2) | |||||
flip_mask = (mask == 0) | |||||
_arc_pred = arc_pred.new_empty((batch_size, seq_len, seq_len)).copy_(arc_pred) | |||||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||||
arc_logits = F.log_softmax(_arc_pred, dim=2) | |||||
label_logits = F.log_softmax(label_pred, dim=2) | label_logits = F.log_softmax(label_pred, dim=2) | ||||
batch_index = torch.arange(start=0, end=batch_size, device=arc_logits.device).long().unsqueeze(1) | |||||
child_index = torch.arange(start=0, end=seq_len, device=arc_logits.device).long().unsqueeze(0) | |||||
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | |||||
child_index = torch.arange(seq_len, device=arc_logits.device, dtype=torch.long).unsqueeze(0) | |||||
arc_loss = arc_logits[batch_index, child_index, head_indices] | arc_loss = arc_logits[batch_index, child_index, head_indices] | ||||
label_loss = label_logits[batch_index, child_index, head_labels] | label_loss = label_logits[batch_index, child_index, head_labels] | ||||
arc_loss = arc_loss[:, 1:] | arc_loss = arc_loss[:, 1:] | ||||
label_loss = label_loss[:, 1:] | label_loss = label_loss[:, 1:] | ||||
float_mask = seq_mask[:, 1:].float() | |||||
length = (seq_mask.sum() - batch_size).float() | |||||
arc_nll = -(arc_loss*float_mask).sum() / length | |||||
label_nll = -(label_loss*float_mask).sum() / length | |||||
float_mask = mask[:, 1:].float() | |||||
arc_nll = -(arc_loss*float_mask).mean() | |||||
label_nll = -(label_loss*float_mask).mean() | |||||
return arc_nll + label_nll | return arc_nll + label_nll | ||||
def evaluate(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **kwargs): | |||||
""" | |||||
Evaluate the performance of prediction. | |||||
:return dict: performance results. | |||||
head_pred_corrct: number of correct predicted heads. | |||||
label_pred_correct: number of correct predicted labels. | |||||
total_tokens: number of predicted tokens | |||||
def predict(self, word_seq, pos_seq, word_seq_origin_len): | |||||
""" | """ | ||||
if 'head_pred' in kwargs: | |||||
head_pred = kwargs['head_pred'] | |||||
elif self.use_greedy_infer: | |||||
head_pred = self._greedy_decoder(arc_pred, seq_mask) | |||||
else: | |||||
head_pred = self._mst_decoder(arc_pred, seq_mask) | |||||
head_pred_correct = (head_pred == head_indices).long() * seq_mask | |||||
_, label_preds = torch.max(label_pred, dim=2) | |||||
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct | |||||
return {"head_pred_correct": head_pred_correct.sum(dim=1), | |||||
"label_pred_correct": label_pred_correct.sum(dim=1), | |||||
"total_tokens": seq_mask.sum(dim=1)} | |||||
def metrics(self, head_pred_correct, label_pred_correct, total_tokens, **_): | |||||
""" | |||||
Compute the metrics of model | |||||
:param head_pred_corrct: number of correct predicted heads. | |||||
:param label_pred_correct: number of correct predicted labels. | |||||
:param total_tokens: number of predicted tokens | |||||
:return dict: the metrics results | |||||
UAS: the head predicted accuracy | |||||
LAS: the label predicted accuracy | |||||
:param word_seq: | |||||
:param pos_seq: | |||||
:param word_seq_origin_len: | |||||
:return: head_pred: [B, L] | |||||
label_pred: [B, L] | |||||
seq_len: [B,] | |||||
""" | """ | ||||
return {"UAS": head_pred_correct.sum().float() / total_tokens.sum().float() * 100, | |||||
"LAS": label_pred_correct.sum().float() / total_tokens.sum().float() * 100} | |||||
res = self(word_seq, pos_seq, word_seq_origin_len) | |||||
output = {} | |||||
output['head_pred'] = res.pop('head_pred') | |||||
_, label_pred = res.pop('label_pred').max(2) | |||||
output['label_pred'] = label_pred | |||||
return output |
@@ -15,33 +15,43 @@ class CNNText(torch.nn.Module): | |||||
Classification.' | Classification.' | ||||
""" | """ | ||||
def __init__(self, args): | |||||
def __init__(self, embed_num, | |||||
embed_dim, | |||||
num_classes, | |||||
kernel_nums=(3, 4, 5), | |||||
kernel_sizes=(3, 4, 5), | |||||
padding=0, | |||||
dropout=0.5): | |||||
super(CNNText, self).__init__() | super(CNNText, self).__init__() | ||||
num_classes = args["num_classes"] | |||||
kernel_nums = [100, 100, 100] | |||||
kernel_sizes = [3, 4, 5] | |||||
vocab_size = args["vocab_size"] | |||||
embed_dim = 300 | |||||
pretrained_embed = None | |||||
drop_prob = 0.5 | |||||
# no support for pre-trained embedding currently | # no support for pre-trained embedding currently | ||||
self.embed = encoder.embedding.Embedding(vocab_size, embed_dim) | |||||
self.conv_pool = encoder.conv_maxpool.ConvMaxpool( | |||||
self.embed = encoder.Embedding(embed_num, embed_dim) | |||||
self.conv_pool = encoder.ConvMaxpool( | |||||
in_channels=embed_dim, | in_channels=embed_dim, | ||||
out_channels=kernel_nums, | out_channels=kernel_nums, | ||||
kernel_sizes=kernel_sizes) | |||||
self.dropout = nn.Dropout(drop_prob) | |||||
self.fc = encoder.linear.Linear(sum(kernel_nums), num_classes) | |||||
kernel_sizes=kernel_sizes, | |||||
padding=padding) | |||||
self.dropout = nn.Dropout(dropout) | |||||
self.fc = encoder.Linear(sum(kernel_nums), num_classes) | |||||
def forward(self, word_seq): | def forward(self, word_seq): | ||||
""" | """ | ||||
:param word_seq: torch.LongTensor, [batch_size, seq_len] | :param word_seq: torch.LongTensor, [batch_size, seq_len] | ||||
:return x: torch.LongTensor, [batch_size, num_classes] | |||||
:return output: dict of torch.LongTensor, [batch_size, num_classes] | |||||
""" | """ | ||||
x = self.embed(word_seq) # [N,L] -> [N,L,C] | x = self.embed(word_seq) # [N,L] -> [N,L,C] | ||||
x = self.conv_pool(x) # [N,L,C] -> [N,C] | x = self.conv_pool(x) # [N,L,C] -> [N,C] | ||||
x = self.dropout(x) | x = self.dropout(x) | ||||
x = self.fc(x) # [N,C] -> [N, N_class] | x = self.fc(x) # [N,C] -> [N, N_class] | ||||
return x | |||||
return {'pred': x} | |||||
def predict(self, word_seq): | |||||
""" | |||||
:param word_seq: torch.LongTensor, [batch_size, seq_len] | |||||
:return predict: dict of torch.LongTensor, [batch_size, seq_len] | |||||
""" | |||||
output = self(word_seq) | |||||
_, predict = output['pred'].max(dim=1) | |||||
return {'pred': predict} |
@@ -1,21 +1,9 @@ | |||||
import torch | import torch | ||||
import numpy as np | |||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules import decoder, encoder | from fastNLP.modules import decoder, encoder | ||||
def seq_mask(seq_len, max_len): | |||||
"""Create a mask for the sequences. | |||||
:param seq_len: list or torch.LongTensor | |||||
:param max_len: int | |||||
:return mask: torch.LongTensor | |||||
""" | |||||
if isinstance(seq_len, list): | |||||
seq_len = torch.LongTensor(seq_len) | |||||
mask = [torch.ge(seq_len, i + 1) for i in range(max_len)] | |||||
mask = torch.stack(mask, 1) | |||||
return mask | |||||
from fastNLP.modules.utils import seq_mask | |||||
class SeqLabeling(BaseModel): | class SeqLabeling(BaseModel): | ||||
@@ -44,6 +32,9 @@ class SeqLabeling(BaseModel): | |||||
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | :return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | ||||
If truth is not None, return loss, a scalar. Used in training. | If truth is not None, return loss, a scalar. Used in training. | ||||
""" | """ | ||||
assert word_seq.shape[0] == word_seq_origin_len.shape[0] | |||||
if truth is not None: | |||||
assert truth.shape == word_seq.shape | |||||
self.mask = self.make_mask(word_seq, word_seq_origin_len) | self.mask = self.make_mask(word_seq, word_seq_origin_len) | ||||
x = self.Embedding(word_seq) | x = self.Embedding(word_seq) | ||||
@@ -52,10 +43,8 @@ class SeqLabeling(BaseModel): | |||||
# [batch_size, max_len, hidden_size * direction] | # [batch_size, max_len, hidden_size * direction] | ||||
x = self.Linear(x) | x = self.Linear(x) | ||||
# [batch_size, max_len, num_classes] | # [batch_size, max_len, num_classes] | ||||
if truth is not None: | |||||
return self._internal_loss(x, truth) | |||||
else: | |||||
return self.decode(x) | |||||
return {"loss": self._internal_loss(x, truth) if truth is not None else None, | |||||
"predict": self.decode(x)} | |||||
def loss(self, x, y): | def loss(self, x, y): | ||||
""" Since the loss has been computed in forward(), this function simply returns x.""" | """ Since the loss has been computed in forward(), this function simply returns x.""" | ||||
@@ -79,8 +68,8 @@ class SeqLabeling(BaseModel): | |||||
def make_mask(self, x, seq_len): | def make_mask(self, x, seq_len): | ||||
batch_size, max_len = x.size(0), x.size(1) | batch_size, max_len = x.size(0), x.size(1) | ||||
mask = seq_mask(seq_len, max_len) | mask = seq_mask(seq_len, max_len) | ||||
mask = mask.byte().view(batch_size, max_len) | |||||
mask = mask.to(x) | |||||
mask = mask.view(batch_size, max_len) | |||||
mask = mask.to(x).float() | |||||
return mask | return mask | ||||
def decode(self, x, pad=True): | def decode(self, x, pad=True): | ||||
@@ -111,42 +100,119 @@ class AdvSeqLabel(SeqLabeling): | |||||
word_emb_dim = args["word_emb_dim"] | word_emb_dim = args["word_emb_dim"] | ||||
hidden_dim = args["rnn_hidden_units"] | hidden_dim = args["rnn_hidden_units"] | ||||
num_classes = args["num_classes"] | num_classes = args["num_classes"] | ||||
dropout = args['dropout'] | |||||
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | ||||
self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=3, dropout=0.3, bidirectional=True) | |||||
self.norm1 = torch.nn.LayerNorm(word_emb_dim) | |||||
# self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True) | |||||
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, bidirectional=True, batch_first=True) | |||||
self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | ||||
self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | |||||
self.relu = torch.nn.ReLU() | |||||
self.drop = torch.nn.Dropout(0.3) | |||||
self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3) | |||||
# self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | |||||
self.relu = torch.nn.LeakyReLU() | |||||
self.drop = torch.nn.Dropout(dropout) | |||||
self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) | self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) | ||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | |||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||||
def forward(self, word_seq, word_seq_origin_len, truth=None): | def forward(self, word_seq, word_seq_origin_len, truth=None): | ||||
""" | """ | ||||
:param word_seq: LongTensor, [batch_size, mex_len] | :param word_seq: LongTensor, [batch_size, mex_len] | ||||
:param word_seq_origin_len: list of int. | |||||
:param word_seq_origin_len: LongTensor, [batch_size, ] | |||||
:param truth: LongTensor, [batch_size, max_len] | :param truth: LongTensor, [batch_size, max_len] | ||||
:return y: | |||||
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | |||||
If truth is not None, return loss, a scalar. Used in training. | |||||
""" | """ | ||||
word_seq = word_seq.long() | |||||
word_seq_origin_len = word_seq_origin_len.long() | |||||
self.mask = self.make_mask(word_seq, word_seq_origin_len) | self.mask = self.make_mask(word_seq, word_seq_origin_len) | ||||
sent_len, idx_sort = torch.sort(word_seq_origin_len, descending=True) | |||||
_, idx_unsort = torch.sort(idx_sort, descending=False) | |||||
# word_seq_origin_len = word_seq_origin_len.long() | |||||
truth = truth.long() if truth is not None else None | |||||
batch_size = word_seq.size(0) | batch_size = word_seq.size(0) | ||||
max_len = word_seq.size(1) | max_len = word_seq.size(1) | ||||
if next(self.parameters()).is_cuda: | |||||
word_seq = word_seq.cuda() | |||||
idx_sort = idx_sort.cuda() | |||||
idx_unsort = idx_unsort.cuda() | |||||
self.mask = self.mask.cuda() | |||||
x = self.Embedding(word_seq) | x = self.Embedding(word_seq) | ||||
x = self.norm1(x) | |||||
# [batch_size, max_len, word_emb_dim] | # [batch_size, max_len, word_emb_dim] | ||||
x = self.Rnn(x) | |||||
sent_variable = x[idx_sort] | |||||
sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True) | |||||
x, _ = self.Rnn(sent_packed) | |||||
# print(x) | |||||
# [batch_size, max_len, hidden_size * direction] | # [batch_size, max_len, hidden_size * direction] | ||||
sent_output = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0] | |||||
x = sent_output[idx_unsort] | |||||
x = x.contiguous() | x = x.contiguous() | ||||
x = x.view(batch_size * max_len, -1) | |||||
# x = x.view(batch_size * max_len, -1) | |||||
x = self.Linear1(x) | x = self.Linear1(x) | ||||
x = self.batch_norm(x) | |||||
# x = self.batch_norm(x) | |||||
x = self.norm2(x) | |||||
x = self.relu(x) | x = self.relu(x) | ||||
x = self.drop(x) | x = self.drop(x) | ||||
x = self.Linear2(x) | x = self.Linear2(x) | ||||
x = x.view(batch_size, max_len, -1) | |||||
# x = x.view(batch_size, max_len, -1) | |||||
# [batch_size, max_len, num_classes] | # [batch_size, max_len, num_classes] | ||||
if truth is not None: | |||||
return self._internal_loss(x, truth) | |||||
else: | |||||
return self.decode(x) | |||||
# TODO seq_lens的key这样做不合理 | |||||
return {"loss": self._internal_loss(x, truth) if truth is not None else None, | |||||
"predict": self.decode(x), | |||||
'word_seq_origin_len': word_seq_origin_len} | |||||
def predict(self, **x): | |||||
out = self.forward(**x) | |||||
return {"predict": out["predict"]} | |||||
def loss(self, **kwargs): | |||||
assert 'loss' in kwargs | |||||
return kwargs['loss'] | |||||
if __name__ == '__main__': | |||||
args = { | |||||
'vocab_size': 20, | |||||
'word_emb_dim': 100, | |||||
'rnn_hidden_units': 100, | |||||
'num_classes': 10, | |||||
} | |||||
model = AdvSeqLabel(args) | |||||
data = [] | |||||
for i in range(20): | |||||
word_seq = torch.randint(20, (15,)).long() | |||||
word_seq_len = torch.LongTensor([15]) | |||||
truth = torch.randint(10, (15,)).long() | |||||
data.append((word_seq, word_seq_len, truth)) | |||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | |||||
print(model) | |||||
curidx = 0 | |||||
for i in range(1000): | |||||
endidx = min(len(data), curidx + 5) | |||||
b_word, b_len, b_truth = [], [], [] | |||||
for word_seq, word_seq_len, truth in data[curidx: endidx]: | |||||
b_word.append(word_seq) | |||||
b_len.append(word_seq_len) | |||||
b_truth.append(truth) | |||||
word_seq = torch.stack(b_word, dim=0) | |||||
word_seq_len = torch.cat(b_len, dim=0) | |||||
truth = torch.stack(b_truth, dim=0) | |||||
res = model(word_seq, word_seq_len, truth) | |||||
loss = res['loss'] | |||||
pred = res['predict'] | |||||
print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
optimizer.step() | |||||
curidx = endidx | |||||
if curidx == len(data): | |||||
curidx = 0 | |||||
@@ -1,11 +1,14 @@ | |||||
from . import aggregator | from . import aggregator | ||||
from . import decoder | from . import decoder | ||||
from . import encoder | from . import encoder | ||||
from . import interactor | |||||
from .aggregator import * | |||||
from .decoder import * | |||||
from .encoder import * | |||||
from .dropout import TimestepDropout | |||||
__version__ = '0.0.0' | __version__ = '0.0.0' | ||||
__all__ = ['encoder', | __all__ = ['encoder', | ||||
'decoder', | 'decoder', | ||||
'aggregator', | 'aggregator', | ||||
'interactor'] | |||||
'TimestepDropout'] |
@@ -1,5 +1,7 @@ | |||||
from .max_pool import MaxPool | from .max_pool import MaxPool | ||||
from .avg_pool import AvgPool | |||||
from .kmax_pool import KMaxPool | |||||
from .attention import Attention | |||||
from .self_attention import SelfAttention | |||||
__all__ = [ | |||||
'MaxPool' | |||||
] |
@@ -1,5 +1,6 @@ | |||||
import torch | import torch | ||||
from torch import nn | |||||
import math | |||||
from fastNLP.modules.utils import mask_softmax | from fastNLP.modules.utils import mask_softmax | ||||
@@ -17,3 +18,47 @@ class Attention(torch.nn.Module): | |||||
def _atten_forward(self, query, memory): | def _atten_forward(self, query, memory): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
class DotAtte(nn.Module): | |||||
def __init__(self, key_size, value_size): | |||||
# TODO never test | |||||
super(DotAtte, self).__init__() | |||||
self.key_size = key_size | |||||
self.value_size = value_size | |||||
self.scale = math.sqrt(key_size) | |||||
def forward(self, Q, K, V, seq_mask=None): | |||||
""" | |||||
:param Q: [batch, seq_len, key_size] | |||||
:param K: [batch, seq_len, key_size] | |||||
:param V: [batch, seq_len, value_size] | |||||
:param seq_mask: [batch, seq_len] | |||||
""" | |||||
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale | |||||
if seq_mask is not None: | |||||
output.masked_fill_(seq_mask.lt(1), -float('inf')) | |||||
output = nn.functional.softmax(output, dim=2) | |||||
return torch.matmul(output, V) | |||||
class MultiHeadAtte(nn.Module): | |||||
def __init__(self, input_size, output_size, key_size, value_size, num_atte): | |||||
raise NotImplementedError | |||||
# TODO never test | |||||
super(MultiHeadAtte, self).__init__() | |||||
self.in_linear = nn.ModuleList() | |||||
for i in range(num_atte * 3): | |||||
out_feat = key_size if (i % 3) != 2 else value_size | |||||
self.in_linear.append(nn.Linear(input_size, out_feat)) | |||||
self.attes = nn.ModuleList([DotAtte(key_size, value_size) for _ in range(num_atte)]) | |||||
self.out_linear = nn.Linear(value_size * num_atte, output_size) | |||||
def forward(self, Q, K, V, seq_mask=None): | |||||
heads = [] | |||||
for i in range(len(self.attes)): | |||||
j = i * 3 | |||||
qi, ki, vi = self.in_linear[j](Q), self.in_linear[j+1](K), self.in_linear[j+2](V) | |||||
headi = self.attes[i](qi, ki, vi, seq_mask) | |||||
heads.append(headi) | |||||
output = torch.cat(heads, dim=2) | |||||
return self.out_linear(output) |
@@ -3,6 +3,7 @@ from torch import nn | |||||
from fastNLP.modules.utils import initial_parameter | from fastNLP.modules.utils import initial_parameter | ||||
def log_sum_exp(x, dim=-1): | def log_sum_exp(x, dim=-1): | ||||
max_value, _ = x.max(dim=dim, keepdim=True) | max_value, _ = x.max(dim=dim, keepdim=True) | ||||
res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value | res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value | ||||
@@ -20,7 +21,7 @@ def seq_len_to_byte_mask(seq_lens): | |||||
class ConditionalRandomField(nn.Module): | class ConditionalRandomField(nn.Module): | ||||
def __init__(self, tag_size, include_start_end_trans=True ,initial_method = None): | |||||
def __init__(self, tag_size, include_start_end_trans=False ,initial_method = None): | |||||
""" | """ | ||||
:param tag_size: int, num of tags | :param tag_size: int, num of tags | ||||
:param include_start_end_trans: bool, whether to include start/end tag | :param include_start_end_trans: bool, whether to include start/end tag | ||||
@@ -31,7 +32,7 @@ class ConditionalRandomField(nn.Module): | |||||
self.tag_size = tag_size | self.tag_size = tag_size | ||||
# the meaning of entry in this matrix is (from_tag_id, to_tag_id) score | # the meaning of entry in this matrix is (from_tag_id, to_tag_id) score | ||||
self.transition_m = nn.Parameter(torch.randn(tag_size, tag_size)) | |||||
self.trans_m = nn.Parameter(torch.randn(tag_size, tag_size)) | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
self.start_scores = nn.Parameter(torch.randn(tag_size)) | self.start_scores = nn.Parameter(torch.randn(tag_size)) | ||||
self.end_scores = nn.Parameter(torch.randn(tag_size)) | self.end_scores = nn.Parameter(torch.randn(tag_size)) | ||||
@@ -39,137 +40,121 @@ class ConditionalRandomField(nn.Module): | |||||
# self.reset_parameter() | # self.reset_parameter() | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def reset_parameter(self): | def reset_parameter(self): | ||||
nn.init.xavier_normal_(self.transition_m) | |||||
nn.init.xavier_normal_(self.trans_m) | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
nn.init.normal_(self.start_scores) | nn.init.normal_(self.start_scores) | ||||
nn.init.normal_(self.end_scores) | nn.init.normal_(self.end_scores) | ||||
def _normalizer_likelihood(self, feats, masks): | |||||
def _normalizer_likelihood(self, logits, mask): | |||||
""" | """ | ||||
Computes the (batch_size,) denominator term for the log-likelihood, which is the | Computes the (batch_size,) denominator term for the log-likelihood, which is the | ||||
sum of the likelihoods across all possible state sequences. | sum of the likelihoods across all possible state sequences. | ||||
:param feats:FloatTensor, batch_size x max_len x tag_size | |||||
:param masks:ByteTensor, batch_size x max_len | |||||
:param logits:FloatTensor, max_len x batch_size x tag_size | |||||
:param mask:ByteTensor, max_len x batch_size | |||||
:return:FloatTensor, batch_size | :return:FloatTensor, batch_size | ||||
""" | """ | ||||
batch_size, max_len, _ = feats.size() | |||||
# alpha, batch_size x tag_size | |||||
seq_len, batch_size, n_tags = logits.size() | |||||
alpha = logits[0] | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha = self.start_scores.view(1, -1) + feats[:, 0] | |||||
else: | |||||
alpha = feats[:, 0] | |||||
# broadcast_trans_m, the meaning of entry in this matrix is [batch_idx, to_tag_id, from_tag_id] | |||||
broadcast_trans_m = self.transition_m.permute( | |||||
1, 0).unsqueeze(0).repeat(batch_size, 1, 1) | |||||
# loop | |||||
for i in range(1, max_len): | |||||
emit_score = feats[:, i].unsqueeze(2) | |||||
new_alpha = broadcast_trans_m + alpha.unsqueeze(1) + emit_score | |||||
alpha += self.start_scores.view(1, -1) | |||||
new_alpha = log_sum_exp(new_alpha, dim=2) | |||||
alpha = new_alpha * \ | |||||
masks[:, i:i + 1].float() + alpha * \ | |||||
(1 - masks[:, i:i + 1].float()) | |||||
for i in range(1, seq_len): | |||||
emit_score = logits[i].view(batch_size, 1, n_tags) | |||||
trans_score = self.trans_m.view(1, n_tags, n_tags) | |||||
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score | |||||
alpha = log_sum_exp(tmp, 1) * mask[i].view(batch_size, 1) + alpha * (1 - mask[i]).view(batch_size, 1) | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha = alpha + self.end_scores.view(1, -1) | |||||
alpha += self.end_scores.view(1, -1) | |||||
return log_sum_exp(alpha) | |||||
return log_sum_exp(alpha, 1) | |||||
def _glod_score(self, feats, tags, masks): | |||||
def _glod_score(self, logits, tags, mask): | |||||
""" | """ | ||||
Compute the score for the gold path. | Compute the score for the gold path. | ||||
:param feats: FloatTensor, batch_size x max_len x tag_size | |||||
:param tags: LongTensor, batch_size x max_len | |||||
:param masks: ByteTensor, batch_size x max_len | |||||
:param logits: FloatTensor, max_len x batch_size x tag_size | |||||
:param tags: LongTensor, max_len x batch_size | |||||
:param mask: ByteTensor, max_len x batch_size | |||||
:return:FloatTensor, batch_size | :return:FloatTensor, batch_size | ||||
""" | """ | ||||
batch_size, max_len, _ = feats.size() | |||||
# alpha, B x 1 | |||||
if self.include_start_end_trans: | |||||
alpha = self.start_scores.view(1, -1).repeat(batch_size, 1).gather(dim=1, index=tags[:, :1]) + \ | |||||
feats[:, 0].gather(dim=1, index=tags[:, :1]) | |||||
else: | |||||
alpha = feats[:, 0].gather(dim=1, index=tags[:, :1]) | |||||
for i in range(1, max_len): | |||||
trans_score = self.transition_m[( | |||||
tags[:, i - 1], tags[:, i])].unsqueeze(1) | |||||
emit_score = feats[:, i].gather(dim=1, index=tags[:, i:i + 1]) | |||||
new_alpha = alpha + trans_score + emit_score | |||||
alpha = new_alpha * \ | |||||
masks[:, i:i + 1].float() + alpha * \ | |||||
(1 - masks[:, i:i + 1].float()) | |||||
seq_len, batch_size, _ = logits.size() | |||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | |||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | |||||
# trans_socre [L-1, B] | |||||
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]] * mask[1:, :] | |||||
# emit_score [L, B] | |||||
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask | |||||
# score [L-1, B] | |||||
score = trans_score + emit_score[:seq_len-1, :] | |||||
score = score.sum(0) + emit_score[-1] * mask[-1] | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
last_tag_index = masks.cumsum(dim=1, dtype=torch.long)[:, -1:] - 1 | |||||
last_from_tag_id = tags.gather(dim=1, index=last_tag_index) | |||||
trans_score = self.end_scores.view( | |||||
1, -1).repeat(batch_size, 1).gather(dim=1, index=last_from_tag_id) | |||||
alpha = alpha + trans_score | |||||
return alpha.squeeze(1) | |||||
def forward(self, feats, tags, masks): | |||||
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | |||||
last_idx = mask.long().sum(0) - 1 | |||||
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] | |||||
score += st_scores + ed_scores | |||||
# return [B,] | |||||
return score | |||||
def forward(self, feats, tags, mask): | |||||
""" | """ | ||||
Calculate the neg log likelihood | Calculate the neg log likelihood | ||||
:param feats:FloatTensor, batch_size x max_len x tag_size | :param feats:FloatTensor, batch_size x max_len x tag_size | ||||
:param tags:LongTensor, batch_size x max_len | :param tags:LongTensor, batch_size x max_len | ||||
:param masks:ByteTensor batch_size x max_len | |||||
:param mask:ByteTensor batch_size x max_len | |||||
:return:FloatTensor, batch_size | :return:FloatTensor, batch_size | ||||
""" | """ | ||||
all_path_score = self._normalizer_likelihood(feats, masks) | |||||
gold_path_score = self._glod_score(feats, tags, masks) | |||||
feats = feats.transpose(0, 1) | |||||
tags = tags.transpose(0, 1).long() | |||||
mask = mask.transpose(0, 1).float() | |||||
all_path_score = self._normalizer_likelihood(feats, mask) | |||||
gold_path_score = self._glod_score(feats, tags, mask) | |||||
return all_path_score - gold_path_score | return all_path_score - gold_path_score | ||||
def viterbi_decode(self, feats, masks, get_score=False): | |||||
def viterbi_decode(self, data, mask, get_score=False): | |||||
""" | """ | ||||
Given a feats matrix, return best decode path and best score. | Given a feats matrix, return best decode path and best score. | ||||
:param feats: | |||||
:param masks: | |||||
:param data:FloatTensor, batch_size x max_len x tag_size | |||||
:param mask:ByteTensor batch_size x max_len | |||||
:param get_score: bool, whether to output the decode score. | :param get_score: bool, whether to output the decode score. | ||||
:return:List[Tuple(List, float)], | |||||
:return: scores, paths | |||||
""" | """ | ||||
batch_size, max_len, tag_size = feats.size() | |||||
batch_size, seq_len, n_tags = data.size() | |||||
data = data.transpose(0, 1).data # L, B, H | |||||
mask = mask.transpose(0, 1).data.float() # L, B | |||||
paths = torch.zeros(batch_size, max_len - 1, self.tag_size) | |||||
# dp | |||||
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||||
vscore = data[0] | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha = self.start_scores.repeat(batch_size, 1) + feats[:, 0] | |||||
else: | |||||
alpha = feats[:, 0] | |||||
for i in range(1, max_len): | |||||
new_alpha = alpha.clone() | |||||
for t in range(self.tag_size): | |||||
pre_scores = self.transition_m[:, t].view( | |||||
1, self.tag_size) + alpha | |||||
max_score, indices = pre_scores.max(dim=1) | |||||
new_alpha[:, t] = max_score + feats[:, i, t] | |||||
paths[:, i - 1, t] = indices | |||||
alpha = new_alpha * masks[:, i:i + 1].float() + alpha * (1 - masks[:, i:i + 1].float()) | |||||
vscore += self.start_scores.view(1, -1) | |||||
for i in range(1, seq_len): | |||||
prev_score = vscore.view(batch_size, n_tags, 1) | |||||
cur_score = data[i].view(batch_size, 1, n_tags) | |||||
trans_score = self.trans_m.view(1, n_tags, n_tags).data | |||||
score = prev_score + trans_score + cur_score | |||||
best_score, best_dst = score.max(1) | |||||
vpath[i] = best_dst | |||||
vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1) | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha += self.end_scores.view(1, -1) | |||||
max_scores, indices = alpha.max(dim=1) | |||||
indices = indices.cpu().numpy() | |||||
final_paths = [] | |||||
paths = paths.cpu().numpy().astype(int) | |||||
seq_lens = masks.cumsum(dim=1, dtype=torch.long)[:, -1] | |||||
vscore += self.end_scores.view(1, -1) | |||||
# backtrace | |||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device) | |||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device) | |||||
lens = (mask.long().sum(0) - 1) | |||||
# idxes [L, B], batched idx from seq_len-1 to 0 | |||||
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len | |||||
ans = data.new_empty((seq_len, batch_size), dtype=torch.long) | |||||
ans_score, last_tags = vscore.max(1) | |||||
ans[idxes[0], batch_idx] = last_tags | |||||
for i in range(seq_len - 1): | |||||
last_tags = vpath[idxes[i], batch_idx, last_tags] | |||||
ans[idxes[i+1], batch_idx] = last_tags | |||||
for b in range(batch_size): | |||||
path = [indices[b]] | |||||
for i in range(seq_lens[b] - 2, -1, -1): | |||||
index = paths[b, i, path[-1]] | |||||
path.append(index) | |||||
final_paths.append(path[::-1]) | |||||
if get_score: | if get_score: | ||||
return list(zip(final_paths, max_scores.detach().cpu().numpy())) | |||||
else: | |||||
return final_paths | |||||
return ans_score, ans.transpose(0, 1) | |||||
return ans.transpose(0, 1) |
@@ -4,12 +4,13 @@ from fastNLP.modules.utils import initial_parameter | |||||
class MLP(nn.Module): | class MLP(nn.Module): | ||||
def __init__(self, size_layer, activation='relu', initial_method=None): | |||||
def __init__(self, size_layer, activation='relu', initial_method=None, dropout=0.0): | |||||
"""Multilayer Perceptrons as a decoder | """Multilayer Perceptrons as a decoder | ||||
:param size_layer: list of int, define the size of MLP layers. | :param size_layer: list of int, define the size of MLP layers. | ||||
:param activation: str or function, the activation function for hidden layers. | :param activation: str or function, the activation function for hidden layers. | ||||
:param initial_method: str, the name of init method. | :param initial_method: str, the name of init method. | ||||
:param dropout: float, the probability of dropout. | |||||
.. note:: | .. note:: | ||||
There is no activation function applying on output layer. | There is no activation function applying on output layer. | ||||
@@ -24,6 +25,8 @@ class MLP(nn.Module): | |||||
else: | else: | ||||
self.hiddens.append(nn.Linear(size_layer[i-1], size_layer[i])) | self.hiddens.append(nn.Linear(size_layer[i-1], size_layer[i])) | ||||
self.dropout = nn.Dropout(p=dropout) | |||||
actives = { | actives = { | ||||
'relu': nn.ReLU(), | 'relu': nn.ReLU(), | ||||
'tanh': nn.Tanh(), | 'tanh': nn.Tanh(), | ||||
@@ -38,8 +41,8 @@ class MLP(nn.Module): | |||||
def forward(self, x): | def forward(self, x): | ||||
for layer in self.hiddens: | for layer in self.hiddens: | ||||
x = self.hidden_active(layer(x)) | |||||
x = self.output(x) | |||||
x = self.dropout(self.hidden_active(layer(x))) | |||||
x = self.dropout(self.output(x)) | |||||
return x | return x | ||||
@@ -1,13 +1,15 @@ | |||||
import torch | import torch | ||||
class TimestepDropout(torch.nn.Dropout): | class TimestepDropout(torch.nn.Dropout): | ||||
"""This module accepts a `[batch_size, num_timesteps, embedding_dim)]` and use a single | """This module accepts a `[batch_size, num_timesteps, embedding_dim)]` and use a single | ||||
dropout mask of shape `(batch_size, embedding_dim)` to apply on every time step. | dropout mask of shape `(batch_size, embedding_dim)` to apply on every time step. | ||||
""" | """ | ||||
def forward(self, x): | def forward(self, x): | ||||
dropout_mask = x.new_ones(x.shape[0], x.shape[-1]) | dropout_mask = x.new_ones(x.shape[0], x.shape[-1]) | ||||
torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True) | torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True) | ||||
dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim] | |||||
dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim] | |||||
if self.inplace: | if self.inplace: | ||||
x *= dropout_mask | x *= dropout_mask | ||||
return | return | ||||
@@ -43,7 +43,7 @@ class ConvCharEmbedding(nn.Module): | |||||
# [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1] | # [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1] | ||||
y = torch.squeeze(y, 2) | y = torch.squeeze(y, 2) | ||||
# [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1] | # [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1] | ||||
y = F.tanh(y) | |||||
y = torch.tanh(y) | |||||
y, __ = torch.max(y, 2) | y, __ = torch.max(y, 2) | ||||
# [batch_size*sent_length, feature_maps[i]] | # [batch_size*sent_length, feature_maps[i]] | ||||
feats.append(y) | feats.append(y) | ||||
@@ -34,8 +34,6 @@ class ConvMaxpool(nn.Module): | |||||
bias=bias) | bias=bias) | ||||
for oc, ks in zip(out_channels, kernel_sizes)]) | for oc, ks in zip(out_channels, kernel_sizes)]) | ||||
for conv in self.convs: | |||||
xavier_uniform_(conv.weight) # weight initialization | |||||
else: | else: | ||||
raise Exception( | raise Exception( | ||||
'Incorrect kernel sizes: should be list, tuple or int') | 'Incorrect kernel sizes: should be list, tuple or int') | ||||
@@ -0,0 +1,32 @@ | |||||
import torch | |||||
from torch import nn | |||||
import torch.nn.functional as F | |||||
from ..aggregator.attention import MultiHeadAtte | |||||
from ..other_modules import LayerNormalization | |||||
class TransformerEncoder(nn.Module): | |||||
class SubLayer(nn.Module): | |||||
def __init__(self, input_size, output_size, key_size, value_size, num_atte): | |||||
super(TransformerEncoder.SubLayer, self).__init__() | |||||
self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte) | |||||
self.norm1 = LayerNormalization(output_size) | |||||
self.ffn = nn.Sequential(nn.Linear(output_size, output_size), | |||||
nn.ReLU(), | |||||
nn.Linear(output_size, output_size)) | |||||
self.norm2 = LayerNormalization(output_size) | |||||
def forward(self, input, seq_mask): | |||||
attention = self.atte(input) | |||||
norm_atte = self.norm1(attention + input) | |||||
output = self.ffn(norm_atte) | |||||
return self.norm2(output + norm_atte) | |||||
def __init__(self, num_layers, **kargs): | |||||
super(TransformerEncoder, self).__init__() | |||||
self.layers = nn.Sequential(*[self.SubLayer(**kargs) for _ in range(num_layers)]) | |||||
def forward(self, x, seq_mask=None): | |||||
return self.layers(x, seq_mask) | |||||
@@ -101,14 +101,14 @@ class VarRNNBase(nn.Module): | |||||
mask_x = input.new_ones((batch_size, self.input_size)) | mask_x = input.new_ones((batch_size, self.input_size)) | ||||
mask_out = input.new_ones((batch_size, self.hidden_size * self.num_directions)) | mask_out = input.new_ones((batch_size, self.hidden_size * self.num_directions)) | ||||
mask_h = input.new_ones((batch_size, self.hidden_size)) | |||||
mask_h_ones = input.new_ones((batch_size, self.hidden_size)) | |||||
nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | ||||
nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | ||||
nn.functional.dropout(mask_h, p=self.hidden_dropout, training=self.training, inplace=True) | |||||
hidden_list = [] | hidden_list = [] | ||||
for layer in range(self.num_layers): | for layer in range(self.num_layers): | ||||
output_list = [] | output_list = [] | ||||
mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) | |||||
for direction in range(self.num_directions): | for direction in range(self.num_directions): | ||||
input_x = input if direction == 0 else flip(input, [0]) | input_x = input if direction == 0 else flip(input, [0]) | ||||
idx = self.num_directions * layer + direction | idx = self.num_directions * layer + direction | ||||
@@ -31,12 +31,12 @@ class GroupNorm(nn.Module): | |||||
class LayerNormalization(nn.Module): | class LayerNormalization(nn.Module): | ||||
""" Layer normalization module """ | """ Layer normalization module """ | ||||
def __init__(self, d_hid, eps=1e-3): | |||||
def __init__(self, layer_size, eps=1e-3): | |||||
super(LayerNormalization, self).__init__() | super(LayerNormalization, self).__init__() | ||||
self.eps = eps | self.eps = eps | ||||
self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True) | |||||
self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True) | |||||
self.a_2 = nn.Parameter(torch.ones(1, layer_size, requires_grad=True)) | |||||
self.b_2 = nn.Parameter(torch.zeros(1, layer_size, requires_grad=True)) | |||||
def forward(self, z): | def forward(self, z): | ||||
if z.size(1) == 1: | if z.size(1) == 1: | ||||
@@ -44,9 +44,8 @@ class LayerNormalization(nn.Module): | |||||
mu = torch.mean(z, keepdim=True, dim=-1) | mu = torch.mean(z, keepdim=True, dim=-1) | ||||
sigma = torch.std(z, keepdim=True, dim=-1) | sigma = torch.std(z, keepdim=True, dim=-1) | ||||
ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps) | |||||
ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out) | |||||
ln_out = (z - mu) / (sigma + self.eps) | |||||
ln_out = ln_out * self.a_2 + self.b_2 | |||||
return ln_out | return ln_out | ||||
@@ -32,9 +32,9 @@ def initial_parameter(net, initial_method=None): | |||||
elif initial_method == 'xavier_normal': | elif initial_method == 'xavier_normal': | ||||
init_method = init.xavier_normal_ | init_method = init.xavier_normal_ | ||||
elif initial_method == 'kaiming_normal' or initial_method == 'msra': | elif initial_method == 'kaiming_normal' or initial_method == 'msra': | ||||
init_method = init.kaiming_normal | |||||
init_method = init.kaiming_normal_ | |||||
elif initial_method == 'kaiming_uniform': | elif initial_method == 'kaiming_uniform': | ||||
init_method = init.kaiming_normal | |||||
init_method = init.kaiming_uniform_ | |||||
elif initial_method == 'orthogonal': | elif initial_method == 'orthogonal': | ||||
init_method = init.orthogonal_ | init_method = init.orthogonal_ | ||||
elif initial_method == 'sparse': | elif initial_method == 'sparse': | ||||
@@ -42,7 +42,7 @@ def initial_parameter(net, initial_method=None): | |||||
elif initial_method == 'normal': | elif initial_method == 'normal': | ||||
init_method = init.normal_ | init_method = init.normal_ | ||||
elif initial_method == 'uniform': | elif initial_method == 'uniform': | ||||
initial_method = init.uniform_ | |||||
init_method = init.uniform_ | |||||
else: | else: | ||||
init_method = init.xavier_normal_ | init_method = init.xavier_normal_ | ||||
@@ -77,11 +77,13 @@ def initial_parameter(net, initial_method=None): | |||||
def seq_mask(seq_len, max_len): | def seq_mask(seq_len, max_len): | ||||
"""Create sequence mask. | """Create sequence mask. | ||||
:param seq_len: list of int, the lengths of sequences in a batch. | |||||
:param seq_len: list or torch.Tensor, the lengths of sequences in a batch. | |||||
:param max_len: int, the maximum sequence length in a batch. | :param max_len: int, the maximum sequence length in a batch. | ||||
:return mask: torch.LongTensor, [batch_size, max_len] | :return mask: torch.LongTensor, [batch_size, max_len] | ||||
""" | """ | ||||
mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)] | |||||
mask = torch.stack(mask, 1) | |||||
return mask | |||||
if not isinstance(seq_len, torch.Tensor): | |||||
seq_len = torch.LongTensor(seq_len) | |||||
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] | |||||
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] | |||||
return torch.gt(seq_len, seq_range) # [batch_size, max_len] |
@@ -1,150 +0,0 @@ | |||||
import os | |||||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | |||||
from fastNLP.saver.logger import create_logger | |||||
class ConfigSaver(object): | |||||
def __init__(self, file_path): | |||||
self.file_path = file_path | |||||
if not os.path.exists(self.file_path): | |||||
raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) | |||||
def _get_section(self, sect_name): | |||||
"""This is the function to get the section with the section name. | |||||
:param sect_name: The name of section what wants to load. | |||||
:return: The section. | |||||
""" | |||||
sect = ConfigSection() | |||||
ConfigLoader().load_config(self.file_path, {sect_name: sect}) | |||||
return sect | |||||
def _read_section(self): | |||||
"""This is the function to read sections from the config file. | |||||
:return: sect_list, sect_key_list | |||||
sect_list: A list of ConfigSection(). | |||||
sect_key_list: A list of names in sect_list. | |||||
""" | |||||
sect_name = None | |||||
sect_list = {} | |||||
sect_key_list = [] | |||||
single_section = {} | |||||
single_section_key = [] | |||||
with open(self.file_path, 'r') as f: | |||||
lines = f.readlines() | |||||
for line in lines: | |||||
if line.startswith('[') and line.endswith(']\n'): | |||||
if sect_name is None: | |||||
pass | |||||
else: | |||||
sect_list[sect_name] = single_section, single_section_key | |||||
single_section = {} | |||||
single_section_key = [] | |||||
sect_key_list.append(sect_name) | |||||
sect_name = line[1: -2] | |||||
continue | |||||
if line.startswith('#'): | |||||
single_section[line] = '#' | |||||
single_section_key.append(line) | |||||
continue | |||||
if line.startswith('\n'): | |||||
single_section_key.append('\n') | |||||
continue | |||||
if '=' not in line: | |||||
log = create_logger(__name__, './config_saver.log') | |||||
log.error("can NOT load config file [%s]" % self.file_path) | |||||
raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) | |||||
key = line.split('=', maxsplit=1)[0].strip() | |||||
value = line.split('=', maxsplit=1)[1].strip() + '\n' | |||||
single_section[key] = value | |||||
single_section_key.append(key) | |||||
if sect_name is not None: | |||||
sect_list[sect_name] = single_section, single_section_key | |||||
sect_key_list.append(sect_name) | |||||
return sect_list, sect_key_list | |||||
def _write_section(self, sect_list, sect_key_list): | |||||
"""This is the function to write config file with section list and name list. | |||||
:param sect_list: A list of ConfigSection() need to be writen into file. | |||||
:param sect_key_list: A list of name of sect_list. | |||||
:return: | |||||
""" | |||||
with open(self.file_path, 'w') as f: | |||||
for sect_key in sect_key_list: | |||||
single_section, single_section_key = sect_list[sect_key] | |||||
f.write('[' + sect_key + ']\n') | |||||
for key in single_section_key: | |||||
if key == '\n': | |||||
f.write('\n') | |||||
continue | |||||
if single_section[key] == '#': | |||||
f.write(key) | |||||
continue | |||||
f.write(key + ' = ' + single_section[key]) | |||||
f.write('\n') | |||||
def save_config_file(self, section_name, section): | |||||
"""This is the function to be called to change the config file with a single section and its name. | |||||
:param section_name: The name of section what needs to be changed and saved. | |||||
:param section: The section with key and value what needs to be changed and saved. | |||||
:return: | |||||
""" | |||||
section_file = self._get_section(section_name) | |||||
if len(section_file.__dict__.keys()) == 0: # the section not in the file before | |||||
# append this section to config file | |||||
with open(self.file_path, 'a') as f: | |||||
f.write('[' + section_name + ']\n') | |||||
for k in section.__dict__.keys(): | |||||
f.write(k + ' = ') | |||||
if isinstance(section[k], str): | |||||
f.write('\"' + str(section[k]) + '\"\n\n') | |||||
else: | |||||
f.write(str(section[k]) + '\n\n') | |||||
else: | |||||
# the section exists | |||||
change_file = False | |||||
for k in section.__dict__.keys(): | |||||
if k not in section_file: | |||||
# find a new key in this section | |||||
change_file = True | |||||
break | |||||
if section_file[k] != section[k]: | |||||
logger = create_logger(__name__, "./config_loader.log") | |||||
logger.warning("section [%s] in config file [%s] has been changed" % ( | |||||
section_name, self.file_path | |||||
)) | |||||
change_file = True | |||||
break | |||||
if not change_file: | |||||
return | |||||
sect_list, sect_key_list = self._read_section() | |||||
if section_name not in sect_key_list: | |||||
raise AttributeError() | |||||
sect, sect_key = sect_list[section_name] | |||||
for k in section.__dict__.keys(): | |||||
if k not in sect_key: | |||||
if sect_key[-1] != '\n': | |||||
sect_key.append('\n') | |||||
sect_key.append(k) | |||||
sect[k] = str(section[k]) | |||||
if isinstance(section[k], str): | |||||
sect[k] = "\"" + sect[k] + "\"" | |||||
sect[k] = sect[k] + "\n" | |||||
sect_list[section_name] = sect, sect_key | |||||
self._write_section(sect_list, sect_key_list) |
@@ -1,24 +0,0 @@ | |||||
import torch | |||||
class ModelSaver(object): | |||||
"""Save a model | |||||
Example:: | |||||
saver = ModelSaver("./save/model_ckpt_100.pkl") | |||||
saver.save_pytorch(model) | |||||
""" | |||||
def __init__(self, save_path): | |||||
""" | |||||
:param save_path: str, the path to the saving directory. | |||||
""" | |||||
self.save_path = save_path | |||||
def save_pytorch(self, model): | |||||
"""Save a pytorch model into .pkl file. | |||||
:param model: a PyTorch model | |||||
""" | |||||
torch.save(model.state_dict(), self.save_path) |
@@ -1,37 +1,40 @@ | |||||
[train] | [train] | ||||
epochs = 50 | |||||
epochs = -1 | |||||
batch_size = 16 | batch_size = 16 | ||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
validate = true | validate = true | ||||
save_best_dev = false | |||||
save_best_dev = true | |||||
eval_sort_key = "UAS" | |||||
use_cuda = true | use_cuda = true | ||||
model_saved_path = "./save/" | model_saved_path = "./save/" | ||||
task = "parse" | |||||
print_every_step = 20 | |||||
use_golden_train=true | |||||
[test] | [test] | ||||
save_output = true | save_output = true | ||||
validate_in_training = true | validate_in_training = true | ||||
save_dev_input = false | save_dev_input = false | ||||
save_loss = true | save_loss = true | ||||
batch_size = 16 | |||||
batch_size = 64 | |||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
use_cuda = true | use_cuda = true | ||||
task = "parse" | |||||
[model] | [model] | ||||
word_vocab_size = -1 | word_vocab_size = -1 | ||||
word_emb_dim = 100 | word_emb_dim = 100 | ||||
pos_vocab_size = -1 | pos_vocab_size = -1 | ||||
pos_emb_dim = 100 | pos_emb_dim = 100 | ||||
word_hid_dim = 100 | |||||
pos_hid_dim = 100 | |||||
rnn_layers = 3 | rnn_layers = 3 | ||||
rnn_hidden_size = 400 | rnn_hidden_size = 400 | ||||
arc_mlp_size = 500 | arc_mlp_size = 500 | ||||
label_mlp_size = 100 | label_mlp_size = 100 | ||||
num_label = -1 | num_label = -1 | ||||
dropout = 0.33 | dropout = 0.33 | ||||
use_var_lstm=true | |||||
use_var_lstm=false | |||||
use_greedy_infer=false | use_greedy_infer=false | ||||
[optim] | [optim] | ||||
lr = 2e-3 | lr = 2e-3 | ||||
weight_decay = 5e-5 |
@@ -0,0 +1,83 @@ | |||||
import os | |||||
import sys | |||||
sys.path.extend(['/home/yfshao/workdir/dev_fastnlp']) | |||||
from fastNLP.api.processor import * | |||||
from fastNLP.models.biaffine_parser import BiaffineParser | |||||
from fastNLP.io.config_io import ConfigSection, ConfigLoader | |||||
import _pickle as pickle | |||||
import torch | |||||
def _load(path): | |||||
with open(path, 'rb') as f: | |||||
obj = pickle.load(f) | |||||
return obj | |||||
def _load_all(src): | |||||
model_path = src | |||||
src = os.path.dirname(src) | |||||
word_v = _load(src+'/word_v.pkl') | |||||
pos_v = _load(src+'/pos_v.pkl') | |||||
tag_v = _load(src+'/tag_v.pkl') | |||||
pos_pp = torch.load(src+'/pos_pp.pkl')['pipeline'] | |||||
model_args = ConfigSection() | |||||
ConfigLoader.load_config('cfg.cfg', {'model': model_args}) | |||||
model_args['word_vocab_size'] = len(word_v) | |||||
model_args['pos_vocab_size'] = len(pos_v) | |||||
model_args['num_label'] = len(tag_v) | |||||
model = BiaffineParser(**model_args.data) | |||||
model.load_state_dict(torch.load(model_path)) | |||||
return { | |||||
'word_v': word_v, | |||||
'pos_v': pos_v, | |||||
'tag_v': tag_v, | |||||
'model': model, | |||||
'pos_pp':pos_pp, | |||||
} | |||||
def build(load_path, save_path): | |||||
BOS = '<BOS>' | |||||
NUM = '<NUM>' | |||||
_dict = _load_all(load_path) | |||||
word_vocab = _dict['word_v'] | |||||
pos_vocab = _dict['pos_v'] | |||||
tag_vocab = _dict['tag_v'] | |||||
pos_pp = _dict['pos_pp'] | |||||
model = _dict['model'] | |||||
print('load model from {}'.format(load_path)) | |||||
word_seq = 'raw_word_seq' | |||||
pos_seq = 'raw_pos_seq' | |||||
# build pipeline | |||||
# input | |||||
pipe = pos_pp | |||||
pipe.pipeline.pop(-1) | |||||
pipe.add_processor(Num2TagProcessor(NUM, 'word_list', word_seq)) | |||||
pipe.add_processor(PreAppendProcessor(BOS, word_seq)) | |||||
pipe.add_processor(PreAppendProcessor(BOS, 'pos_list', pos_seq)) | |||||
pipe.add_processor(IndexerProcessor(word_vocab, word_seq, 'word_seq')) | |||||
pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, 'pos_seq')) | |||||
pipe.add_processor(SeqLenProcessor('word_seq', 'word_seq_origin_len')) | |||||
pipe.add_processor(SetTensorProcessor({'word_seq':True, 'pos_seq':True, 'word_seq_origin_len':True}, default=False)) | |||||
pipe.add_processor(ModelProcessor(model, 'word_seq_origin_len')) | |||||
pipe.add_processor(SliceProcessor(1, None, None, 'head_pred', 'heads')) | |||||
pipe.add_processor(SliceProcessor(1, None, None, 'label_pred', 'label_pred')) | |||||
pipe.add_processor(Index2WordProcessor(tag_vocab, 'label_pred', 'labels')) | |||||
if not os.path.exists(save_path): | |||||
os.makedirs(save_path) | |||||
with open(save_path+'/pipeline.pkl', 'wb') as f: | |||||
torch.save({'pipeline': pipe}, f) | |||||
print('save pipeline in {}'.format(save_path)) | |||||
import argparse | |||||
parser = argparse.ArgumentParser(description='build pipeline for parser.') | |||||
parser.add_argument('--src', type=str, default='/home/yfshao/workdir/dev_fastnlp/reproduction/Biaffine_parser/save') | |||||
parser.add_argument('--dst', type=str, default='/home/yfshao/workdir/dev_fastnlp/reproduction/Biaffine_parser/pipe') | |||||
args = parser.parse_args() | |||||
build(args.src, args.dst) |
@@ -0,0 +1,114 @@ | |||||
import sys | |||||
sys.path.extend(['/home/yfshao/workdir/dev_fastnlp']) | |||||
import torch | |||||
import argparse | |||||
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('--pipe', type=str, default='') | |||||
parser.add_argument('--gold_data', type=str, default='') | |||||
parser.add_argument('--new_data', type=str) | |||||
args = parser.parse_args() | |||||
pipe = torch.load(args.pipe)['pipeline'] | |||||
for p in pipe: | |||||
if p.field_name == 'word_list': | |||||
print(p.field_name) | |||||
p.field_name = 'gold_words' | |||||
elif p.field_name == 'pos_list': | |||||
print(p.field_name) | |||||
p.field_name = 'gold_pos' | |||||
data = ConllxDataLoader().load(args.gold_data) | |||||
ds = DataSet() | |||||
for ins1, ins2 in zip(add_seg_tag(data), data): | |||||
ds.append(Instance(words=ins1[0], tag=ins1[1], | |||||
gold_words=ins2[0], gold_pos=ins2[1], | |||||
gold_heads=ins2[2], gold_head_tags=ins2[3])) | |||||
ds = pipe(ds) | |||||
seg_threshold = 0. | |||||
pos_threshold = 0. | |||||
parse_threshold = 0.74 | |||||
def get_heads(ins, head_f, word_f): | |||||
head_pred = [] | |||||
for i, idx in enumerate(ins[head_f]): | |||||
j = idx - 1 if idx != 0 else i | |||||
head_pred.append(ins[word_f][j]) | |||||
return head_pred | |||||
def evaluate(ins): | |||||
seg_count = sum([1 for i, j in zip(ins['word_list'], ins['gold_words']) if i == j]) | |||||
pos_count = sum([1 for i, j in zip(ins['pos_list'], ins['gold_pos']) if i == j]) | |||||
head_count = sum([1 for i, j in zip(ins['heads'], ins['gold_heads']) if i == j]) | |||||
total = len(ins['gold_words']) | |||||
return seg_count / total, pos_count / total, head_count / total | |||||
def is_ok(x): | |||||
seg, pos, head = x[1] | |||||
return seg > seg_threshold and pos > pos_threshold and head > parse_threshold | |||||
res_list = [] | |||||
for i, ins in enumerate(ds): | |||||
res_list.append((i, evaluate(ins))) | |||||
res_list = list(filter(is_ok, res_list)) | |||||
print('{} {}'.format(len(ds), len(res_list))) | |||||
seg_cor, pos_cor, head_cor, label_cor, total = 0,0,0,0,0 | |||||
for i, _ in res_list: | |||||
ins = ds[i] | |||||
# print(i) | |||||
# print('gold_words:\t', ins['gold_words']) | |||||
# print('predict_words:\t', ins['word_list']) | |||||
# print('gold_tag:\t', ins['gold_pos']) | |||||
# print('predict_tag:\t', ins['pos_list']) | |||||
# print('gold_heads:\t', ins['gold_heads']) | |||||
# print('predict_heads:\t', ins['heads'].tolist()) | |||||
# print('gold_head_tags:\t', ins['gold_head_tags']) | |||||
# print('predict_labels:\t', ins['labels']) | |||||
# print() | |||||
head_pred = ins['heads'] | |||||
head_gold = ins['gold_heads'] | |||||
label_pred = ins['labels'] | |||||
label_gold = ins['gold_head_tags'] | |||||
total += len(head_gold) | |||||
seg_cor += sum([1 for i, j in zip(ins['word_list'], ins['gold_words']) if i == j]) | |||||
pos_cor += sum([1 for i, j in zip(ins['pos_list'], ins['gold_pos']) if i == j]) | |||||
length = len(head_gold) | |||||
for i in range(length): | |||||
head_cor += 1 if head_pred[i] == head_gold[i] else 0 | |||||
label_cor += 1 if head_pred[i] == head_gold[i] and label_gold[i] == label_pred[i] else 0 | |||||
print('SEG: {}, POS: {}, UAS: {}, LAS: {}'.format(seg_cor/total, pos_cor/total, head_cor/total, label_cor/total)) | |||||
colln_path = args.gold_data | |||||
new_colln_path = args.new_data | |||||
index_list = [x[0] for x in res_list] | |||||
with open(colln_path, 'r', encoding='utf-8') as f1, \ | |||||
open(new_colln_path, 'w', encoding='utf-8') as f2: | |||||
for idx, ins in enumerate(ds): | |||||
if idx in index_list: | |||||
length = len(ins['gold_words']) | |||||
pad = ['_' for _ in range(length)] | |||||
for x in zip( | |||||
map(str, range(1, length+1)), ins['gold_words'], ins['gold_words'], ins['gold_pos'], | |||||
pad, pad, map(str, ins['gold_heads']), ins['gold_head_tags']): | |||||
new_lines = '\t'.join(x) | |||||
f2.write(new_lines) | |||||
f2.write('\n') | |||||
f2.write('\n') |
@@ -3,34 +3,33 @@ import sys | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | ||||
from collections import defaultdict | |||||
import math | |||||
import torch | import torch | ||||
import re | |||||
from fastNLP.core.trainer import Trainer | from fastNLP.core.trainer import Trainer | ||||
from fastNLP.core.metrics import Evaluator | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.field import TextField, SeqLabelField | from fastNLP.core.field import TextField, SeqLabelField | ||||
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||||
from fastNLP.core.tester import Tester | from fastNLP.core.tester import Tester | ||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.loader.model_loader import ModelLoader | |||||
from fastNLP.loader.embed_loader import EmbedLoader | |||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||||
from fastNLP.io.model_io import ModelLoader, ModelSaver | |||||
from fastNLP.io.embed_loader import EmbedLoader | |||||
from fastNLP.models.biaffine_parser import BiaffineParser | from fastNLP.models.biaffine_parser import BiaffineParser | ||||
from fastNLP.saver.model_saver import ModelSaver | |||||
BOS = '<BOS>' | |||||
EOS = '<EOS>' | |||||
UNK = '<OOV>' | |||||
NUM = '<NUM>' | |||||
ENG = '<ENG>' | |||||
# not in the file's dir | # not in the file's dir | ||||
if len(os.path.dirname(__file__)) != 0: | if len(os.path.dirname(__file__)) != 0: | ||||
os.chdir(os.path.dirname(__file__)) | os.chdir(os.path.dirname(__file__)) | ||||
class MyDataLoader(object): | |||||
def __init__(self, pickle_path): | |||||
self.pickle_path = pickle_path | |||||
def load(self, path, word_v=None, pos_v=None, headtag_v=None): | |||||
class ConlluDataLoader(object): | |||||
def load(self, path): | |||||
datalist = [] | datalist = [] | ||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
sample = [] | sample = [] | ||||
@@ -49,23 +48,18 @@ class MyDataLoader(object): | |||||
for sample in datalist: | for sample in datalist: | ||||
# print(sample) | # print(sample) | ||||
res = self.get_one(sample) | res = self.get_one(sample) | ||||
if word_v is not None: | |||||
word_v.update(res[0]) | |||||
pos_v.update(res[1]) | |||||
headtag_v.update(res[3]) | |||||
ds.append(Instance(word_seq=TextField(res[0], is_target=False), | ds.append(Instance(word_seq=TextField(res[0], is_target=False), | ||||
pos_seq=TextField(res[1], is_target=False), | pos_seq=TextField(res[1], is_target=False), | ||||
head_indices=SeqLabelField(res[2], is_target=True), | head_indices=SeqLabelField(res[2], is_target=True), | ||||
head_labels=TextField(res[3], is_target=True), | |||||
seq_mask=SeqLabelField([1 for _ in range(len(res[0]))], is_target=False))) | |||||
head_labels=TextField(res[3], is_target=True))) | |||||
return ds | return ds | ||||
def get_one(self, sample): | def get_one(self, sample): | ||||
text = ['<root>'] | |||||
pos_tags = ['<root>'] | |||||
heads = [0] | |||||
head_tags = ['root'] | |||||
text = [] | |||||
pos_tags = [] | |||||
heads = [] | |||||
head_tags = [] | |||||
for w in sample: | for w in sample: | ||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | ||||
if t3 == '_': | if t3 == '_': | ||||
@@ -76,17 +70,60 @@ class MyDataLoader(object): | |||||
head_tags.append(t4) | head_tags.append(t4) | ||||
return (text, pos_tags, heads, head_tags) | return (text, pos_tags, heads, head_tags) | ||||
def index_data(self, dataset, word_v, pos_v, tag_v): | |||||
dataset.index_field('word_seq', word_v) | |||||
dataset.index_field('pos_seq', pos_v) | |||||
dataset.index_field('head_labels', tag_v) | |||||
class CTBDataLoader(object): | |||||
def load(self, data_path): | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
lines = f.readlines() | |||||
data = self.parse(lines) | |||||
return self.convert(data) | |||||
def parse(self, lines): | |||||
""" | |||||
[ | |||||
[word], [pos], [head_index], [head_tag] | |||||
] | |||||
""" | |||||
sample = [] | |||||
data = [] | |||||
for i, line in enumerate(lines): | |||||
line = line.strip() | |||||
if len(line) == 0 or i+1 == len(lines): | |||||
data.append(list(map(list, zip(*sample)))) | |||||
sample = [] | |||||
else: | |||||
sample.append(line.split()) | |||||
return data | |||||
def convert(self, data): | |||||
dataset = DataSet() | |||||
for sample in data: | |||||
word_seq = [BOS] + sample[0] + [EOS] | |||||
pos_seq = [BOS] + sample[1] + [EOS] | |||||
heads = [0] + list(map(int, sample[2])) + [0] | |||||
head_tags = [BOS] + sample[3] + [EOS] | |||||
dataset.append(Instance(word_seq=TextField(word_seq, is_target=False), | |||||
pos_seq=TextField(pos_seq, is_target=False), | |||||
gold_heads=SeqLabelField(heads, is_target=False), | |||||
head_indices=SeqLabelField(heads, is_target=True), | |||||
head_labels=TextField(head_tags, is_target=True))) | |||||
return dataset | |||||
# datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT" | # datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT" | ||||
datadir = "/home/yfshao/UD_English-EWT" | |||||
# datadir = "/home/yfshao/UD_English-EWT" | |||||
# train_data_name = "en_ewt-ud-train.conllu" | |||||
# dev_data_name = "en_ewt-ud-dev.conllu" | |||||
# emb_file_name = '/home/yfshao/glove.6B.100d.txt' | |||||
# loader = ConlluDataLoader() | |||||
datadir = '/home/yfshao/workdir/parser-data/' | |||||
train_data_name = "train_ctb5.txt" | |||||
dev_data_name = "dev_ctb5.txt" | |||||
test_data_name = "test_ctb5.txt" | |||||
emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt" | |||||
# emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec" | |||||
loader = CTBDataLoader() | |||||
cfgfile = './cfg.cfg' | cfgfile = './cfg.cfg' | ||||
train_data_name = "en_ewt-ud-train.conllu" | |||||
dev_data_name = "en_ewt-ud-dev.conllu" | |||||
emb_file_name = '/home/yfshao/glove.6B.100d.txt' | |||||
processed_datadir = './save' | processed_datadir = './save' | ||||
# Config Loader | # Config Loader | ||||
@@ -95,8 +132,12 @@ test_args = ConfigSection() | |||||
model_args = ConfigSection() | model_args = ConfigSection() | ||||
optim_args = ConfigSection() | optim_args = ConfigSection() | ||||
ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) | ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) | ||||
print('trainre Args:', train_args.data) | |||||
print('test Args:', test_args.data) | |||||
print('optim Args:', optim_args.data) | |||||
# Data Loader | |||||
# Pickle Loader | |||||
def save_data(dirpath, **kwargs): | def save_data(dirpath, **kwargs): | ||||
import _pickle | import _pickle | ||||
if not os.path.exists(dirpath): | if not os.path.exists(dirpath): | ||||
@@ -117,38 +158,57 @@ def load_data(dirpath): | |||||
datas[name] = _pickle.load(f) | datas[name] = _pickle.load(f) | ||||
return datas | return datas | ||||
class MyTester(object): | |||||
def __init__(self, batch_size, use_cuda=False, **kwagrs): | |||||
self.batch_size = batch_size | |||||
self.use_cuda = use_cuda | |||||
def test(self, model, dataset): | |||||
self.model = model.cuda() if self.use_cuda else model | |||||
self.model.eval() | |||||
batchiter = Batch(dataset, self.batch_size, SequentialSampler(), self.use_cuda) | |||||
eval_res = defaultdict(list) | |||||
i = 0 | |||||
for batch_x, batch_y in batchiter: | |||||
with torch.no_grad(): | |||||
pred_y = self.model(**batch_x) | |||||
eval_one = self.model.evaluate(**pred_y, **batch_y) | |||||
i += self.batch_size | |||||
for eval_name, tensor in eval_one.items(): | |||||
eval_res[eval_name].append(tensor) | |||||
tmp = {} | |||||
for eval_name, tensorlist in eval_res.items(): | |||||
tmp[eval_name] = torch.cat(tensorlist, dim=0) | |||||
self.res = self.model.metrics(**tmp) | |||||
def show_metrics(self): | |||||
s = "" | |||||
for name, val in self.res.items(): | |||||
s += '{}: {:.2f}\t'.format(name, val) | |||||
return s | |||||
loader = MyDataLoader('') | |||||
def P2(data, field, length): | |||||
ds = [ins for ins in data if ins[field].get_length() >= length] | |||||
data.clear() | |||||
data.extend(ds) | |||||
return ds | |||||
def P1(data, field): | |||||
def reeng(w): | |||||
return w if w == BOS or w == EOS or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else ENG | |||||
def renum(w): | |||||
return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else NUM | |||||
for ins in data: | |||||
ori = ins[field].contents() | |||||
s = list(map(renum, map(reeng, ori))) | |||||
if s != ori: | |||||
# print(ori) | |||||
# print(s) | |||||
# print() | |||||
ins[field] = ins[field].new(s) | |||||
return data | |||||
class ParserEvaluator(Evaluator): | |||||
def __init__(self, ignore_label): | |||||
super(ParserEvaluator, self).__init__() | |||||
self.ignore = ignore_label | |||||
def __call__(self, predict_list, truth_list): | |||||
head_all, label_all, total_all = 0, 0, 0 | |||||
for pred, truth in zip(predict_list, truth_list): | |||||
head, label, total = self.evaluate(**pred, **truth) | |||||
head_all += head | |||||
label_all += label | |||||
total_all += total | |||||
return {'UAS': head_all*1.0 / total_all, 'LAS': label_all*1.0 / total_all} | |||||
def evaluate(self, head_pred, label_pred, head_indices, head_labels, seq_mask, **_): | |||||
""" | |||||
Evaluate the performance of prediction. | |||||
:return : performance results. | |||||
head_pred_corrct: number of correct predicted heads. | |||||
label_pred_correct: number of correct predicted labels. | |||||
total_tokens: number of predicted tokens | |||||
""" | |||||
seq_mask *= (head_labels != self.ignore).long() | |||||
head_pred_correct = (head_pred == head_indices).long() * seq_mask | |||||
_, label_preds = torch.max(label_pred, dim=2) | |||||
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct | |||||
return head_pred_correct.sum().item(), label_pred_correct.sum().item(), seq_mask.sum().item() | |||||
try: | try: | ||||
data_dict = load_data(processed_datadir) | data_dict = load_data(processed_datadir) | ||||
word_v = data_dict['word_v'] | word_v = data_dict['word_v'] | ||||
@@ -156,62 +216,90 @@ try: | |||||
tag_v = data_dict['tag_v'] | tag_v = data_dict['tag_v'] | ||||
train_data = data_dict['train_data'] | train_data = data_dict['train_data'] | ||||
dev_data = data_dict['dev_data'] | dev_data = data_dict['dev_data'] | ||||
test_data = data_dict['test_data'] | |||||
print('use saved pickles') | print('use saved pickles') | ||||
except Exception as _: | except Exception as _: | ||||
print('load raw data and preprocess') | print('load raw data and preprocess') | ||||
# use pretrain embedding | |||||
word_v = Vocabulary(need_default=True, min_freq=2) | word_v = Vocabulary(need_default=True, min_freq=2) | ||||
word_v.unknown_label = UNK | |||||
pos_v = Vocabulary(need_default=True) | pos_v = Vocabulary(need_default=True) | ||||
tag_v = Vocabulary(need_default=False) | tag_v = Vocabulary(need_default=False) | ||||
train_data = loader.load(os.path.join(datadir, train_data_name), word_v, pos_v, tag_v) | |||||
train_data = loader.load(os.path.join(datadir, train_data_name)) | |||||
dev_data = loader.load(os.path.join(datadir, dev_data_name)) | dev_data = loader.load(os.path.join(datadir, dev_data_name)) | ||||
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data) | |||||
test_data = loader.load(os.path.join(datadir, test_data_name)) | |||||
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v) | |||||
datasets = (train_data, dev_data, test_data) | |||||
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data) | |||||
loader.index_data(train_data, word_v, pos_v, tag_v) | |||||
loader.index_data(dev_data, word_v, pos_v, tag_v) | |||||
print(len(train_data)) | |||||
print(len(dev_data)) | |||||
ep = train_args['epochs'] | |||||
train_args['epochs'] = math.ceil(50000.0 / len(train_data) * train_args['batch_size']) if ep <= 0 else ep | |||||
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) | |||||
print(len(word_v)) | |||||
print(embed.size()) | |||||
# Model | |||||
model_args['word_vocab_size'] = len(word_v) | model_args['word_vocab_size'] = len(word_v) | ||||
model_args['pos_vocab_size'] = len(pos_v) | model_args['pos_vocab_size'] = len(pos_v) | ||||
model_args['num_label'] = len(tag_v) | model_args['num_label'] = len(tag_v) | ||||
model = BiaffineParser(**model_args.data) | |||||
model.reset_parameters() | |||||
datasets = (train_data, dev_data, test_data) | |||||
for ds in datasets: | |||||
ds.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) | |||||
ds.set_origin_len('word_seq') | |||||
if train_args['use_golden_train']: | |||||
train_data.set_target(gold_heads=False) | |||||
else: | |||||
train_data.set_target(gold_heads=None) | |||||
train_args.data.pop('use_golden_train') | |||||
ignore_label = pos_v['P'] | |||||
print(test_data[0]) | |||||
print(len(train_data)) | |||||
print(len(dev_data)) | |||||
print(len(test_data)) | |||||
def train(): | |||||
def train(path): | |||||
# Trainer | # Trainer | ||||
trainer = Trainer(**train_args.data) | trainer = Trainer(**train_args.data) | ||||
def _define_optim(obj): | def _define_optim(obj): | ||||
obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data) | |||||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: .75 ** (ep / 5e4)) | |||||
lr = optim_args.data['lr'] | |||||
embed_params = set(obj._model.word_embedding.parameters()) | |||||
decay_params = set(obj._model.arc_predictor.parameters()) | set(obj._model.label_predictor.parameters()) | |||||
params = [p for p in obj._model.parameters() if p not in decay_params and p not in embed_params] | |||||
obj._optimizer = torch.optim.Adam([ | |||||
{'params': list(embed_params), 'lr':lr*0.1}, | |||||
{'params': list(decay_params), **optim_args.data}, | |||||
{'params': params} | |||||
], lr=lr, betas=(0.9, 0.9)) | |||||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05)) | |||||
def _update(obj): | def _update(obj): | ||||
# torch.nn.utils.clip_grad_norm_(obj._model.parameters(), 5.0) | |||||
obj._scheduler.step() | obj._scheduler.step() | ||||
obj._optimizer.step() | obj._optimizer.step() | ||||
trainer.define_optimizer = lambda: _define_optim(trainer) | trainer.define_optimizer = lambda: _define_optim(trainer) | ||||
trainer.update = lambda: _update(trainer) | trainer.update = lambda: _update(trainer) | ||||
trainer.get_loss = lambda predict, truth: trainer._loss_func(**predict, **truth) | |||||
trainer._create_validator = lambda x: MyTester(**test_args.data) | |||||
# Model | |||||
model = BiaffineParser(**model_args.data) | |||||
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label))) | |||||
# use pretrain embedding | |||||
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) | |||||
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) | model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) | ||||
model.word_embedding.padding_idx = word_v.padding_idx | model.word_embedding.padding_idx = word_v.padding_idx | ||||
model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | ||||
model.pos_embedding.padding_idx = pos_v.padding_idx | model.pos_embedding.padding_idx = pos_v.padding_idx | ||||
model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) | model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) | ||||
try: | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
print('model parameter loaded!') | |||||
except Exception as _: | |||||
print("No saved model. Continue.") | |||||
pass | |||||
# try: | |||||
# ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
# print('model parameter loaded!') | |||||
# except Exception as _: | |||||
# print("No saved model. Continue.") | |||||
# pass | |||||
# Start training | # Start training | ||||
trainer.train(model, train_data, dev_data) | trainer.train(model, train_data, dev_data) | ||||
@@ -223,24 +311,27 @@ def train(): | |||||
print("Model saved!") | print("Model saved!") | ||||
def test(): | |||||
def test(path): | |||||
# Tester | # Tester | ||||
tester = MyTester(**test_args.data) | |||||
tester = Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label)) | |||||
# Model | # Model | ||||
model = BiaffineParser(**model_args.data) | model = BiaffineParser(**model_args.data) | ||||
model.eval() | |||||
try: | try: | ||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
ModelLoader.load_pytorch(model, path) | |||||
print('model parameter loaded!') | print('model parameter loaded!') | ||||
except Exception as _: | except Exception as _: | ||||
print("No saved model. Abort test.") | print("No saved model. Abort test.") | ||||
raise | raise | ||||
# Start training | # Start training | ||||
print("Testing Train data") | |||||
tester.test(model, train_data) | |||||
print("Testing Dev data") | |||||
tester.test(model, dev_data) | tester.test(model, dev_data) | ||||
print(tester.show_metrics()) | |||||
print("Testing finished!") | |||||
print("Testing Test data") | |||||
tester.test(model, test_data) | |||||
@@ -248,13 +339,14 @@ if __name__ == "__main__": | |||||
import argparse | import argparse | ||||
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | ||||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | ||||
parser.add_argument('--path', type=str, default='') | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
if args.mode == 'train': | if args.mode == 'train': | ||||
train() | |||||
train(args.path) | |||||
elif args.mode == 'test': | elif args.mode == 'test': | ||||
test() | |||||
test(args.path) | |||||
elif args.mode == 'infer': | elif args.mode == 'infer': | ||||
infer() | |||||
pass | |||||
else: | else: | ||||
print('no mode specified for model!') | print('no mode specified for model!') | ||||
parser.print_help() | parser.print_help() |
@@ -0,0 +1,78 @@ | |||||
class ConllxDataLoader(object): | |||||
def load(self, path): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
data = [self.get_one(sample) for sample in datalist] | |||||
return list(filter(lambda x: x is not None, data)) | |||||
def get_one(self, sample): | |||||
sample = list(map(list, zip(*sample))) | |||||
if len(sample) == 0: | |||||
return None | |||||
for w in sample[7]: | |||||
if w == '_': | |||||
print('Error Sample {}'.format(sample)) | |||||
return None | |||||
# return word_seq, pos_seq, head_seq, head_tag_seq | |||||
return sample[1], sample[3], list(map(int, sample[6])), sample[7] | |||||
class MyDataloader: | |||||
def load(self, data_path): | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
lines = f.readlines() | |||||
data = self.parse(lines) | |||||
return data | |||||
def parse(self, lines): | |||||
""" | |||||
[ | |||||
[word], [pos], [head_index], [head_tag] | |||||
] | |||||
""" | |||||
sample = [] | |||||
data = [] | |||||
for i, line in enumerate(lines): | |||||
line = line.strip() | |||||
if len(line) == 0 or i + 1 == len(lines): | |||||
data.append(list(map(list, zip(*sample)))) | |||||
sample = [] | |||||
else: | |||||
sample.append(line.split()) | |||||
if len(sample) > 0: | |||||
data.append(list(map(list, zip(*sample)))) | |||||
return data | |||||
def add_seg_tag(data): | |||||
""" | |||||
:param data: list of ([word], [pos], [heads], [head_tags]) | |||||
:return: list of ([word], [pos]) | |||||
""" | |||||
_processed = [] | |||||
for word_list, pos_list, _, _ in data: | |||||
new_sample = [] | |||||
for word, pos in zip(word_list, pos_list): | |||||
if len(word) == 1: | |||||
new_sample.append((word, 'S-' + pos)) | |||||
else: | |||||
new_sample.append((word[0], 'B-' + pos)) | |||||
for c in word[1:-1]: | |||||
new_sample.append((c, 'M-' + pos)) | |||||
new_sample.append((word[-1], 'E-' + pos)) | |||||
_processed.append(list(map(list, zip(*new_sample)))) | |||||
return _processed |
@@ -4,8 +4,9 @@ import torch.nn.functional as F | |||||
class CNN_text(nn.Module): | class CNN_text(nn.Module): | ||||
def __init__(self, kernel_h=[3, 4, 5], kernel_num=100, embed_num=1000, embed_dim=300, dropout=0.5, L2_constrain=3, | |||||
batchsize=50, pretrained_embeddings=None): | |||||
def __init__(self, kernel_h=[3, 4, 5], kernel_num=100, embed_num=1000, embed_dim=300, num_classes=2, dropout=0.5, | |||||
L2_constrain=3, | |||||
pretrained_embeddings=None): | |||||
super(CNN_text, self).__init__() | super(CNN_text, self).__init__() | ||||
self.embedding = nn.Embedding(embed_num, embed_dim) | self.embedding = nn.Embedding(embed_num, embed_dim) | ||||
@@ -15,11 +16,11 @@ class CNN_text(nn.Module): | |||||
# the network structure | # the network structure | ||||
# Conv2d: input- N,C,H,W output- (50,100,62,1) | # Conv2d: input- N,C,H,W output- (50,100,62,1) | ||||
self.conv1 = nn.ModuleList([nn.Conv2d(1, 100, (K, 300)) for K in kernel_h]) | |||||
self.fc1 = nn.Linear(300, 2) | |||||
self.conv1 = nn.ModuleList([nn.Conv2d(1, kernel_num, (K, embed_dim)) for K in kernel_h]) | |||||
self.fc1 = nn.Linear(len(kernel_h) * kernel_num, num_classes) | |||||
def max_pooling(self, x): | def max_pooling(self, x): | ||||
x = F.relu(conv(x)).squeeze(3) # N,C,L - (50,100,62) | |||||
x = F.relu(self.conv1(x)).squeeze(3) # N,C,L - (50,100,62) | |||||
x = F.max_pool1d(x, x.size(2)).squeeze(2) | x = F.max_pool1d(x, x.size(2)).squeeze(2) | ||||
# x.size(2)=62 squeeze: (50,100,1) -> (50,100) | # x.size(2)=62 squeeze: (50,100,1) -> (50,100) | ||||
return x | return x | ||||
@@ -33,3 +34,9 @@ class CNN_text(nn.Module): | |||||
x = self.dropout(x) | x = self.dropout(x) | ||||
x = self.fc1(x) | x = self.fc1(x) | ||||
return x | return x | ||||
if __name__ == '__main__': | |||||
model = CNN_text(kernel_h=[1, 2, 3, 4], embed_num=3, embed_dim=2) | |||||
x = torch.LongTensor([[1, 2, 1, 2, 0]]) | |||||
print(model(x)) |
@@ -1,10 +1,10 @@ | |||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from fastNLP.core.preprocess import ClassPreprocess as Preprocess | |||||
from fastNLP.core.trainer import ClassificationTrainer | from fastNLP.core.trainer import ClassificationTrainer | ||||
from fastNLP.loader.config_loader import ConfigLoader | |||||
from fastNLP.loader.config_loader import ConfigSection | |||||
from fastNLP.loader.dataset_loader import ClassDataSetLoader as Dataset_loader | |||||
from fastNLP.core.utils import ClassPreprocess as Preprocess | |||||
from fastNLP.io.config_io import ConfigLoader | |||||
from fastNLP.io.config_io import ConfigSection | |||||
from fastNLP.io.dataset_loader import ClassDataSetLoader as Dataset_loader | |||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules.aggregator.self_attention import SelfAttention | from fastNLP.modules.aggregator.self_attention import SelfAttention | ||||
from fastNLP.modules.decoder.MLP import MLP | from fastNLP.modules.decoder.MLP import MLP | ||||
@@ -1,6 +1,6 @@ | |||||
[train] | [train] | ||||
epochs = 30 | |||||
batch_size = 64 | |||||
epochs = 40 | |||||
batch_size = 8 | |||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
validate = true | validate = true | ||||
save_best_dev = true | save_best_dev = true | ||||
@@ -0,0 +1,176 @@ | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.io.dataset_loader import DataSetLoader | |||||
def cut_long_sentence(sent, max_sample_length=200): | |||||
sent_no_space = sent.replace(' ', '') | |||||
cutted_sentence = [] | |||||
if len(sent_no_space) > max_sample_length: | |||||
parts = sent.strip().split() | |||||
new_line = '' | |||||
length = 0 | |||||
for part in parts: | |||||
length += len(part) | |||||
new_line += part + ' ' | |||||
if length > max_sample_length: | |||||
new_line = new_line[:-1] | |||||
cutted_sentence.append(new_line) | |||||
length = 0 | |||||
new_line = '' | |||||
if new_line != '': | |||||
cutted_sentence.append(new_line[:-1]) | |||||
else: | |||||
cutted_sentence.append(sent) | |||||
return cutted_sentence | |||||
class NaiveCWSReader(DataSetLoader): | |||||
""" | |||||
这个reader假设了分词数据集为以下形式, 即已经用空格分割好内容了 | |||||
这是 fastNLP , 一个 非常 good 的 包 . | |||||
或者,即每个part后面还有一个pos tag | |||||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | |||||
""" | |||||
def __init__(self, in_word_splitter=None): | |||||
super().__init__() | |||||
self.in_word_splitter = in_word_splitter | |||||
def load(self, filepath, in_word_splitter=None, cut_long_sent=False): | |||||
""" | |||||
允许使用的情况有(默认以\t或空格作为seg) | |||||
这是 fastNLP , 一个 非常 good 的 包 . | |||||
和 | |||||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | |||||
如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0] | |||||
:param filepath: | |||||
:param in_word_splitter: | |||||
:return: | |||||
""" | |||||
if in_word_splitter == None: | |||||
in_word_splitter = self.in_word_splitter | |||||
dataset = DataSet() | |||||
with open(filepath, 'r') as f: | |||||
for line in f: | |||||
line = line.strip() | |||||
if len(line.replace(' ', ''))==0: # 不能接受空行 | |||||
continue | |||||
if not in_word_splitter is None: | |||||
words = [] | |||||
for part in line.split(): | |||||
word = part.split(in_word_splitter)[0] | |||||
words.append(word) | |||||
line = ' '.join(words) | |||||
if cut_long_sent: | |||||
sents = cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for sent in sents: | |||||
instance = Instance(raw_sentence=sent) | |||||
dataset.append(instance) | |||||
return dataset | |||||
class POSCWSReader(DataSetLoader): | |||||
""" | |||||
支持读取以下的情况, 即每一行是一个词, 用空行作为两句话的界限. | |||||
迈 N | |||||
向 N | |||||
充 N | |||||
... | |||||
泽 I-PER | |||||
民 I-PER | |||||
( N | |||||
一 N | |||||
九 N | |||||
... | |||||
:param filepath: | |||||
:return: | |||||
""" | |||||
def __init__(self, in_word_splitter=None): | |||||
super().__init__() | |||||
self.in_word_splitter = in_word_splitter | |||||
def load(self, filepath, in_word_splitter=None, cut_long_sent=False): | |||||
if in_word_splitter is None: | |||||
in_word_splitter = self.in_word_splitter | |||||
dataset = DataSet() | |||||
with open(filepath, 'r') as f: | |||||
words = [] | |||||
for line in f: | |||||
line = line.strip() | |||||
if len(line) == 0: # new line | |||||
if len(words)==0: # 不能接受空行 | |||||
continue | |||||
line = ' '.join(words) | |||||
if cut_long_sent: | |||||
sents = cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for sent in sents: | |||||
instance = Instance(raw_sentence=sent) | |||||
dataset.append(instance) | |||||
words = [] | |||||
else: | |||||
line = line.split()[0] | |||||
if in_word_splitter is None: | |||||
words.append(line) | |||||
else: | |||||
words.append(line.split(in_word_splitter)[0]) | |||||
return dataset | |||||
class ConlluCWSReader(object): | |||||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path, cut_long_sent=False): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_one(sample) | |||||
if res is None: | |||||
continue | |||||
line = ' '.join(res) | |||||
if cut_long_sent: | |||||
sents = cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for raw_sentence in sents: | |||||
ds.append(Instance(raw_sentence=raw_sentence)) | |||||
return ds | |||||
def get_one(self, sample): | |||||
if len(sample)==0: | |||||
return None | |||||
text = [] | |||||
for w in sample: | |||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
if t3 == '_': | |||||
return None | |||||
text.append(t1) | |||||
return text | |||||
@@ -0,0 +1,172 @@ | |||||
from torch import nn | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from fastNLP.modules.decoder.MLP import MLP | |||||
from fastNLP.models.base_model import BaseModel | |||||
from reproduction.chinese_word_segment.utils import seq_lens_to_mask | |||||
class CWSBiLSTMEncoder(BaseModel): | |||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | |||||
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1): | |||||
super().__init__() | |||||
self.input_size = 0 | |||||
self.num_bigram_per_char = num_bigram_per_char | |||||
self.bidirectional = bidirectional | |||||
self.num_layers = num_layers | |||||
self.embed_drop_p = embed_drop_p | |||||
if self.bidirectional: | |||||
self.hidden_size = hidden_size//2 | |||||
self.num_directions = 2 | |||||
else: | |||||
self.hidden_size = hidden_size | |||||
self.num_directions = 1 | |||||
if not bigram_vocab_num is None: | |||||
assert not bigram_vocab_num is None, "Specify num_bigram_per_char." | |||||
if vocab_num is not None: | |||||
self.char_embedding = nn.Embedding(num_embeddings=vocab_num, embedding_dim=embed_dim) | |||||
self.input_size += embed_dim | |||||
if bigram_vocab_num is not None: | |||||
self.bigram_embedding = nn.Embedding(num_embeddings=bigram_vocab_num, embedding_dim=bigram_embed_dim) | |||||
self.input_size += self.num_bigram_per_char*bigram_embed_dim | |||||
if not self.embed_drop_p is None: | |||||
self.embedding_drop = nn.Dropout(p=self.embed_drop_p) | |||||
self.lstm = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, bidirectional=self.bidirectional, | |||||
batch_first=True, num_layers=self.num_layers) | |||||
self.reset_parameters() | |||||
def reset_parameters(self): | |||||
for name, param in self.named_parameters(): | |||||
if 'bias_hh' in name: | |||||
nn.init.constant_(param, 0) | |||||
elif 'bias_ih' in name: | |||||
nn.init.constant_(param, 1) | |||||
else: | |||||
nn.init.xavier_uniform_(param) | |||||
def init_embedding(self, embedding, embed_name): | |||||
if embed_name == 'bigram': | |||||
self.bigram_embedding.weight.data = torch.from_numpy(embedding) | |||||
elif embed_name == 'char': | |||||
self.char_embedding.weight.data = torch.from_numpy(embedding) | |||||
def forward(self, chars, bigrams=None, seq_lens=None): | |||||
batch_size, max_len = chars.size() | |||||
x_tensor = self.char_embedding(chars) | |||||
if not bigrams is None: | |||||
bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1) | |||||
x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2) | |||||
sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True) | |||||
packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True) | |||||
outputs, _ = self.lstm(packed_x) | |||||
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) | |||||
_, desorted_indices = torch.sort(sorted_indices, descending=False) | |||||
outputs = outputs[desorted_indices] | |||||
return outputs | |||||
class CWSBiLSTMSegApp(BaseModel): | |||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | |||||
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=2): | |||||
super(CWSBiLSTMSegApp, self).__init__() | |||||
self.tag_size = tag_size | |||||
self.encoder_model = CWSBiLSTMEncoder(vocab_num, embed_dim, bigram_vocab_num, bigram_embed_dim, num_bigram_per_char, | |||||
hidden_size, bidirectional, embed_drop_p, num_layers) | |||||
size_layer = [hidden_size, 200, tag_size] | |||||
self.decoder_model = MLP(size_layer) | |||||
def forward(self, chars, seq_lens, bigrams=None): | |||||
device = self.parameters().__next__().device | |||||
chars = chars.to(device).long() | |||||
if not bigrams is None: | |||||
bigrams = bigrams.to(device).long() | |||||
else: | |||||
bigrams = None | |||||
seq_lens = seq_lens.to(device).long() | |||||
feats = self.encoder_model(chars, bigrams, seq_lens) | |||||
probs = self.decoder_model(feats) | |||||
pred_dict = {} | |||||
pred_dict['seq_lens'] = seq_lens | |||||
pred_dict['pred_probs'] = probs | |||||
return pred_dict | |||||
def predict(self, chars, seq_lens, bigrams=None): | |||||
pred_dict = self.forward(chars, seq_lens, bigrams) | |||||
pred_probs = pred_dict['pred_probs'] | |||||
_, pred_tags = pred_probs.max(dim=-1) | |||||
return {'pred_tags': pred_tags} | |||||
from fastNLP.modules.decoder.CRF import ConditionalRandomField | |||||
class CWSBiLSTMCRF(BaseModel): | |||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | |||||
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=4): | |||||
super(CWSBiLSTMCRF, self).__init__() | |||||
self.tag_size = tag_size | |||||
self.encoder_model = CWSBiLSTMEncoder(vocab_num, embed_dim, bigram_vocab_num, bigram_embed_dim, num_bigram_per_char, | |||||
hidden_size, bidirectional, embed_drop_p, num_layers) | |||||
size_layer = [hidden_size, 200, tag_size] | |||||
self.decoder_model = MLP(size_layer) | |||||
self.crf = ConditionalRandomField(tag_size=tag_size, include_start_end_trans=False) | |||||
def forward(self, chars, tags, seq_lens, bigrams=None): | |||||
device = self.parameters().__next__().device | |||||
chars = chars.to(device).long() | |||||
if not bigrams is None: | |||||
bigrams = bigrams.to(device).long() | |||||
else: | |||||
bigrams = None | |||||
seq_lens = seq_lens.to(device).long() | |||||
masks = seq_lens_to_mask(seq_lens) | |||||
feats = self.encoder_model(chars, bigrams, seq_lens) | |||||
feats = self.decoder_model(feats) | |||||
losses = self.crf(feats, tags, masks) | |||||
pred_dict = {} | |||||
pred_dict['seq_lens'] = seq_lens | |||||
pred_dict['loss'] = torch.mean(losses) | |||||
return pred_dict | |||||
def predict(self, chars, seq_lens, bigrams=None): | |||||
device = self.parameters().__next__().device | |||||
chars = chars.to(device).long() | |||||
if not bigrams is None: | |||||
bigrams = bigrams.to(device).long() | |||||
else: | |||||
bigrams = None | |||||
seq_lens = seq_lens.to(device).long() | |||||
masks = seq_lens_to_mask(seq_lens) | |||||
feats = self.encoder_model(chars, bigrams, seq_lens) | |||||
feats = self.decoder_model(feats) | |||||
probs = self.crf.viterbi_decode(feats, masks, get_score=False) | |||||
return {'pred_tags': probs} | |||||
@@ -0,0 +1,284 @@ | |||||
import re | |||||
from fastNLP.core.field import SeqLabelField | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.api.processor import Processor | |||||
from reproduction.chinese_word_segment.process.span_converter import SpanConverter | |||||
_SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' | |||||
class SpeicalSpanProcessor(Processor): | |||||
# 这个类会将句子中的special span转换为对应的内容。 | |||||
def __init__(self, field_name, new_added_field_name=None): | |||||
super(SpeicalSpanProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.span_converters = [] | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
sentence = ins[self.field_name] | |||||
for span_converter in self.span_converters: | |||||
sentence = span_converter.find_certain_span_and_replace(sentence) | |||||
ins[self.new_added_field_name] = sentence | |||||
return dataset | |||||
def add_span_converter(self, converter): | |||||
assert isinstance(converter, SpanConverter), "Only SpanConverterBase is allowed, not {}."\ | |||||
.format(type(converter)) | |||||
self.span_converters.append(converter) | |||||
class CWSCharSegProcessor(Processor): | |||||
def __init__(self, field_name, new_added_field_name): | |||||
super(CWSCharSegProcessor, self).__init__(field_name, new_added_field_name) | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
sentence = ins[self.field_name] | |||||
chars = self._split_sent_into_chars(sentence) | |||||
ins[self.new_added_field_name] = chars | |||||
return dataset | |||||
def _split_sent_into_chars(self, sentence): | |||||
sp_tag_match_iter = re.finditer(_SPECIAL_TAG_PATTERN, sentence) | |||||
sp_spans = [match_span.span() for match_span in sp_tag_match_iter] | |||||
sp_span_idx = 0 | |||||
in_span_flag = False | |||||
chars = [] | |||||
num_spans = len(sp_spans) | |||||
for idx, char in enumerate(sentence): | |||||
if sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][0]: | |||||
in_span_flag = True | |||||
elif in_span_flag and sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][1] - 1: | |||||
chars.append(sentence[sp_spans[sp_span_idx] | |||||
[0]:sp_spans[sp_span_idx][1]]) | |||||
in_span_flag = False | |||||
sp_span_idx += 1 | |||||
elif not in_span_flag: | |||||
# TODO 需要谨慎考虑如何处理空格的问题 | |||||
if char != ' ': | |||||
chars.append(char) | |||||
else: | |||||
pass | |||||
return chars | |||||
class CWSTagProcessor(Processor): | |||||
def __init__(self, field_name, new_added_field_name=None): | |||||
super(CWSTagProcessor, self).__init__(field_name, new_added_field_name) | |||||
def _generate_tag(self, sentence): | |||||
sp_tag_match_iter = re.finditer(_SPECIAL_TAG_PATTERN, sentence) | |||||
sp_spans = [match_span.span() for match_span in sp_tag_match_iter] | |||||
sp_span_idx = 0 | |||||
in_span_flag = False | |||||
tag_list = [] | |||||
word_len = 0 | |||||
num_spans = len(sp_spans) | |||||
for idx, char in enumerate(sentence): | |||||
if sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][0]: | |||||
in_span_flag = True | |||||
elif in_span_flag and sp_span_idx<num_spans and idx == sp_spans[sp_span_idx][1] - 1: | |||||
word_len += 1 | |||||
in_span_flag = False | |||||
sp_span_idx += 1 | |||||
elif not in_span_flag: | |||||
if char == ' ': | |||||
if word_len!=0: | |||||
tag_list.extend(self._tags_from_word_len(word_len)) | |||||
word_len = 0 | |||||
else: | |||||
word_len += 1 | |||||
else: | |||||
pass | |||||
if word_len!=0: | |||||
tag_list.extend(self._tags_from_word_len(word_len)) | |||||
return tag_list | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
sentence = ins[self.field_name] | |||||
tag_list = self._generate_tag(sentence) | |||||
ins[self.new_added_field_name] = tag_list | |||||
dataset.set_target(**{self.new_added_field_name:True}) | |||||
dataset._set_need_tensor(**{self.new_added_field_name:True}) | |||||
return dataset | |||||
def _tags_from_word_len(self, word_len): | |||||
raise NotImplementedError | |||||
class CWSBMESTagProcessor(CWSTagProcessor): | |||||
def __init__(self, field_name, new_added_field_name=None): | |||||
super(CWSBMESTagProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.tag_size = 4 | |||||
def _tags_from_word_len(self, word_len): | |||||
tag_list = [] | |||||
if word_len == 1: | |||||
tag_list.append(3) | |||||
else: | |||||
tag_list.append(0) | |||||
for _ in range(word_len-2): | |||||
tag_list.append(1) | |||||
tag_list.append(2) | |||||
return tag_list | |||||
class CWSSegAppTagProcessor(CWSTagProcessor): | |||||
def __init__(self, field_name, new_added_field_name=None): | |||||
super(CWSSegAppTagProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.tag_size = 2 | |||||
def _tags_from_word_len(self, word_len): | |||||
tag_list = [] | |||||
for _ in range(word_len-1): | |||||
tag_list.append(0) | |||||
tag_list.append(1) | |||||
return tag_list | |||||
class BigramProcessor(Processor): | |||||
def __init__(self, field_name, new_added_fielf_name=None): | |||||
super(BigramProcessor, self).__init__(field_name, new_added_fielf_name) | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
characters = ins[self.field_name] | |||||
bigrams = self._generate_bigram(characters) | |||||
ins[self.new_added_field_name] = bigrams | |||||
return dataset | |||||
def _generate_bigram(self, characters): | |||||
pass | |||||
class Pre2Post2BigramProcessor(BigramProcessor): | |||||
def __init__(self, field_name, new_added_fielf_name=None): | |||||
super(BigramProcessor, self).__init__(field_name, new_added_fielf_name) | |||||
def _generate_bigram(self, characters): | |||||
bigrams = [] | |||||
characters = ['<SOS>', '<SOS>'] + characters + ['<EOS>', '<EOS>'] | |||||
for idx in range(2, len(characters)-2): | |||||
cur_char = characters[idx] | |||||
pre_pre_char = characters[idx-2] | |||||
pre_char = characters[idx-1] | |||||
post_char = characters[idx+1] | |||||
post_post_char = characters[idx+2] | |||||
pre_pre_cur_bigram = pre_pre_char + cur_char | |||||
pre_cur_bigram = pre_char + cur_char | |||||
cur_post_bigram = cur_char + post_char | |||||
cur_post_post_bigram = cur_char + post_post_char | |||||
bigrams.extend([pre_pre_char, pre_char, post_char, post_post_char, | |||||
pre_pre_cur_bigram, pre_cur_bigram, | |||||
cur_post_bigram, cur_post_post_bigram]) | |||||
return bigrams | |||||
# 这里需要建立vocabulary了,但是遇到了以下的问题 | |||||
# (1) 如果使用Processor的方式的话,但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现,不借用 | |||||
# Processor了 | |||||
class VocabProcessor(Processor): | |||||
def __init__(self, field_name, min_count=1, max_vocab_size=None): | |||||
super(VocabProcessor, self).__init__(field_name, None) | |||||
self.vocab = Vocabulary(min_freq=min_count, max_size=max_vocab_size) | |||||
def process(self, *datasets): | |||||
for dataset in datasets: | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
tokens = ins[self.field_name] | |||||
self.vocab.update(tokens) | |||||
def get_vocab(self): | |||||
self.vocab.build_vocab() | |||||
return self.vocab | |||||
def get_vocab_size(self): | |||||
return len(self.vocab) | |||||
class SeqLenProcessor(Processor): | |||||
def __init__(self, field_name, new_added_field_name='seq_lens'): | |||||
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
length = len(ins[self.field_name]) | |||||
ins[self.new_added_field_name] = length | |||||
dataset._set_need_tensor(**{self.new_added_field_name:True}) | |||||
return dataset | |||||
class SegApp2OutputProcessor(Processor): | |||||
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): | |||||
super(SegApp2OutputProcessor, self).__init__(None, None) | |||||
self.chars_field_name = chars_field_name | |||||
self.tag_field_name = tag_field_name | |||||
self.new_added_field_name = new_added_field_name | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
pred_tags = ins[self.tag_field_name] | |||||
chars = ins[self.chars_field_name] | |||||
words = [] | |||||
start_idx = 0 | |||||
for idx, tag in enumerate(pred_tags): | |||||
if tag==1: | |||||
# 当前没有考虑将原文替换回去 | |||||
words.append(''.join(chars[start_idx:idx+1])) | |||||
start_idx = idx + 1 | |||||
ins[self.new_added_field_name] = ' '.join(words) | |||||
class BMES2OutputProcessor(Processor): | |||||
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): | |||||
super(BMES2OutputProcessor, self).__init__(None, None) | |||||
self.chars_field_name = chars_field_name | |||||
self.tag_field_name = tag_field_name | |||||
self.new_added_field_name = new_added_field_name | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
pred_tags = ins[self.tag_field_name] | |||||
chars = ins[self.chars_field_name] | |||||
words = [] | |||||
start_idx = 0 | |||||
for idx, tag in enumerate(pred_tags): | |||||
if tag==3: | |||||
# 当前没有考虑将原文替换回去 | |||||
words.extend(chars[start_idx:idx+1]) | |||||
start_idx = idx + 1 | |||||
elif tag==2: | |||||
words.append(''.join(chars[start_idx:idx+1])) | |||||
start_idx = idx + 1 | |||||
ins[self.new_added_field_name] = ' '.join(words) |
@@ -0,0 +1,185 @@ | |||||
import re | |||||
class SpanConverter: | |||||
def __init__(self, replace_tag, pattern): | |||||
super(SpanConverter, self).__init__() | |||||
self.replace_tag = replace_tag | |||||
self.pattern = pattern | |||||
def find_certain_span_and_replace(self, sentence): | |||||
replaced_sentence = '' | |||||
prev_end = 0 | |||||
for match in re.finditer(self.pattern, sentence): | |||||
start, end = match.span() | |||||
span = sentence[start:end] | |||||
replaced_sentence += sentence[prev_end:start] + \ | |||||
self.span_to_special_tag(span) | |||||
prev_end = end | |||||
replaced_sentence += sentence[prev_end:] | |||||
return replaced_sentence | |||||
def span_to_special_tag(self, span): | |||||
return self.replace_tag | |||||
def find_certain_span(self, sentence): | |||||
spans = [] | |||||
for match in re.finditer(self.pattern, sentence): | |||||
spans.append(match.span()) | |||||
return spans | |||||
class AlphaSpanConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<ALPHA>' | |||||
# 理想状态下仅处理纯为字母的情况, 但不处理<[a-zA-Z]+>(因为这应该是特殊的tag). | |||||
pattern = '[a-zA-Z]+(?=[\u4e00-\u9fff ,%.!<\\-"])' | |||||
super(AlphaSpanConverter, self).__init__(replace_tag, pattern) | |||||
class DigitSpanConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<NUM>' | |||||
pattern = '\d[\d\\.]*(?=[\u4e00-\u9fff ,%.!<-])' | |||||
super(DigitSpanConverter, self).__init__(replace_tag, pattern) | |||||
def span_to_special_tag(self, span): | |||||
# return self.special_tag | |||||
if span[0] == '0' and len(span) > 2: | |||||
return '<NUM>' | |||||
decimal_point_count = 0 # one might have more than one decimal pointers | |||||
for idx, char in enumerate(span): | |||||
if char == '.' or char == '﹒' or char == '·': | |||||
decimal_point_count += 1 | |||||
if span[-1] == '.' or span[-1] == '﹒' or span[ | |||||
-1] == '·': # last digit being decimal point means this is not a number | |||||
if decimal_point_count == 1: | |||||
return span | |||||
else: | |||||
return '<UNKDGT>' | |||||
if decimal_point_count == 1: | |||||
return '<DEC>' | |||||
elif decimal_point_count > 1: | |||||
return '<UNKDGT>' | |||||
else: | |||||
return '<NUM>' | |||||
class TimeConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<TOC>' | |||||
pattern = '\d+[::∶][\d::∶]+(?=[\u4e00-\u9fff ,%.!<-])' | |||||
super().__init__(replace_tag, pattern) | |||||
class MixNumAlphaConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<MIX>' | |||||
pattern = None | |||||
super().__init__(replace_tag, pattern) | |||||
def find_certain_span_and_replace(self, sentence): | |||||
replaced_sentence = '' | |||||
start = 0 | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
for idx in range(len(sentence)): | |||||
if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]): | |||||
if not matching_flag: | |||||
replaced_sentence += sentence[start:idx] | |||||
start = idx | |||||
if re.match('[0-9]', sentence[idx]): | |||||
number_flag = True | |||||
elif re.match('[\'′&\\-]', sentence[idx]): | |||||
link_flag = True | |||||
elif re.match('/', sentence[idx]): | |||||
slash_flag = True | |||||
elif re.match('[\\(\\)]', sentence[idx]): | |||||
bracket_flag = True | |||||
else: | |||||
alpha_flag = True | |||||
matching_flag = True | |||||
elif re.match('[\\.]', sentence[idx]): | |||||
pass | |||||
else: | |||||
if matching_flag: | |||||
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \ | |||||
or (slash_flag and alpha_flag) or (link_flag and number_flag) \ | |||||
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag): | |||||
span = sentence[start:idx] | |||||
start = idx | |||||
replaced_sentence += self.span_to_special_tag(span) | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
replaced_sentence += sentence[start:] | |||||
return replaced_sentence | |||||
def find_certain_span(self, sentence): | |||||
spans = [] | |||||
start = 0 | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
for idx in range(len(sentence)): | |||||
if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]): | |||||
if not matching_flag: | |||||
start = idx | |||||
if re.match('[0-9]', sentence[idx]): | |||||
number_flag = True | |||||
elif re.match('[\'′&\\-]', sentence[idx]): | |||||
link_flag = True | |||||
elif re.match('/', sentence[idx]): | |||||
slash_flag = True | |||||
elif re.match('[\\(\\)]', sentence[idx]): | |||||
bracket_flag = True | |||||
else: | |||||
alpha_flag = True | |||||
matching_flag = True | |||||
elif re.match('[\\.]', sentence[idx]): | |||||
pass | |||||
else: | |||||
if matching_flag: | |||||
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \ | |||||
or (slash_flag and alpha_flag) or (link_flag and number_flag) \ | |||||
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag): | |||||
spans.append((start, idx)) | |||||
start = idx | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
return spans | |||||
class EmailConverter(SpanConverter): | |||||
def __init__(self): | |||||
replaced_tag = "<EML>" | |||||
pattern = '[0-9a-zA-Z]+[@][.﹒0-9a-zA-Z@]+(?=[\u4e00-\u9fff ,%.!<\\-"$])' | |||||
super(EmailConverter, self).__init__(replaced_tag, pattern) |
@@ -3,17 +3,15 @@ import sys | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | ||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||||
from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
from fastNLP.loader.dataset_loader import BaseLoader, TokenizeDataSetLoader | |||||
from fastNLP.core.preprocess import load_pickle | |||||
from fastNLP.saver.model_saver import ModelSaver | |||||
from fastNLP.loader.model_loader import ModelLoader | |||||
from fastNLP.io.dataset_loader import BaseLoader, TokenizeDataSetLoader | |||||
from fastNLP.core.utils import load_pickle | |||||
from fastNLP.io.model_io import ModelLoader, ModelSaver | |||||
from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | from fastNLP.models.sequence_modeling import AdvSeqLabel | ||||
from fastNLP.core.predictor import SeqLabelInfer | from fastNLP.core.predictor import SeqLabelInfer | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.preprocess import save_pickle | |||||
from fastNLP.core.utils import save_pickle | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | from fastNLP.core.metrics import SeqLabelEvaluator | ||||
# not in the file's dir | # not in the file's dir | ||||
@@ -0,0 +1,151 @@ | |||||
import torch | |||||
def seq_lens_to_mask(seq_lens): | |||||
batch_size = seq_lens.size(0) | |||||
max_len = seq_lens.max() | |||||
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) | |||||
masks = indexes.lt(seq_lens.unsqueeze(1)) | |||||
return masks | |||||
from itertools import chain | |||||
def refine_ys_on_seq_len(ys, seq_lens): | |||||
refined_ys = [] | |||||
for b_idx, length in enumerate(seq_lens): | |||||
refined_ys.append(list(ys[b_idx][:length])) | |||||
return refined_ys | |||||
def flat_nested_list(nested_list): | |||||
return list(chain(*nested_list)) | |||||
def calculate_pre_rec_f1(model, batcher, type='segapp'): | |||||
true_ys, pred_ys = decode_iterator(model, batcher) | |||||
true_ys = flat_nested_list(true_ys) | |||||
pred_ys = flat_nested_list(pred_ys) | |||||
cor_num = 0 | |||||
start = 0 | |||||
if type=='segapp': | |||||
yp_wordnum = pred_ys.count(1) | |||||
yt_wordnum = true_ys.count(1) | |||||
if true_ys[0]==1 and pred_ys[0]==1: | |||||
cor_num += 1 | |||||
start = 1 | |||||
for i in range(1, len(true_ys)): | |||||
if true_ys[i] == 1: | |||||
flag = True | |||||
if true_ys[start-1] != pred_ys[start-1]: | |||||
flag = False | |||||
else: | |||||
for j in range(start, i + 1): | |||||
if true_ys[j] != pred_ys[j]: | |||||
flag = False | |||||
break | |||||
if flag: | |||||
cor_num += 1 | |||||
start = i + 1 | |||||
elif type=='bmes': | |||||
yp_wordnum = pred_ys.count(2) + pred_ys.count(3) | |||||
yt_wordnum = true_ys.count(2) + true_ys.count(3) | |||||
for i in range(len(true_ys)): | |||||
if true_ys[i] == 2 or true_ys[i] == 3: | |||||
flag = True | |||||
for j in range(start, i + 1): | |||||
if true_ys[j] != pred_ys[j]: | |||||
flag = False | |||||
break | |||||
if flag: | |||||
cor_num += 1 | |||||
start = i + 1 | |||||
P = cor_num / (float(yp_wordnum) + 1e-6) | |||||
R = cor_num / (float(yt_wordnum) + 1e-6) | |||||
F = 2 * P * R / (P + R + 1e-6) | |||||
# print(cor_num, yt_wordnum, yp_wordnum) | |||||
return P, R, F | |||||
def decode_iterator(model, batcher): | |||||
true_ys = [] | |||||
pred_ys = [] | |||||
seq_lens = [] | |||||
with torch.no_grad(): | |||||
model.eval() | |||||
for batch_x, batch_y in batcher: | |||||
pred_dict = model.predict(**batch_x) | |||||
seq_len = batch_x['seq_lens'].cpu().numpy() | |||||
pred_y = pred_dict['pred_tags'] | |||||
true_y = batch_y['tags'] | |||||
pred_y = pred_y.cpu().numpy() | |||||
true_y = true_y.cpu().numpy() | |||||
true_ys.extend(true_y.tolist()) | |||||
pred_ys.extend(pred_y.tolist()) | |||||
seq_lens.extend(list(seq_len)) | |||||
model.train() | |||||
true_ys = refine_ys_on_seq_len(true_ys, seq_lens) | |||||
pred_ys = refine_ys_on_seq_len(pred_ys, seq_lens) | |||||
return true_ys, pred_ys | |||||
from torch import nn | |||||
import torch.nn.functional as F | |||||
class FocalLoss(nn.Module): | |||||
r""" | |||||
This criterion is a implemenation of Focal Loss, which is proposed in | |||||
Focal Loss for Dense Object Detection. | |||||
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) | |||||
The losses are averaged across observations for each minibatch. | |||||
Args: | |||||
alpha(1D Tensor, Variable) : the scalar factor for this criterion | |||||
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), | |||||
putting more focus on hard, misclassified examples | |||||
size_average(bool): size_average(bool): By default, the losses are averaged over observations for each minibatch. | |||||
However, if the field size_average is set to False, the losses are | |||||
instead summed for each minibatch. | |||||
""" | |||||
def __init__(self, class_num, gamma=2, size_average=True, reduce=False): | |||||
super(FocalLoss, self).__init__() | |||||
self.gamma = gamma | |||||
self.class_num = class_num | |||||
self.size_average = size_average | |||||
self.reduce = reduce | |||||
def forward(self, inputs, targets): | |||||
N = inputs.size(0) | |||||
C = inputs.size(1) | |||||
P = F.softmax(inputs, dim=-1) | |||||
class_mask = inputs.data.new(N, C).fill_(0) | |||||
class_mask.requires_grad = True | |||||
ids = targets.view(-1, 1) | |||||
class_mask = class_mask.scatter(1, ids.data, 1.) | |||||
probs = (P * class_mask).sum(1).view(-1, 1) | |||||
log_p = probs.log() | |||||
batch_loss = - (torch.pow((1 - probs), self.gamma)) * log_p | |||||
if self.reduce: | |||||
if self.size_average: | |||||
loss = batch_loss.mean() | |||||
else: | |||||
loss = batch_loss.sum() | |||||
return loss | |||||
return batch_loss |
@@ -0,0 +1,89 @@ | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
def cut_long_sentence(sent, max_sample_length=200): | |||||
sent_no_space = sent.replace(' ', '') | |||||
cutted_sentence = [] | |||||
if len(sent_no_space) > max_sample_length: | |||||
parts = sent.strip().split() | |||||
new_line = '' | |||||
length = 0 | |||||
for part in parts: | |||||
length += len(part) | |||||
new_line += part + ' ' | |||||
if length > max_sample_length: | |||||
new_line = new_line[:-1] | |||||
cutted_sentence.append(new_line) | |||||
length = 0 | |||||
new_line = '' | |||||
if new_line != '': | |||||
cutted_sentence.append(new_line[:-1]) | |||||
else: | |||||
cutted_sentence.append(sent) | |||||
return cutted_sentence | |||||
class ConlluPOSReader(object): | |||||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_one(sample) | |||||
if res is None: | |||||
continue | |||||
char_seq = [] | |||||
pos_seq = [] | |||||
for word, tag in zip(res[0], res[1]): | |||||
if len(word)==1: | |||||
char_seq.append(word) | |||||
pos_seq.append('S-{}'.format(tag)) | |||||
elif len(word)>1: | |||||
pos_seq.append('B-{}'.format(tag)) | |||||
for _ in range(len(word)-2): | |||||
pos_seq.append('M-{}'.format(tag)) | |||||
pos_seq.append('E-{}'.format(tag)) | |||||
char_seq.extend(list(word)) | |||||
else: | |||||
raise ValueError("Zero length of word detected.") | |||||
ds.append(Instance(words=char_seq, | |||||
tag=pos_seq)) | |||||
return ds | |||||
def get_one(self, sample): | |||||
if len(sample)==0: | |||||
return None | |||||
text = [] | |||||
pos_tags = [] | |||||
for w in sample: | |||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
if t3 == '_': | |||||
return None | |||||
text.append(t1) | |||||
pos_tags.append(t2) | |||||
return text, pos_tags | |||||
if __name__ == '__main__': | |||||
reader = ConlluPOSReader() | |||||
d = reader.load('/home/hyan/train.conllx') | |||||
print('reader') |
@@ -1,14 +1,18 @@ | |||||
[train] | [train] | ||||
epochs = 30 | |||||
batch_size = 64 | |||||
epochs = 6 | |||||
batch_size = 32 | |||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
validate = true | validate = true | ||||
save_best_dev = true | save_best_dev = true | ||||
model_saved_path = "./save/" | model_saved_path = "./save/" | ||||
rnn_hidden_units = 100 | |||||
word_emb_dim = 100 | |||||
valid_step = 250 | |||||
eval_sort_key = 'accuracy' | |||||
[model] | |||||
rnn_hidden_units = 300 | |||||
word_emb_dim = 300 | |||||
dropout = 0.5 | |||||
use_crf = true | use_crf = true | ||||
use_cuda = true | |||||
print_every_step = 10 | print_every_step = 10 | ||||
[test] | [test] | ||||
@@ -0,0 +1,131 @@ | |||||
from collections import Counter | |||||
from fastNLP.api.processor import Processor | |||||
from fastNLP.core.dataset import DataSet | |||||
class CombineWordAndPosProcessor(Processor): | |||||
def __init__(self, word_field_name, pos_field_name): | |||||
super(CombineWordAndPosProcessor, self).__init__(None, None) | |||||
self.word_field_name = word_field_name | |||||
self.pos_field_name = pos_field_name | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
chars = ins[self.word_field_name] | |||||
bmes_pos = ins[self.pos_field_name] | |||||
word_list = [] | |||||
pos_list = [] | |||||
pos_stack_cnt = Counter() | |||||
char_stack = [] | |||||
for char, p in zip(chars, bmes_pos): | |||||
parts = p.split('-') | |||||
pre = parts[0] | |||||
post = parts[1] | |||||
if pre.lower() == 's': | |||||
if len(pos_stack_cnt) != 0: | |||||
pos = pos_stack_cnt.most_common(1)[0][0] | |||||
pos_list.append(pos) | |||||
word_list.append(''.join(char_stack)) | |||||
pos_list.append(post) | |||||
word_list.append(char) | |||||
char_stack.clear() | |||||
pos_stack_cnt.clear() | |||||
elif pre.lower() == 'e': | |||||
pos_stack_cnt.update([post]) | |||||
char_stack.append(char) | |||||
pos = pos_stack_cnt.most_common(1)[0][0] | |||||
pos_list.append(pos) | |||||
word_list.append(''.join(char_stack)) | |||||
char_stack.clear() | |||||
pos_stack_cnt.clear() | |||||
elif pre.lower() == 'b': | |||||
if len(pos_stack_cnt) != 0: | |||||
pos = pos_stack_cnt.most_common(1)[0][0] | |||||
pos_list.append(pos) | |||||
word_list.append(''.join(char_stack)) | |||||
char_stack.clear() | |||||
pos_stack_cnt.clear() | |||||
char_stack.append(char) | |||||
pos_stack_cnt.update([post]) | |||||
else: | |||||
char_stack.append(char) | |||||
pos_stack_cnt.update([post]) | |||||
ins['word_list'] = word_list | |||||
ins['pos_list'] = pos_list | |||||
return dataset | |||||
class PosOutputStrProcessor(Processor): | |||||
def __init__(self, word_field_name, pos_field_name): | |||||
super(PosOutputStrProcessor, self).__init__(None, None) | |||||
self.word_field_name = word_field_name | |||||
self.pos_field_name = pos_field_name | |||||
self.sep = '_' | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
word_list = ins[self.word_field_name] | |||||
pos_list = ins[self.pos_field_name] | |||||
word_pos_list = [] | |||||
for word, pos in zip(word_list, pos_list): | |||||
word_pos_list.append(word + self.sep + pos) | |||||
#TODO 应该可以定制 | |||||
ins['word_pos_output'] = ' '.join(word_pos_list) | |||||
return dataset | |||||
if __name__ == '__main__': | |||||
chars = ['迈', '向', '充', '满', '希', '望', '的', '新', '世', '纪', '—', '—', '一', '九', '九', '八', '年', '新', '年', '讲', '话', '(', '附', '图', '片', '1', '张', ')'] | |||||
bmes_pos = ['B-v', 'E-v', 'B-v', 'E-v', 'B-n', 'E-n', 'S-u', 'S-a', 'B-n', 'E-n', 'B-w', 'E-w', 'B-t', 'M-t', 'M-t', 'M-t', 'E-t', 'B-t', 'E-t', 'B-n', 'E-n', 'S-w', 'S-v', 'B-n', 'E-n', 'S-m', 'S-q', 'S-w'] | |||||
word_list = [] | |||||
pos_list = [] | |||||
pos_stack_cnt = Counter() | |||||
char_stack = [] | |||||
for char, p in zip(''.join(chars), bmes_pos): | |||||
parts = p.split('-') | |||||
pre = parts[0] | |||||
post = parts[1] | |||||
if pre.lower() == 's': | |||||
if len(pos_stack_cnt) != 0: | |||||
pos = pos_stack_cnt.most_common(1)[0][0] | |||||
pos_list.append(pos) | |||||
word_list.append(''.join(char_stack)) | |||||
pos_list.append(post) | |||||
word_list.append(char) | |||||
char_stack.clear() | |||||
pos_stack_cnt.clear() | |||||
elif pre.lower() == 'e': | |||||
pos_stack_cnt.update([post]) | |||||
char_stack.append(char) | |||||
pos = pos_stack_cnt.most_common(1)[0][0] | |||||
pos_list.append(pos) | |||||
word_list.append(''.join(char_stack)) | |||||
char_stack.clear() | |||||
pos_stack_cnt.clear() | |||||
elif pre.lower() == 'b': | |||||
if len(pos_stack_cnt) != 0: | |||||
pos = pos_stack_cnt.most_common(1)[0][0] | |||||
pos_list.append(pos) | |||||
word_list.append(''.join(char_stack)) | |||||
char_stack.clear() | |||||
pos_stack_cnt.clear() | |||||
char_stack.append(char) | |||||
pos_stack_cnt.update([post]) | |||||
else: | |||||
char_stack.append(char) | |||||
pos_stack_cnt.update([post]) | |||||
print(word_list) | |||||
print(pos_list) |
@@ -1,146 +0,0 @@ | |||||
import os | |||||
import sys | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.core.trainer import SeqLabelTrainer | |||||
from fastNLP.loader.dataset_loader import PeopleDailyCorpusLoader, BaseLoader | |||||
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||||
from fastNLP.saver.model_saver import ModelSaver | |||||
from fastNLP.loader.model_loader import ModelLoader | |||||
from fastNLP.core.tester import SeqLabelTester | |||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | |||||
from fastNLP.core.predictor import SeqLabelInfer | |||||
# not in the file's dir | |||||
if len(os.path.dirname(__file__)) != 0: | |||||
os.chdir(os.path.dirname(__file__)) | |||||
datadir = "/home/zyfeng/data/" | |||||
cfgfile = './pos_tag.cfg' | |||||
data_name = "CWS_POS_TAG_NER_people_daily.txt" | |||||
pos_tag_data_path = os.path.join(datadir, data_name) | |||||
pickle_path = "save" | |||||
data_infer_path = os.path.join(datadir, "infer.utf8") | |||||
def infer(): | |||||
# Config Loader | |||||
test_args = ConfigSection() | |||||
ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | |||||
word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
test_args["vocab_size"] = len(word2index) | |||||
index2label = load_pickle(pickle_path, "class2id.pkl") | |||||
test_args["num_classes"] = len(index2label) | |||||
# Define the same model | |||||
model = AdvSeqLabel(test_args) | |||||
try: | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
print('model loaded!') | |||||
except Exception as e: | |||||
print('cannot load model!') | |||||
raise | |||||
# Data Loader | |||||
raw_data_loader = BaseLoader(data_infer_path) | |||||
infer_data = raw_data_loader.load_lines() | |||||
print('data loaded') | |||||
# Inference interface | |||||
infer = SeqLabelInfer(pickle_path) | |||||
results = infer.predict(model, infer_data) | |||||
print(results) | |||||
print("Inference finished!") | |||||
def train(): | |||||
# Config Loader | |||||
train_args = ConfigSection() | |||||
test_args = ConfigSection() | |||||
ConfigLoader("good_name").load_config(cfgfile, {"train": train_args, "test": test_args}) | |||||
# Data Loader | |||||
loader = PeopleDailyCorpusLoader() | |||||
train_data, _ = loader.load() | |||||
# Preprocessor | |||||
preprocessor = SeqLabelPreprocess() | |||||
data_train, data_dev = preprocessor.run(train_data, pickle_path=pickle_path, train_dev_split=0.3) | |||||
train_args["vocab_size"] = preprocessor.vocab_size | |||||
train_args["num_classes"] = preprocessor.num_classes | |||||
# Trainer | |||||
trainer = SeqLabelTrainer(**train_args.data) | |||||
# Model | |||||
model = AdvSeqLabel(train_args) | |||||
try: | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
print('model parameter loaded!') | |||||
except Exception as e: | |||||
print("No saved model. Continue.") | |||||
pass | |||||
# Start training | |||||
trainer.train(model, data_train, data_dev) | |||||
print("Training finished!") | |||||
# Saver | |||||
saver = ModelSaver("./save/saved_model.pkl") | |||||
saver.save_pytorch(model) | |||||
print("Model saved!") | |||||
def test(): | |||||
# Config Loader | |||||
test_args = ConfigSection() | |||||
ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | |||||
word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
test_args["vocab_size"] = len(word2index) | |||||
index2label = load_pickle(pickle_path, "class2id.pkl") | |||||
test_args["num_classes"] = len(index2label) | |||||
# load dev data | |||||
dev_data = load_pickle(pickle_path, "data_dev.pkl") | |||||
# Define the same model | |||||
model = AdvSeqLabel(test_args) | |||||
# Dump trained parameters into the model | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
print("model loaded!") | |||||
# Tester | |||||
tester = SeqLabelTester(**test_args.data) | |||||
# Start testing | |||||
tester.test(model, dev_data) | |||||
# print test results | |||||
print(tester.show_metrics()) | |||||
print("model tested!") | |||||
if __name__ == "__main__": | |||||
import argparse | |||||
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | |||||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | |||||
args = parser.parse_args() | |||||
if args.mode == 'train': | |||||
train() | |||||
elif args.mode == 'test': | |||||
test() | |||||
elif args.mode == 'infer': | |||||
infer() | |||||
else: | |||||
print('no mode specified for model!') | |||||
parser.print_help() |
@@ -1,4 +1,4 @@ | |||||
numpy>=1.14.2 | numpy>=1.14.2 | ||||
torch==0.4.0 | |||||
torchvision>=0.1.8 | |||||
torch>=0.4.0 | |||||
tensorboardX | tensorboardX | ||||
tqdm>=4.28.1 |
@@ -12,12 +12,12 @@ with open('requirements.txt', encoding='utf-8') as f: | |||||
reqs = f.read() | reqs = f.read() | ||||
setup( | setup( | ||||
name='fastNLP', | |||||
version='0.1.0', | |||||
name='FastNLP', | |||||
version='0.1.1', | |||||
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | ||||
long_description=readme, | long_description=readme, | ||||
license=license, | license=license, | ||||
author='fudanNLP', | |||||
author='FudanNLP', | |||||
python_requires='>=3.5', | python_requires='>=3.5', | ||||
packages=find_packages(), | packages=find_packages(), | ||||
install_requires=reqs.strip().split('\n'), | install_requires=reqs.strip().split('\n'), | ||||
@@ -0,0 +1,12 @@ | |||||
import unittest | |||||
from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor | |||||
from fastNLP.core.dataset import DataSet | |||||
class TestProcessor(unittest.TestCase): | |||||
def test_FullSpaceToHalfSpaceProcessor(self): | |||||
ds = DataSet({"word": ["00, u1, u), (u2, u2"]}) | |||||
proc = FullSpaceToHalfSpaceProcessor("word") | |||||
ds = proc(ds) | |||||
self.assertTrue(ds.field_arrays["word"].content, ["00, u1, u), (u2, u2"]) |
@@ -1,53 +1,33 @@ | |||||
import unittest | import unittest | ||||
import torch | |||||
import numpy as np | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.field import TextField, LabelField | |||||
from fastNLP.core.instance import Instance | |||||
raw_texts = ["i am a cat", | |||||
"this is a test of new batch", | |||||
"ha ha", | |||||
"I am a good boy .", | |||||
"This is the most beautiful girl ." | |||||
] | |||||
texts = [text.strip().split() for text in raw_texts] | |||||
labels = [0, 1, 0, 0, 1] | |||||
# prepare vocabulary | |||||
vocab = {} | |||||
for text in texts: | |||||
for tokens in text: | |||||
if tokens not in vocab: | |||||
vocab[tokens] = len(vocab) | |||||
from fastNLP.core.dataset import construct_dataset | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
class TestCase1(unittest.TestCase): | class TestCase1(unittest.TestCase): | ||||
def test(self): | |||||
data = DataSet() | |||||
for text, label in zip(texts, labels): | |||||
x = TextField(text, is_target=False) | |||||
y = LabelField(label, is_target=True) | |||||
ins = Instance(text=x, label=y) | |||||
data.append(ins) | |||||
# use vocabulary to index data | |||||
data.index_field("text", vocab) | |||||
# define naive sampler for batch class | |||||
class SeqSampler: | |||||
def __call__(self, dataset): | |||||
return list(range(len(dataset))) | |||||
# use batch to iterate dataset | |||||
data_iterator = Batch(data, 2, SeqSampler(), False) | |||||
total_data = 0 | |||||
for batch_x, batch_y in data_iterator: | |||||
total_data += batch_x["text"].size(0) | |||||
self.assertTrue(batch_x["text"].size(0) == 2 or total_data == len(raw_texts)) | |||||
self.assertTrue(isinstance(batch_x, dict)) | |||||
self.assertTrue(isinstance(batch_x["text"], torch.LongTensor)) | |||||
self.assertTrue(isinstance(batch_y, dict)) | |||||
self.assertTrue(isinstance(batch_y["label"], torch.LongTensor)) | |||||
def test_simple(self): | |||||
dataset = construct_dataset( | |||||
[["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | |||||
dataset.set_target() | |||||
batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
cnt = 0 | |||||
for _, _ in batch: | |||||
cnt += 1 | |||||
self.assertEqual(cnt, 10) | |||||
def test_dataset_batching(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
ds.set_input("x") | |||||
ds.set_target("y") | |||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
for x, y in iter: | |||||
self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray)) | |||||
self.assertEqual(len(x["x"]), 4) | |||||
self.assertEqual(len(y["y"]), 4) | |||||
self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4]) | |||||
self.assertListEqual(list(y["y"][-1]), [5, 6]) |
@@ -1,54 +1,200 @@ | |||||
import os | |||||
import unittest | import unittest | ||||
from fastNLP.loader.dataset_loader import convert_seq2seq_dataset, convert_seq_dataset | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.fieldarray import FieldArray | |||||
from fastNLP.core.instance import Instance | |||||
class TestDataSet(unittest.TestCase): | class TestDataSet(unittest.TestCase): | ||||
labeled_data_list = [ | |||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
] | |||||
unlabeled_data_list = [ | |||||
["a", "b", "e", "d"], | |||||
["a", "b", "e", "d"], | |||||
["a", "b", "e", "d"] | |||||
] | |||||
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} | |||||
label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4} | |||||
def test_case_1(self): | |||||
data_set = convert_seq2seq_dataset(self.labeled_data_list) | |||||
data_set.index_field("word_seq", self.word_vocab) | |||||
data_set.index_field("label_seq", self.label_vocab) | |||||
self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
self.assertTrue(len(data_set) > 0) | |||||
self.assertTrue(hasattr(data_set[0], "fields")) | |||||
self.assertTrue("word_seq" in data_set[0].fields) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||||
self.assertTrue("label_seq" in data_set[0].fields) | |||||
self.assertTrue(hasattr(data_set[0].fields["label_seq"], "text")) | |||||
self.assertTrue(hasattr(data_set[0].fields["label_seq"], "_index")) | |||||
self.assertEqual(data_set[0].fields["label_seq"].text, self.labeled_data_list[0][1]) | |||||
self.assertEqual(data_set[0].fields["label_seq"]._index, | |||||
[self.label_vocab[c] for c in self.labeled_data_list[0][1]]) | |||||
def test_case_2(self): | |||||
data_set = convert_seq_dataset(self.unlabeled_data_list) | |||||
data_set.index_field("word_seq", self.word_vocab) | |||||
self.assertEqual(len(data_set), len(self.unlabeled_data_list)) | |||||
self.assertTrue(len(data_set) > 0) | |||||
self.assertTrue(hasattr(data_set[0], "fields")) | |||||
self.assertTrue("word_seq" in data_set[0].fields) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.unlabeled_data_list[0]) | |||||
self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
[self.word_vocab[c] for c in self.unlabeled_data_list[0]]) | |||||
def test_init_v1(self): | |||||
ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40) | |||||
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | |||||
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | |||||
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) | |||||
def test_init_v2(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | |||||
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | |||||
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) | |||||
def test_init_assert(self): | |||||
with self.assertRaises(AssertionError): | |||||
_ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100}) | |||||
with self.assertRaises(AssertionError): | |||||
_ = DataSet([[1, 2, 3, 4]] * 10) | |||||
with self.assertRaises(ValueError): | |||||
_ = DataSet(0.00001) | |||||
def test_append(self): | |||||
dd = DataSet() | |||||
for _ in range(3): | |||||
dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6])) | |||||
self.assertEqual(len(dd), 3) | |||||
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3) | |||||
self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3) | |||||
def test_add_append(self): | |||||
dd = DataSet() | |||||
dd.add_field("x", [[1, 2, 3]] * 10) | |||||
dd.add_field("y", [[1, 2, 3, 4]] * 10) | |||||
dd.add_field("z", [[5, 6]] * 10) | |||||
self.assertEqual(len(dd), 10) | |||||
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3]] * 10) | |||||
self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) | |||||
self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) | |||||
with self.assertRaises(RuntimeError): | |||||
dd.add_field("??", [[1, 2]] * 40) | |||||
def test_delete_field(self): | |||||
dd = DataSet() | |||||
dd.add_field("x", [[1, 2, 3]] * 10) | |||||
dd.add_field("y", [[1, 2, 3, 4]] * 10) | |||||
dd.delete_field("x") | |||||
self.assertFalse("x" in dd.field_arrays) | |||||
self.assertTrue("y" in dd.field_arrays) | |||||
def test_getitem(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
ins_1, ins_0 = ds[0], ds[1] | |||||
self.assertTrue(isinstance(ins_1, Instance) and isinstance(ins_0, Instance)) | |||||
self.assertEqual(ins_1["x"], [1, 2, 3, 4]) | |||||
self.assertEqual(ins_1["y"], [5, 6]) | |||||
self.assertEqual(ins_0["x"], [1, 2, 3, 4]) | |||||
self.assertEqual(ins_0["y"], [5, 6]) | |||||
sub_ds = ds[:10] | |||||
self.assertTrue(isinstance(sub_ds, DataSet)) | |||||
self.assertEqual(len(sub_ds), 10) | |||||
def test_get_item_error(self): | |||||
with self.assertRaises(RuntimeError): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
_ = ds[40:] | |||||
with self.assertRaises(KeyError): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
_ = ds["kom"] | |||||
def test_len_(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
self.assertEqual(len(ds), 40) | |||||
ds = DataSet() | |||||
self.assertEqual(len(ds), 0) | |||||
def test_apply(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx") | |||||
self.assertTrue("rx" in ds.field_arrays) | |||||
self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) | |||||
ds.apply(lambda ins: len(ins["y"]), new_field_name="y") | |||||
self.assertEqual(ds.field_arrays["y"].content[0], 2) | |||||
res = ds.apply(lambda ins: len(ins["x"])) | |||||
self.assertTrue(isinstance(res, list) and len(res) > 0) | |||||
self.assertTrue(res[0], 4) | |||||
def test_drop(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) | |||||
ds.drop(lambda ins: len(ins["y"]) < 3) | |||||
self.assertEqual(len(ds), 20) | |||||
def test_contains(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
self.assertTrue("x" in ds) | |||||
self.assertTrue("y" in ds) | |||||
self.assertFalse("z" in ds) | |||||
def test_rename_field(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
ds.rename_field("x", "xx") | |||||
self.assertTrue("xx" in ds) | |||||
self.assertFalse("x" in ds) | |||||
with self.assertRaises(KeyError): | |||||
ds.rename_field("yyy", "oo") | |||||
def test_input_target(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
ds.set_input("x") | |||||
ds.set_target("y") | |||||
self.assertTrue(ds.field_arrays["x"].is_input) | |||||
self.assertTrue(ds.field_arrays["y"].is_target) | |||||
with self.assertRaises(KeyError): | |||||
ds.set_input("xxx") | |||||
with self.assertRaises(KeyError): | |||||
ds.set_input("yyy") | |||||
def test_get_input_name(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
self.assertEqual(ds.get_input_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_input]) | |||||
def test_get_target_name(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target]) | |||||
def test_apply2(self): | |||||
def split_sent(ins): | |||||
return ins['raw_sentence'].split() | |||||
dataset = DataSet.read_csv('test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), | |||||
sep='\t') | |||||
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0) | |||||
dataset.apply(split_sent, new_field_name='words', is_input=True) | |||||
# print(dataset) | |||||
def test_add_field(self): | |||||
ds = DataSet({"x": [3, 4]}) | |||||
ds.add_field('y', [['hello', 'world'], ['this', 'is', 'a', 'test']], is_input=True, is_target=True) | |||||
# ds.apply(lambda x:[x['x']]*3, is_input=True, is_target=True, new_field_name='y') | |||||
print(ds) | |||||
def test_save_load(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
ds.save("./my_ds.pkl") | |||||
self.assertTrue(os.path.exists("./my_ds.pkl")) | |||||
ds_1 = DataSet.load("./my_ds.pkl") | |||||
os.remove("my_ds.pkl") | |||||
def test_get_all_fields(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
ans = ds.get_all_fields() | |||||
self.assertEqual(ans["x"].content, [[1, 2, 3, 4]] * 10) | |||||
self.assertEqual(ans["y"].content, [[5, 6]] * 10) | |||||
def test_get_field(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
ans = ds.get_field("x") | |||||
self.assertTrue(isinstance(ans, FieldArray)) | |||||
self.assertEqual(ans.content, [[1, 2, 3, 4]] * 10) | |||||
ans = ds.get_field("y") | |||||
self.assertTrue(isinstance(ans, FieldArray)) | |||||
self.assertEqual(ans.content, [[5, 6]] * 10) | |||||
def test_reader(self): | |||||
# 跑通即可 | |||||
ds = DataSet().read_naive("test/data_for_tests/tutorial_sample_dataset.csv") | |||||
self.assertTrue(isinstance(ds, DataSet)) | |||||
self.assertTrue(len(ds) > 0) | |||||
ds = DataSet().read_rawdata("test/data_for_tests/people_daily_raw.txt") | |||||
self.assertTrue(isinstance(ds, DataSet)) | |||||
self.assertTrue(len(ds) > 0) | |||||
ds = DataSet().read_pos("test/data_for_tests/people.txt") | |||||
self.assertTrue(isinstance(ds, DataSet)) | |||||
self.assertTrue(len(ds) > 0) | |||||
class TestDataSetIter(unittest.TestCase): | |||||
def test__repr__(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
for iter in ds: | |||||
self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4],\n'y': [5, 6]}") |
@@ -0,0 +1,99 @@ | |||||
import unittest | |||||
import numpy as np | |||||
from fastNLP.core.fieldarray import FieldArray | |||||
class TestFieldArray(unittest.TestCase): | |||||
def test(self): | |||||
fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) | |||||
self.assertEqual(len(fa), 5) | |||||
fa.append(6) | |||||
self.assertEqual(len(fa), 6) | |||||
self.assertEqual(fa[-1], 6) | |||||
self.assertEqual(fa[0], 1) | |||||
fa[-1] = 60 | |||||
self.assertEqual(fa[-1], 60) | |||||
self.assertEqual(fa.get(0), 1) | |||||
self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray)) | |||||
self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3]) | |||||
def test_type_conversion(self): | |||||
fa = FieldArray("x", [1.2, 2.2, 3, 4, 5], is_input=True) | |||||
self.assertEqual(fa.pytype, float) | |||||
self.assertEqual(fa.dtype, np.float64) | |||||
fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) | |||||
fa.append(1.3333) | |||||
self.assertEqual(fa.pytype, float) | |||||
self.assertEqual(fa.dtype, np.float64) | |||||
fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True) | |||||
fa.append(10) | |||||
self.assertEqual(fa.pytype, float) | |||||
self.assertEqual(fa.dtype, np.float64) | |||||
fa = FieldArray("y", ["a", "b", "c", "d"], is_input=True) | |||||
fa.append("e") | |||||
self.assertEqual(fa.dtype, np.str) | |||||
self.assertEqual(fa.pytype, str) | |||||
def test_support_np_array(self): | |||||
fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=True) | |||||
self.assertEqual(fa.dtype, np.ndarray) | |||||
self.assertEqual(fa.pytype, np.ndarray) | |||||
fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5])) | |||||
self.assertEqual(fa.dtype, np.ndarray) | |||||
self.assertEqual(fa.pytype, np.ndarray) | |||||
fa = FieldArray("my_field", np.random.rand(3, 5), is_input=True) | |||||
# in this case, pytype is actually a float. We do not care about it. | |||||
self.assertEqual(fa.dtype, np.float64) | |||||
def test_nested_list(self): | |||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=True) | |||||
self.assertEqual(fa.pytype, float) | |||||
self.assertEqual(fa.dtype, np.float64) | |||||
def test_getitem_v1(self): | |||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||||
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5]) | |||||
ans = fa[[0, 1]] | |||||
self.assertTrue(isinstance(ans, np.ndarray)) | |||||
self.assertTrue(isinstance(ans[0], np.ndarray)) | |||||
self.assertEqual(ans[0].tolist(), [1.1, 2.2, 3.3, 4.4, 5.5]) | |||||
self.assertEqual(ans[1].tolist(), [1, 2, 3, 4, 5]) | |||||
self.assertEqual(ans.dtype, np.float64) | |||||
def test_getitem_v2(self): | |||||
x = np.random.rand(10, 5) | |||||
fa = FieldArray("my_field", x, is_input=True) | |||||
indices = [0, 1, 3, 4, 6] | |||||
for a, b in zip(fa[indices], x[indices]): | |||||
self.assertListEqual(a.tolist(), b.tolist()) | |||||
def test_append(self): | |||||
with self.assertRaises(Exception): | |||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||||
fa.append(0) | |||||
with self.assertRaises(Exception): | |||||
fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True) | |||||
fa.append([1, 2, 3, 4, 5]) | |||||
with self.assertRaises(Exception): | |||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||||
fa.append([]) | |||||
with self.assertRaises(Exception): | |||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||||
fa.append(["str", 0, 0, 0, 1.89]) | |||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||||
fa.append([1.2, 2.3, 3.4, 4.5, 5.6]) | |||||
self.assertEqual(len(fa), 3) | |||||
self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) |
@@ -0,0 +1,35 @@ | |||||
import unittest | |||||
from fastNLP.core.instance import Instance | |||||
class TestCase(unittest.TestCase): | |||||
def test_init(self): | |||||
fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | |||||
ins = Instance(x=[1, 2, 3], y=[4, 5, 6]) | |||||
self.assertTrue(isinstance(ins.fields, dict)) | |||||
self.assertEqual(ins.fields, fields) | |||||
ins = Instance(**fields) | |||||
self.assertEqual(ins.fields, fields) | |||||
def test_add_field(self): | |||||
fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | |||||
ins = Instance(**fields) | |||||
ins.add_field("z", [1, 1, 1]) | |||||
fields.update({"z": [1, 1, 1]}) | |||||
self.assertEqual(ins.fields, fields) | |||||
def test_get_item(self): | |||||
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | |||||
ins = Instance(**fields) | |||||
self.assertEqual(ins["x"], [1, 2, 3]) | |||||
self.assertEqual(ins["y"], [4, 5, 6]) | |||||
self.assertEqual(ins["z"], [1, 1, 1]) | |||||
def test_repr(self): | |||||
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | |||||
ins = Instance(**fields) | |||||
# simple print, that is enough. | |||||
print(ins) |
@@ -0,0 +1,87 @@ | |||||
import unittest | |||||
import torch | |||||
import torch.nn.functional as F | |||||
import fastNLP.core.losses as loss | |||||
from fastNLP.core.losses import squash, unpad | |||||
class TestLoss(unittest.TestCase): | |||||
def test_CrossEntropyLoss(self): | |||||
ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth") | |||||
a = torch.randn(3, 5, requires_grad=False) | |||||
b = torch.empty(3, dtype=torch.long).random_(5) | |||||
ans = ce({"my_predict": a}, {"my_truth": b}) | |||||
self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) | |||||
def test_BCELoss(self): | |||||
bce = loss.BCELoss(pred="my_predict", target="my_truth") | |||||
a = torch.sigmoid(torch.randn((3, 5), requires_grad=False)) | |||||
b = torch.randn((3, 5), requires_grad=False) | |||||
ans = bce({"my_predict": a}, {"my_truth": b}) | |||||
self.assertEqual(ans, torch.nn.functional.binary_cross_entropy(a, b)) | |||||
def test_L1Loss(self): | |||||
l1 = loss.L1Loss(pred="my_predict", target="my_truth") | |||||
a = torch.randn(3, 5, requires_grad=False) | |||||
b = torch.randn(3, 5) | |||||
ans = l1({"my_predict": a}, {"my_truth": b}) | |||||
self.assertEqual(ans, torch.nn.functional.l1_loss(a, b)) | |||||
def test_NLLLoss(self): | |||||
l1 = loss.NLLLoss(pred="my_predict", target="my_truth") | |||||
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) | |||||
b = torch.tensor([1, 0, 4]) | |||||
ans = l1({"my_predict": a}, {"my_truth": b}) | |||||
self.assertEqual(ans, torch.nn.functional.nll_loss(a, b)) | |||||
class TestLosserError(unittest.TestCase): | |||||
def test_losser1(self): | |||||
# (1) only input, targets passed | |||||
pred_dict = {"pred": torch.zeros(4, 3)} | |||||
target_dict = {'target': torch.zeros(4).long()} | |||||
los = loss.CrossEntropyLoss() | |||||
print(los(pred_dict=pred_dict, target_dict=target_dict)) | |||||
# | |||||
def test_losser2(self): | |||||
# (2) with corrupted size | |||||
pred_dict = {"pred": torch.zeros(16, 3)} | |||||
target_dict = {'target': torch.zeros(16, 3).long()} | |||||
los = loss.CrossEntropyLoss() | |||||
with self.assertRaises(RuntimeError): | |||||
print(los(pred_dict=pred_dict, target_dict=target_dict)) | |||||
def test_losser3(self): | |||||
# (2) with corrupted size | |||||
pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param': 0} | |||||
target_dict = {'target': torch.zeros(16).long()} | |||||
los = loss.CrossEntropyLoss() | |||||
print(los(pred_dict=pred_dict, target_dict=target_dict)) | |||||
def test_check_error(self): | |||||
l1 = loss.NLLLoss(pred="my_predict", target="my_truth") | |||||
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) | |||||
b = torch.tensor([1, 0, 4]) | |||||
with self.assertRaises(Exception): | |||||
ans = l1({"wrong_predict": a, "my": b}, {"my_truth": b}) | |||||
with self.assertRaises(Exception): | |||||
ans = l1({"my_predict": a}, {"truth": b, "my": a}) | |||||
class TestLossUtils(unittest.TestCase): | |||||
def test_squash(self): | |||||
a, b = squash(torch.randn(3, 5), torch.randn(3, 5)) | |||||
self.assertEqual(tuple(a.size()), (3, 5)) | |||||
self.assertEqual(tuple(b.size()), (15,)) | |||||
def test_unpad(self): | |||||
a, b = unpad(torch.randn(5, 8, 3), torch.randn(5, 8)) | |||||
self.assertEqual(tuple(a.size()), (5, 8, 3)) | |||||
self.assertEqual(tuple(b.size()), (5, 8)) |
@@ -1,100 +1,145 @@ | |||||
import os | |||||
import sys | |||||
sys.path = [os.path.join(os.path.dirname(__file__), '..')] + sys.path | |||||
from fastNLP.core import metrics | |||||
# from sklearn import metrics as skmetrics | |||||
import unittest | import unittest | ||||
from numpy import random | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
import numpy as np | |||||
import torch | import torch | ||||
from fastNLP.core.metrics import AccuracyMetric | |||||
from fastNLP.core.metrics import pred_topk, accuracy_topk | |||||
class TestAccuracyMetric(unittest.TestCase): | |||||
def test_AccuracyMetric1(self): | |||||
# (1) only input, targets passed | |||||
pred_dict = {"pred": torch.zeros(4, 3)} | |||||
target_dict = {'target': torch.zeros(4)} | |||||
metric = AccuracyMetric() | |||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
print(metric.get_metric()) | |||||
def test_AccuracyMetric2(self): | |||||
# (2) with corrupted size | |||||
try: | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4)} | |||||
metric = AccuracyMetric() | |||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
print(metric.get_metric()) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
def test_AccuracyMetric3(self): | |||||
# (3) the second batch is corrupted size | |||||
try: | |||||
metric = AccuracyMetric() | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
print(metric.get_metric()) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
def test_AccuaryMetric4(self): | |||||
# (5) check reset | |||||
metric = AccuracyMetric() | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3) + 1} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 0}) | |||||
def test_AccuaryMetric5(self): | |||||
# (5) check reset | |||||
metric = AccuracyMetric() | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(reset=False), {'acc': 1}) | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3) + 1} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 0.5}) | |||||
def test_AccuaryMetric6(self): | |||||
# (6) check numpy array is not acceptable | |||||
try: | |||||
metric = AccuracyMetric() | |||||
pred_dict = {"pred": np.zeros((4, 3, 2))} | |||||
target_dict = {'target': np.zeros((4, 3))} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
def test_AccuaryMetric7(self): | |||||
# (7) check map, match | |||||
metric = AccuracyMetric(pred='predictions', target='targets') | |||||
pred_dict = {"predictions": torch.zeros(4, 3, 2)} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
def test_AccuaryMetric8(self): | |||||
# (8) check map, does not match. use stop_fast_param to stop fast param map | |||||
try: | |||||
metric = AccuracyMetric(pred='predictions', target='targets') | |||||
pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param": 1} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
def test_AccuaryMetric9(self): | |||||
# (9) check map, include unused | |||||
try: | |||||
metric = AccuracyMetric(pred='prediction', target='targets') | |||||
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
def test_AccuaryMetric10(self): | |||||
# (10) check _fast_metric | |||||
try: | |||||
metric = AccuracyMetric() | |||||
pred_dict = {"predictions": torch.zeros(4, 3, 2), "masks": torch.zeros(4, 3)} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
class TestUsefulFunctions(unittest.TestCase): | |||||
# 测试metrics.py中一些看上去挺有用的函数 | |||||
def test_case_1(self): | |||||
# multi-class | |||||
_ = accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3) | |||||
_ = pred_topk(np.random.randint(0, 3, size=(10, 1))) | |||||
def generate_fake_label(low, high, size): | |||||
return random.randint(low, high, size), random.randint(low, high, size) | |||||
class TestEvaluator(unittest.TestCase): | |||||
def test_a(self): | |||||
evaluator = SeqLabelEvaluator() | |||||
pred = [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]] | |||||
truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4])}] | |||||
ans = evaluator(pred, truth) | |||||
print(ans) | |||||
def test_b(self): | |||||
evaluator = SeqLabelEvaluator() | |||||
pred = [[1, 2, 3, 4, 5, 0, 0], [1, 2, 3, 4, 5, 0, 0]] | |||||
truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3, 0, 0])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4, 0, 0])}] | |||||
ans = evaluator(pred, truth) | |||||
print(ans) | |||||
class TestMetrics(unittest.TestCase): | |||||
delta = 1e-5 | |||||
# test for binary, multiclass, multilabel | |||||
data_types = [((1000,), 2), ((1000,), 10), ((1000, 10), 2)] | |||||
fake_data = [generate_fake_label(0, high, shape) for shape, high in data_types] | |||||
def test_accuracy_score(self): | |||||
for y_true, y_pred in self.fake_data: | |||||
for normalize in [True, False]: | |||||
for sample_weight in [None, random.rand(y_true.shape[0])]: | |||||
test = metrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) | |||||
# ans = skmetrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) | |||||
# self.assertAlmostEqual(test, ans, delta=self.delta) | |||||
def test_recall_score(self): | |||||
for y_true, y_pred in self.fake_data: | |||||
# print(y_true.shape) | |||||
labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None | |||||
test = metrics.recall_score(y_true, y_pred, labels=labels, average=None) | |||||
if not isinstance(test, list): | |||||
test = list(test) | |||||
# ans = skmetrics.recall_score(y_true, y_pred,labels=labels, average=None) | |||||
# ans = list(ans) | |||||
# for a, b in zip(test, ans): | |||||
# # print('{}, {}'.format(a, b)) | |||||
# self.assertAlmostEqual(a, b, delta=self.delta) | |||||
# test binary | |||||
y_true, y_pred = generate_fake_label(0, 2, 1000) | |||||
test = metrics.recall_score(y_true, y_pred) | |||||
# ans = skmetrics.recall_score(y_true, y_pred) | |||||
# self.assertAlmostEqual(ans, test, delta=self.delta) | |||||
def test_precision_score(self): | |||||
for y_true, y_pred in self.fake_data: | |||||
# print(y_true.shape) | |||||
labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None | |||||
test = metrics.precision_score(y_true, y_pred, labels=labels, average=None) | |||||
# ans = skmetrics.precision_score(y_true, y_pred,labels=labels, average=None) | |||||
# ans, test = list(ans), list(test) | |||||
# for a, b in zip(test, ans): | |||||
# # print('{}, {}'.format(a, b)) | |||||
# self.assertAlmostEqual(a, b, delta=self.delta) | |||||
# test binary | |||||
y_true, y_pred = generate_fake_label(0, 2, 1000) | |||||
test = metrics.precision_score(y_true, y_pred) | |||||
# ans = skmetrics.precision_score(y_true, y_pred) | |||||
# self.assertAlmostEqual(ans, test, delta=self.delta) | |||||
def test_f1_score(self): | |||||
for y_true, y_pred in self.fake_data: | |||||
# print(y_true.shape) | |||||
labels = list(range(y_true.shape[1])) if len(y_true.shape) >= 2 else None | |||||
test = metrics.f1_score(y_true, y_pred, labels=labels, average=None) | |||||
# ans = skmetrics.f1_score(y_true, y_pred,labels=labels, average=None) | |||||
# ans, test = list(ans), list(test) | |||||
# for a, b in zip(test, ans): | |||||
# # print('{}, {}'.format(a, b)) | |||||
# self.assertAlmostEqual(a, b, delta=self.delta) | |||||
# test binary | |||||
y_true, y_pred = generate_fake_label(0, 2, 1000) | |||||
test = metrics.f1_score(y_true, y_pred) | |||||
# ans = skmetrics.f1_score(y_true, y_pred) | |||||
# self.assertAlmostEqual(ans, test, delta=self.delta) | |||||
if __name__ == '__main__': | |||||
unittest.main() | |||||
# 跑通即可 |
@@ -0,0 +1,54 @@ | |||||
import unittest | |||||
import torch | |||||
from fastNLP.core.optimizer import SGD, Adam | |||||
class TestOptim(unittest.TestCase): | |||||
def test_SGD(self): | |||||
optim = SGD(model_params=torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue("lr" in optim.__dict__["settings"]) | |||||
self.assertTrue("momentum" in optim.__dict__["settings"]) | |||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue(isinstance(res, torch.optim.SGD)) | |||||
optim = SGD(lr=0.001) | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue(isinstance(res, torch.optim.SGD)) | |||||
optim = SGD(lr=0.002, momentum=0.989) | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | |||||
self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989) | |||||
optim = SGD(0.001) | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue(isinstance(res, torch.optim.SGD)) | |||||
with self.assertRaises(TypeError): | |||||
_ = SGD("???") | |||||
with self.assertRaises(TypeError): | |||||
_ = SGD(0.001, lr=0.002) | |||||
def test_Adam(self): | |||||
optim = Adam(model_params=torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue("lr" in optim.__dict__["settings"]) | |||||
self.assertTrue("weight_decay" in optim.__dict__["settings"]) | |||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue(isinstance(res, torch.optim.Adam)) | |||||
optim = Adam(lr=0.001) | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue(isinstance(res, torch.optim.Adam)) | |||||
optim = Adam(lr=0.002, weight_decay=0.989) | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | |||||
self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989) | |||||
optim = Adam(0.001) | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue(isinstance(res, torch.optim.Adam)) |
@@ -1,79 +1,34 @@ | |||||
import os | |||||
import unittest | import unittest | ||||
import numpy as np | |||||
import torch | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.predictor import Predictor | from fastNLP.core.predictor import Predictor | ||||
from fastNLP.core.preprocess import save_pickle | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.loader.base_loader import BaseLoader | |||||
from fastNLP.loader.dataset_loader import convert_seq_dataset | |||||
from fastNLP.models.cnn_text_classification import CNNText | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
class TestPredictor(unittest.TestCase): | |||||
def test_seq_label(self): | |||||
model_args = { | |||||
"vocab_size": 10, | |||||
"word_emb_dim": 100, | |||||
"rnn_hidden_units": 100, | |||||
"num_classes": 5 | |||||
} | |||||
infer_data = [ | |||||
['a', 'b', 'c', 'd', 'e'], | |||||
['a', '@', 'c', 'd', 'e'], | |||||
['a', 'b', '#', 'd', 'e'], | |||||
['a', 'b', 'c', '?', 'e'], | |||||
['a', 'b', 'c', 'd', '$'], | |||||
['!', 'b', 'c', 'd', 'e'] | |||||
] | |||||
vocab = Vocabulary() | |||||
vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||||
class_vocab = Vocabulary() | |||||
class_vocab.word2idx = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4} | |||||
from fastNLP.modules.encoder.linear import Linear | |||||
os.system("mkdir save") | |||||
save_pickle(class_vocab, "./save/", "label2id.pkl") | |||||
save_pickle(vocab, "./save/", "word2id.pkl") | |||||
model = CNNText(model_args) | |||||
import fastNLP.core.predictor as pre | |||||
predictor = Predictor("./save/", pre.text_classify_post_processor) | |||||
def prepare_fake_dataset(): | |||||
mean = np.array([-3, -3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
# Load infer data | |||||
infer_data_set = convert_seq_dataset(infer_data) | |||||
infer_data_set.index_field("word_seq", vocab) | |||||
mean = np.array([3, 3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
results = predictor.predict(network=model, data=infer_data_set) | |||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||||
return data_set | |||||
self.assertTrue(isinstance(results, list)) | |||||
self.assertGreater(len(results), 0) | |||||
self.assertEqual(len(results), len(infer_data)) | |||||
for res in results: | |||||
self.assertTrue(isinstance(res, str)) | |||||
self.assertTrue(res in class_vocab.word2idx) | |||||
del model, predictor | |||||
infer_data_set.set_origin_len("word_seq") | |||||
model = SeqLabeling(model_args) | |||||
predictor = Predictor("./save/", pre.seq_label_post_processor) | |||||
results = predictor.predict(network=model, data=infer_data_set) | |||||
self.assertTrue(isinstance(results, list)) | |||||
self.assertEqual(len(results), len(infer_data)) | |||||
for i in range(len(infer_data)): | |||||
res = results[i] | |||||
self.assertTrue(isinstance(res, list)) | |||||
self.assertEqual(len(res), len(infer_data[i])) | |||||
os.system("rm -rf save") | |||||
print("pickle path deleted") | |||||
class TestPredictor2(unittest.TestCase): | |||||
def test_text_classify(self): | |||||
# TODO | |||||
pass | |||||
class TestPredictor(unittest.TestCase): | |||||
def test(self): | |||||
predictor = Predictor() | |||||
model = Linear(2, 1) | |||||
data = prepare_fake_dataset() | |||||
data.set_input("x") | |||||
ans = predictor.predict(model, data) | |||||
self.assertEqual(len(ans), 2000) | |||||
self.assertTrue(isinstance(ans[0], torch.Tensor)) |
@@ -1,30 +1,52 @@ | |||||
import torch | |||||
from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler | |||||
def test_convert_to_torch_tensor(): | |||||
data = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 3, 4, 5, 2]] | |||||
ans = convert_to_torch_tensor(data, False) | |||||
assert isinstance(ans, torch.Tensor) | |||||
assert tuple(ans.shape) == (3, 5) | |||||
def test_sequential_sampler(): | |||||
sampler = SequentialSampler() | |||||
data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||||
for idx, i in enumerate(sampler(data)): | |||||
assert idx == i | |||||
def test_random_sampler(): | |||||
sampler = RandomSampler() | |||||
data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||||
ans = [data[i] for i in sampler(data)] | |||||
assert len(ans) == len(data) | |||||
for d in ans: | |||||
assert d in data | |||||
import random | |||||
import unittest | |||||
import torch | |||||
if __name__ == "__main__": | |||||
test_sequential_sampler() | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \ | |||||
k_means_1d, k_means_bucketing, simple_sort_bucketing, BucketSampler | |||||
class TestSampler(unittest.TestCase): | |||||
def test_convert_to_torch_tensor(self): | |||||
data = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 3, 4, 5, 2]] | |||||
ans = convert_to_torch_tensor(data, False) | |||||
assert isinstance(ans, torch.Tensor) | |||||
assert tuple(ans.shape) == (3, 5) | |||||
def test_sequential_sampler(self): | |||||
sampler = SequentialSampler() | |||||
data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||||
for idx, i in enumerate(sampler(data)): | |||||
assert idx == i | |||||
def test_random_sampler(self): | |||||
sampler = RandomSampler() | |||||
data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||||
ans = [data[i] for i in sampler(data)] | |||||
assert len(ans) == len(data) | |||||
for d in ans: | |||||
assert d in data | |||||
def test_k_means(self): | |||||
centroids, assign = k_means_1d([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], 2, max_iter=5) | |||||
centroids, assign = list(centroids), list(assign) | |||||
assert len(centroids) == 2 | |||||
assert len(assign) == 10 | |||||
def test_k_means_bucketing(self): | |||||
res = k_means_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], [None, None]) | |||||
assert len(res) == 2 | |||||
def test_simple_sort_bucketing(self): | |||||
_ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10]) | |||||
assert len(_) == 10 | |||||
def test_BucketSampler(self): | |||||
sampler = BucketSampler(num_buckets=3, batch_size=16, seq_lens_field_name="seq_len") | |||||
data_set = DataSet({"x": [[0] * random.randint(1, 10)] * 10, "y": [[5, 6]] * 10}) | |||||
data_set.apply(lambda ins: len(ins["x"]), new_field_name="seq_len") | |||||
indices = sampler(data_set) | |||||
self.assertEqual(len(indices), 10) | |||||
# 跑通即可,不验证效果 |
@@ -1,57 +1,67 @@ | |||||
import os | |||||
import unittest | import unittest | ||||
data_name = "pku_training.utf8" | |||||
pickle_path = "data_for_tests" | |||||
import numpy as np | |||||
import torch.nn.functional as F | |||||
from torch import nn | |||||
import time | |||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
from fastNLP.core.field import TextField, LabelField | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.tester import SeqLabelTester | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
from fastNLP.core.losses import BCELoss | |||||
from fastNLP.core.losses import CrossEntropyLoss | |||||
from fastNLP.core.metrics import AccuracyMetric | |||||
from fastNLP.core.optimizer import SGD | |||||
from fastNLP.core.tester import Tester | |||||
from fastNLP.models.base_model import NaiveClassifier | |||||
data_name = "pku_training.utf8" | |||||
pickle_path = "data_for_tests" | |||||
def prepare_fake_dataset(): | |||||
mean = np.array([-3, -3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
mean = np.array([3, 3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||||
return data_set | |||||
def prepare_fake_dataset2(*args, size=100): | |||||
ys = np.random.randint(4, size=100, dtype=np.int64) | |||||
data = {'y': ys} | |||||
for arg in args: | |||||
data[arg] = np.random.randn(size, 5) | |||||
return DataSet(data=data) | |||||
class TestTester(unittest.TestCase): | class TestTester(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
model_args = { | |||||
"vocab_size": 10, | |||||
"word_emb_dim": 100, | |||||
"rnn_hidden_units": 100, | |||||
"num_classes": 5 | |||||
} | |||||
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||||
"save_loss": True, "batch_size": 2, "pickle_path": "./save/", | |||||
"use_cuda": False, "print_every_step": 1, "evaluator": SeqLabelEvaluator()} | |||||
train_data = [ | |||||
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']], | |||||
[['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
] | |||||
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||||
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | |||||
data_set = DataSet() | |||||
for example in train_data: | |||||
text, label = example[0], example[1] | |||||
x = TextField(text, False) | |||||
x_len = LabelField(len(text), is_target=False) | |||||
y = TextField(label, is_target=True) | |||||
ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len) | |||||
data_set.append(ins) | |||||
data_set.index_field("word_seq", vocab) | |||||
data_set.index_field("truth", label_vocab) | |||||
model = SeqLabeling(model_args) | |||||
tester = SeqLabelTester(**valid_args) | |||||
tester.test(network=model, dev_data=data_set) | |||||
# If this can run, everything is OK. | |||||
os.system("rm -rf save") | |||||
print("pickle path deleted") | |||||
# 检查报错提示能否正确提醒用户 | |||||
dataset = prepare_fake_dataset2('x1', 'x_unused') | |||||
dataset.rename_field('x_unused', 'x2') | |||||
dataset.set_input('x1', 'x2') | |||||
dataset.set_target('y', 'x1') | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
time.sleep(0.1) | |||||
# loss = F.cross_entropy(x, y) | |||||
return {'preds': x} | |||||
model = Model() | |||||
with self.assertRaises(NameError): | |||||
tester = Tester( | |||||
data=dataset, | |||||
model=model, | |||||
metrics=AccuracyMetric()) | |||||
tester.test() |
@@ -1,57 +1,242 @@ | |||||
import os | |||||
import unittest | import unittest | ||||
import numpy as np | |||||
import torch.nn.functional as F | |||||
from torch import nn | |||||
import time | |||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
from fastNLP.core.field import TextField, LabelField | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.loss import Loss | |||||
from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.trainer import SeqLabelTrainer | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | |||||
class TestTrainer(unittest.TestCase): | |||||
def test_case_1(self): | |||||
args = {"epochs": 3, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | |||||
"save_best_dev": True, "model_name": "default_model_name.pkl", | |||||
"loss": Loss("cross_entropy"), | |||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | |||||
"vocab_size": 10, | |||||
"word_emb_dim": 100, | |||||
"rnn_hidden_units": 100, | |||||
"num_classes": 5, | |||||
"evaluator": SeqLabelEvaluator() | |||||
} | |||||
trainer = SeqLabelTrainer(**args) | |||||
train_data = [ | |||||
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
[['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']], | |||||
[['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
] | |||||
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||||
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | |||||
data_set = DataSet() | |||||
for example in train_data: | |||||
text, label = example[0], example[1] | |||||
x = TextField(text, False) | |||||
x_len = LabelField(len(text), is_target=False) | |||||
y = TextField(label, is_target=False) | |||||
ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len) | |||||
data_set.append(ins) | |||||
data_set.index_field("word_seq", vocab) | |||||
data_set.index_field("truth", label_vocab) | |||||
model = SeqLabeling(args) | |||||
trainer.train(network=model, train_data=data_set, dev_data=data_set) | |||||
# If this can run, everything is OK. | |||||
os.system("rm -rf save") | |||||
print("pickle path deleted") | |||||
from fastNLP.core.losses import BCELoss | |||||
from fastNLP.core.losses import CrossEntropyLoss | |||||
from fastNLP.core.metrics import AccuracyMetric | |||||
from fastNLP.core.optimizer import SGD | |||||
from fastNLP.core.trainer import Trainer | |||||
from fastNLP.models.base_model import NaiveClassifier | |||||
def prepare_fake_dataset(): | |||||
mean = np.array([-3, -3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
mean = np.array([3, 3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||||
return data_set | |||||
def prepare_fake_dataset2(*args, size=100): | |||||
ys = np.random.randint(4, size=100, dtype=np.int64) | |||||
data = {'y': ys} | |||||
for arg in args: | |||||
data[arg] = np.random.randn(size, 5) | |||||
return DataSet(data=data) | |||||
class TrainerTestGround(unittest.TestCase): | |||||
def test_case(self): | |||||
data_set = prepare_fake_dataset() | |||||
data_set.set_input("x", flag=True) | |||||
data_set.set_target("y", flag=True) | |||||
train_set, dev_set = data_set.split(0.3) | |||||
model = NaiveClassifier(2, 1) | |||||
trainer = Trainer(train_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
n_epochs=10, | |||||
batch_size=32, | |||||
print_every=50, | |||||
validate_every=-1, | |||||
dev_data=dev_set, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=True, | |||||
save_path=None) | |||||
trainer.train() | |||||
""" | |||||
# 应该正确运行 | |||||
""" | |||||
def test_trainer_suggestion1(self): | |||||
# 检查报错提示能否正确提醒用户。 | |||||
# 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。 | |||||
dataset = prepare_fake_dataset2('x') | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2, y): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
loss = F.cross_entropy(x, y) | |||||
return {'loss': loss} | |||||
model = Model() | |||||
with self.assertRaises(NameError): | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model | |||||
) | |||||
""" | |||||
# 应该获取到的报错提示 | |||||
NameError: | |||||
The following problems occurred when calling Model.forward(self, x1, x2, y) | |||||
missing param: ['y', 'x1', 'x2'] | |||||
Suggestion: (1). You might need to set ['y'] as input. | |||||
(2). You need to provide ['x1', 'x2'] in DataSet and set it as input. | |||||
""" | |||||
def test_trainer_suggestion2(self): | |||||
# 检查报错提示能否正确提醒用户 | |||||
# 这里传入forward需要的数据,看是否可以运行 | |||||
dataset = prepare_fake_dataset2('x1', 'x2') | |||||
dataset.set_input('x1', 'x2', 'y', flag=True) | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2, y): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
loss = F.cross_entropy(x, y) | |||||
return {'loss': loss} | |||||
model = Model() | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
trainer.train() | |||||
""" | |||||
# 应该正确运行 | |||||
""" | |||||
def test_trainer_suggestion3(self): | |||||
# 检查报错提示能否正确提醒用户 | |||||
# 这里传入forward需要的数据,但是forward没有返回loss这个key | |||||
dataset = prepare_fake_dataset2('x1', 'x2') | |||||
dataset.set_input('x1', 'x2', 'y', flag=True) | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2, y): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
loss = F.cross_entropy(x, y) | |||||
return {'wrong_loss_key': loss} | |||||
model = Model() | |||||
with self.assertRaises(NameError): | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
trainer.train() | |||||
def test_trainer_suggestion4(self): | |||||
# 检查报错提示能否正确提醒用户 | |||||
# 这里传入forward需要的数据,是否可以正确提示unused | |||||
dataset = prepare_fake_dataset2('x1', 'x2') | |||||
dataset.set_input('x1', 'x2', 'y', flag=True) | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2, y): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
loss = F.cross_entropy(x, y) | |||||
return {'losses': loss} | |||||
model = Model() | |||||
with self.assertRaises(NameError): | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
def test_trainer_suggestion5(self): | |||||
# 检查报错提示能否正确提醒用户 | |||||
# 这里传入多余参数,让其duplicate, 但这里因为y不会被调用,所以其实不会报错 | |||||
dataset = prepare_fake_dataset2('x1', 'x_unused') | |||||
dataset.rename_field('x_unused', 'x2') | |||||
dataset.set_input('x1', 'x2', 'y') | |||||
dataset.set_target('y') | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2, y): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
loss = F.cross_entropy(x, y) | |||||
return {'loss': loss} | |||||
model = Model() | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
def test_trainer_suggestion6(self): | |||||
# 检查报错提示能否正确提醒用户 | |||||
# 这里传入多余参数,让其duplicate | |||||
dataset = prepare_fake_dataset2('x1', 'x_unused') | |||||
dataset.rename_field('x_unused', 'x2') | |||||
dataset.set_input('x1', 'x2') | |||||
dataset.set_target('y', 'x1') | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
time.sleep(0.1) | |||||
# loss = F.cross_entropy(x, y) | |||||
return {'preds': x} | |||||
model = Model() | |||||
with self.assertRaises(NameError): | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
dev_data=dataset, | |||||
loss=CrossEntropyLoss(), | |||||
metrics=AccuracyMetric(), | |||||
use_tqdm=False, | |||||
print_every=2) | |||||
def test_case2(self): | |||||
# check metrics Wrong | |||||
data_set = prepare_fake_dataset2('x1', 'x2') |
@@ -1,31 +0,0 @@ | |||||
import unittest | |||||
from fastNLP.core.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX | |||||
class TestVocabulary(unittest.TestCase): | |||||
def test_vocab(self): | |||||
import _pickle as pickle | |||||
import os | |||||
vocab = Vocabulary() | |||||
filename = 'vocab' | |||||
vocab.update(filename) | |||||
vocab.update([filename, ['a'], [['b']], ['c']]) | |||||
idx = vocab[filename] | |||||
before_pic = (vocab.to_word(idx), vocab[filename]) | |||||
with open(filename, 'wb') as f: | |||||
pickle.dump(vocab, f) | |||||
with open(filename, 'rb') as f: | |||||
vocab = pickle.load(f) | |||||
os.remove(filename) | |||||
vocab.build_reverse_vocab() | |||||
after_pic = (vocab.to_word(idx), vocab[filename]) | |||||
TRUE_DICT = {'vocab': 5, 'a': 6, 'b': 7, 'c': 8} | |||||
TRUE_DICT.update(DEFAULT_WORD_TO_INDEX) | |||||
TRUE_IDXDICT = {0: '<pad>', 1: '<unk>', 2: '<reserved-2>', 3: '<reserved-3>', 4: '<reserved-4>', 5: 'vocab', 6: 'a', 7: 'b', 8: 'c'} | |||||
self.assertEqual(before_pic, after_pic) | |||||
self.assertDictEqual(TRUE_DICT, vocab.word2idx) | |||||
self.assertDictEqual(TRUE_IDXDICT, vocab.idx2word) | |||||
if __name__ == '__main__': | |||||
unittest.main() |
@@ -0,0 +1,88 @@ | |||||
import unittest | |||||
from collections import Counter | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | |||||
"works", "well", "in", "most", "cases", "scales", "well"] | |||||
counter = Counter(text) | |||||
class TestAdd(unittest.TestCase): | |||||
def test_add(self): | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
for word in text: | |||||
vocab.add(word) | |||||
self.assertEqual(vocab.word_count, counter) | |||||
def test_add_word(self): | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
for word in text: | |||||
vocab.add_word(word) | |||||
self.assertEqual(vocab.word_count, counter) | |||||
def test_add_word_lst(self): | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab.add_word_lst(text) | |||||
self.assertEqual(vocab.word_count, counter) | |||||
def test_update(self): | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab.update(text) | |||||
self.assertEqual(vocab.word_count, counter) | |||||
class TestIndexing(unittest.TestCase): | |||||
def test_len(self): | |||||
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None) | |||||
vocab.update(text) | |||||
self.assertEqual(len(vocab), len(counter)) | |||||
def test_contains(self): | |||||
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None) | |||||
vocab.update(text) | |||||
self.assertTrue(text[-1] in vocab) | |||||
self.assertFalse("~!@#" in vocab) | |||||
self.assertEqual(text[-1] in vocab, vocab.has_word(text[-1])) | |||||
self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) | |||||
def test_index(self): | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab.update(text) | |||||
res = [vocab[w] for w in set(text)] | |||||
self.assertEqual(len(res), len(set(res))) | |||||
res = [vocab.to_index(w) for w in set(text)] | |||||
self.assertEqual(len(res), len(set(res))) | |||||
def test_to_word(self): | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab.update(text) | |||||
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | |||||
class TestOther(unittest.TestCase): | |||||
def test_additional_update(self): | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab.update(text) | |||||
_ = vocab["well"] | |||||
self.assertEqual(vocab.rebuild, False) | |||||
vocab.add("hahaha") | |||||
self.assertEqual(vocab.rebuild, True) | |||||
_ = vocab["hahaha"] | |||||
self.assertEqual(vocab.rebuild, False) | |||||
self.assertTrue("hahaha" in vocab) | |||||
def test_warning(self): | |||||
vocab = Vocabulary(max_size=len(set(text)), min_freq=None) | |||||
vocab.update(text) | |||||
self.assertEqual(vocab.rebuild, True) | |||||
print(len(vocab)) | |||||
self.assertEqual(vocab.rebuild, False) | |||||
vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"]) | |||||
# this will print a warning | |||||
self.assertEqual(vocab.rebuild, True) |
@@ -1,12 +1,6 @@ | |||||
the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.044457 -0.49688 -0.17862 -0.00066023 -0.6566 0.27843 -0.14767 -0.55677 0.14658 -0.0095095 0.011658 0.10204 -0.12792 -0.8443 -0.12181 -0.016801 -0.33279 -0.1552 -0.23131 -0.19181 -1.8823 -0.76746 0.099051 -0.42125 -0.19526 4.0071 -0.18594 -0.52287 -0.31681 0.00059213 0.0074449 0.17778 -0.15897 0.012041 -0.054223 -0.29871 -0.15749 -0.34758 -0.045637 -0.44251 0.18785 0.0027849 -0.18411 -0.11514 -0.78581 | the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.044457 -0.49688 -0.17862 -0.00066023 -0.6566 0.27843 -0.14767 -0.55677 0.14658 -0.0095095 0.011658 0.10204 -0.12792 -0.8443 -0.12181 -0.016801 -0.33279 -0.1552 -0.23131 -0.19181 -1.8823 -0.76746 0.099051 -0.42125 -0.19526 4.0071 -0.18594 -0.52287 -0.31681 0.00059213 0.0074449 0.17778 -0.15897 0.012041 -0.054223 -0.29871 -0.15749 -0.34758 -0.045637 -0.44251 0.18785 0.0027849 -0.18411 -0.11514 -0.78581 | ||||
, 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -0.42852 -0.55641 -0.364 -0.23938 0.13001 -0.063734 -0.39575 -0.48162 0.23291 0.090201 -0.13324 0.078639 -0.41634 -0.15428 0.10068 0.48891 0.31226 -0.1252 -0.037512 -1.5179 0.12612 -0.02442 -0.042961 -0.28351 3.5416 -0.11956 -0.014533 -0.1499 0.21864 -0.33412 -0.13872 0.31806 0.70358 0.44858 -0.080262 0.63003 0.32111 -0.46765 0.22786 0.36034 -0.37818 -0.56657 0.044691 0.30392 | |||||
. 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 -0.43478 -0.31086 -0.44999 -0.29486 0.16608 0.11963 -0.41328 -0.42353 0.59868 0.28825 -0.11547 -0.041848 -0.67989 -0.25063 0.18472 0.086876 0.46582 0.015035 0.043474 -1.4671 -0.30384 -0.023441 0.30589 -0.21785 3.746 0.0042284 -0.18436 -0.46209 0.098329 -0.11907 0.23919 0.1161 0.41705 0.056763 -6.3681e-05 0.068987 0.087939 -0.10285 -0.13931 0.22314 -0.080803 -0.35652 0.016413 0.10216 | |||||
of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 0.18157 -0.52393 0.10381 -0.17566 0.078852 -0.36216 -0.11829 -0.83336 0.11917 -0.16605 0.061555 -0.012719 -0.56623 0.013616 0.22851 -0.14396 -0.067549 -0.38157 -0.23698 -1.7037 -0.86692 -0.26704 -0.2589 0.1767 3.8676 -0.1613 -0.13273 -0.68881 0.18444 0.0052464 -0.33874 -0.078956 0.24185 0.36576 -0.34727 0.28483 0.075693 -0.062178 -0.38988 0.22902 -0.21617 -0.22562 -0.093918 -0.80375 | of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 0.18157 -0.52393 0.10381 -0.17566 0.078852 -0.36216 -0.11829 -0.83336 0.11917 -0.16605 0.061555 -0.012719 -0.56623 0.013616 0.22851 -0.14396 -0.067549 -0.38157 -0.23698 -1.7037 -0.86692 -0.26704 -0.2589 0.1767 3.8676 -0.1613 -0.13273 -0.68881 0.18444 0.0052464 -0.33874 -0.078956 0.24185 0.36576 -0.34727 0.28483 0.075693 -0.062178 -0.38988 0.22902 -0.21617 -0.22562 -0.093918 -0.80375 | ||||
to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -0.41376 0.13228 -0.29847 -0.085253 0.17118 0.22419 -0.10046 -0.43653 0.33418 0.67846 0.057204 -0.34448 -0.42785 -0.43275 0.55963 0.10032 0.18677 -0.26854 0.037334 -2.0932 0.22171 -0.39868 0.20912 -0.55725 3.8826 0.47466 -0.95658 -0.37788 0.20869 -0.32752 0.12751 0.088359 0.16351 -0.21634 -0.094375 0.018324 0.21048 -0.03088 -0.19722 0.082279 -0.09434 -0.073297 -0.064699 -0.26044 | to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -0.41376 0.13228 -0.29847 -0.085253 0.17118 0.22419 -0.10046 -0.43653 0.33418 0.67846 0.057204 -0.34448 -0.42785 -0.43275 0.55963 0.10032 0.18677 -0.26854 0.037334 -2.0932 0.22171 -0.39868 0.20912 -0.55725 3.8826 0.47466 -0.95658 -0.37788 0.20869 -0.32752 0.12751 0.088359 0.16351 -0.21634 -0.094375 0.018324 0.21048 -0.03088 -0.19722 0.082279 -0.09434 -0.073297 -0.064699 -0.26044 | ||||
and 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 -0.51332 -0.47368 -0.33075 -0.13834 0.2702 0.30938 -0.45012 -0.4127 -0.09932 0.038085 0.029749 0.10076 -0.25058 -0.51818 0.34558 0.44922 0.48791 -0.080866 -0.10121 -1.3777 -0.10866 -0.23201 0.012839 -0.46508 3.8463 0.31362 0.13643 -0.52244 0.3302 0.33707 -0.35601 0.32431 0.12041 0.3512 -0.069043 0.36885 0.25168 -0.24517 0.25381 0.1367 -0.31178 -0.6321 -0.25028 -0.38097 | and 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 -0.51332 -0.47368 -0.33075 -0.13834 0.2702 0.30938 -0.45012 -0.4127 -0.09932 0.038085 0.029749 0.10076 -0.25058 -0.51818 0.34558 0.44922 0.48791 -0.080866 -0.10121 -1.3777 -0.10866 -0.23201 0.012839 -0.46508 3.8463 0.31362 0.13643 -0.52244 0.3302 0.33707 -0.35601 0.32431 0.12041 0.3512 -0.069043 0.36885 0.25168 -0.24517 0.25381 0.1367 -0.31178 -0.6321 -0.25028 -0.38097 | ||||
in 0.33042 0.24995 -0.60874 0.10923 0.036372 0.151 -0.55083 -0.074239 -0.092307 -0.32821 0.09598 -0.82269 -0.36717 -0.67009 0.42909 0.016496 -0.23573 0.12864 -1.0953 0.43334 0.57067 -0.1036 0.20422 0.078308 -0.42795 -1.7984 -0.27865 0.11954 -0.12689 0.031744 3.8631 -0.17786 -0.082434 -0.62698 0.26497 -0.057185 -0.073521 0.46103 0.30862 0.12498 -0.48609 -0.0080272 0.031184 -0.36576 -0.42699 0.42164 -0.11666 -0.50703 -0.027273 -0.53285 | in 0.33042 0.24995 -0.60874 0.10923 0.036372 0.151 -0.55083 -0.074239 -0.092307 -0.32821 0.09598 -0.82269 -0.36717 -0.67009 0.42909 0.016496 -0.23573 0.12864 -1.0953 0.43334 0.57067 -0.1036 0.20422 0.078308 -0.42795 -1.7984 -0.27865 0.11954 -0.12689 0.031744 3.8631 -0.17786 -0.082434 -0.62698 0.26497 -0.057185 -0.073521 0.46103 0.30862 0.12498 -0.48609 -0.0080272 0.031184 -0.36576 -0.42699 0.42164 -0.11666 -0.50703 -0.027273 -0.53285 | ||||
a 0.21705 0.46515 -0.46757 0.10082 1.0135 0.74845 -0.53104 -0.26256 0.16812 0.13182 -0.24909 -0.44185 -0.21739 0.51004 0.13448 -0.43141 -0.03123 0.20674 -0.78138 -0.20148 -0.097401 0.16088 -0.61836 -0.18504 -0.12461 -2.2526 -0.22321 0.5043 0.32257 0.15313 3.9636 -0.71365 -0.67012 0.28388 0.21738 0.14433 0.25926 0.23434 0.4274 -0.44451 0.13813 0.36973 -0.64289 0.024142 -0.039315 -0.26037 0.12017 -0.043782 0.41013 0.1796 | |||||
" 0.25769 0.45629 -0.76974 -0.37679 0.59272 -0.063527 0.20545 -0.57385 -0.29009 -0.13662 0.32728 1.4719 -0.73681 -0.12036 0.71354 -0.46098 0.65248 0.48887 -0.51558 0.039951 -0.34307 -0.014087 0.86488 0.3546 0.7999 -1.4995 -1.8153 0.41128 0.23921 -0.43139 3.6623 -0.79834 -0.54538 0.16943 -0.82017 -0.3461 0.69495 -1.2256 -0.17992 -0.057474 0.030498 -0.39543 -0.38515 -1.0002 0.087599 -0.31009 -0.34677 -0.31438 0.75004 0.97065 | |||||
's 0.23727 0.40478 -0.20547 0.58805 0.65533 0.32867 -0.81964 -0.23236 0.27428 0.24265 0.054992 0.16296 -1.2555 -0.086437 0.44536 0.096561 -0.16519 0.058378 -0.38598 0.086977 0.0033869 0.55095 -0.77697 -0.62096 0.092948 -2.5685 -0.67739 0.10151 -0.48643 -0.057805 3.1859 -0.017554 -0.16138 0.055486 -0.25885 -0.33938 -0.19928 0.26049 0.10478 -0.55934 -0.12342 0.65961 -0.51802 -0.82995 -0.082739 0.28155 -0.423 -0.27378 -0.007901 -0.030231 | |||||
a 0.21705 0.46515 -0.46757 0.10082 1.0135 0.74845 -0.53104 -0.26256 0.16812 0.13182 -0.24909 -0.44185 -0.21739 0.51004 0.13448 -0.43141 -0.03123 0.20674 -0.78138 -0.20148 -0.097401 0.16088 -0.61836 -0.18504 -0.12461 -2.2526 -0.22321 0.5043 0.32257 0.15313 3.9636 -0.71365 -0.67012 0.28388 0.21738 0.14433 0.25926 0.23434 0.4274 -0.44451 0.13813 0.36973 -0.64289 0.024142 -0.039315 -0.26037 0.12017 -0.043782 0.41013 0.1796 |
@@ -0,0 +1,77 @@ | |||||
A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . 1 | |||||
This quiet , introspective and entertaining independent is worth seeking . 4 | |||||
Even fans of Ismail Merchant 's work , I suspect , would have a hard time sitting through this one . 1 | |||||
A positively thrilling combination of ethnography and all the intrigue , betrayal , deceit and murder of a Shakespearean tragedy or a juicy soap opera . 3 | |||||
Aggressive self-glorification and a manipulative whitewash . 1 | |||||
A comedy-drama of nearly epic proportions rooted in a sincere performance by the title character undergoing midlife crisis . 4 | |||||
Narratively , Trouble Every Day is a plodding mess . 1 | |||||
The Importance of Being Earnest , so thick with wit it plays like a reading from Bartlett 's Familiar Quotations 3 | |||||
But it does n't leave you with much . 1 | |||||
You could hate it for the same reason . 1 | |||||
There 's little to recommend Snow Dogs , unless one considers cliched dialogue and perverse escapism a source of high hilarity . 1 | |||||
Kung Pow is Oedekerk 's realization of his childhood dream to be in a martial-arts flick , and proves that sometimes the dreams of youth should remain just that . 1 | |||||
The performances are an absolute joy . 4 | |||||
Fresnadillo has something serious to say about the ways in which extravagant chance can distort our perspective and throw us off the path of good sense . 3 | |||||
I still like Moonlight Mile , better judgment be damned . 3 | |||||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||||
a bilingual charmer , just like the woman who inspired it 3 | |||||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||||
It 's everything you 'd expect -- but nothing more . 2 | |||||
Best indie of the year , so far . 4 | |||||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||||
The plot is romantic comedy boilerplate from start to finish . 2 | |||||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||||
A film that clearly means to preach exclusively to the converted . 2 | |||||
While The Importance of Being Earnest offers opportunities for occasional smiles and chuckles , it does n't give us a reason to be in the theater beyond Wilde 's wit and the actors ' performances . 1 | |||||
The latest vapid actor 's exercise to appropriate the structure of Arthur Schnitzler 's Reigen . 1 | |||||
More vaudeville show than well-constructed narrative , but on those terms it 's inoffensive and actually rather sweet . 2 | |||||
Nothing more than a run-of-the-mill action flick . 2 | |||||
Hampered -- no , paralyzed -- by a self-indulgent script ... that aims for poetry and ends up sounding like satire . 0 | |||||
Ice Age is the first computer-generated feature cartoon to feel like other movies , and that makes for some glacial pacing early on . 2 | |||||
There 's very little sense to what 's going on here , but the makers serve up the cliches with considerable dash . 2 | |||||
Cattaneo should have followed the runaway success of his first film , The Full Monty , with something different . 2 | |||||
They 're the unnamed , easily substitutable forces that serve as whatever terror the heroes of horror movies try to avoid . 1 | |||||
It almost feels as if the movie is more interested in entertaining itself than in amusing us . 1 | |||||
The movie 's progression into rambling incoherence gives new meaning to the phrase ` fatal script error . ' 0 | |||||
I still like Moonlight Mile , better judgment be damned . 3 | |||||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||||
a bilingual charmer , just like the woman who inspired it 3 | |||||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||||
It 's everything you 'd expect -- but nothing more . 2 | |||||
Best indie of the year , so far . 4 | |||||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||||
The plot is romantic comedy boilerplate from start to finish . 2 | |||||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||||
A film that clearly means to preach exclusively to the converted . 2 | |||||
I still like Moonlight Mile , better judgment be damned . 3 | |||||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||||
a bilingual charmer , just like the woman who inspired it 3 | |||||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||||
It 's everything you 'd expect -- but nothing more . 2 | |||||
Best indie of the year , so far . 4 | |||||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||||
The plot is romantic comedy boilerplate from start to finish . 2 | |||||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||||
A film that clearly means to preach exclusively to the converted . 2 | |||||
I still like Moonlight Mile , better judgment be damned . 3 | |||||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||||
a bilingual charmer , just like the woman who inspired it 3 | |||||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||||
It 's everything you 'd expect -- but nothing more . 2 | |||||
Best indie of the year , so far . 4 | |||||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||||
The plot is romantic comedy boilerplate from start to finish . 2 | |||||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||||
A film that clearly means to preach exclusively to the converted . 2 |