@@ -1,9 +1,6 @@ | |||
language: python | |||
python: | |||
- "3.6" | |||
env: | |||
- TRAVIS=1 | |||
# command to install dependencies | |||
install: | |||
- pip install --quiet -r requirements.txt | |||
@@ -11,7 +8,7 @@ install: | |||
- pip install pytest-cov | |||
# command to run tests | |||
script: | |||
- pytest --cov=./ test/ | |||
- pytest --cov=fastNLP test/ | |||
after_success: | |||
- bash <(curl -s https://codecov.io/bash) |
@@ -23,6 +23,13 @@ def _colored_string(string: str, color: str or int) -> str: | |||
return "\033[%dm%s\033[0m" % (color, string) | |||
def gr(string, flag): | |||
if flag: | |||
return _colored_string(string, "green") | |||
else: | |||
return _colored_string(string, "red") | |||
def find_all_modules(): | |||
modules = {} | |||
children = {} | |||
@@ -66,31 +73,75 @@ def create_rst_file(modules, name, children): | |||
fout.write(t + "\n") | |||
fout.write("\n") | |||
fout.write(".. automodule:: " + name + "\n") | |||
if len(m.__all__) > 0: | |||
if name != "fastNLP.core" and len(m.__all__) > 0: | |||
fout.write(" :members: " + ", ".join(m.__all__) + "\n") | |||
fout.write(" :inherited-members:\n") | |||
short = name[len("fastNLP."):] | |||
if not (short.startswith('models') or short.startswith('modules') or short.startswith('embeddings')): | |||
fout.write(" :inherited-members:\n") | |||
fout.write("\n") | |||
if name in children: | |||
fout.write("子模块\n------\n\n.. toctree::\n\n") | |||
fout.write("子模块\n------\n\n.. toctree::\n :maxdepth: 1\n\n") | |||
for module in children[name]: | |||
fout.write(" " + module + "\n") | |||
def check_file(m, name): | |||
names = name.split('.') | |||
test_name = "test." + ".".join(names[1:-1]) + ".test_" + names[-1] | |||
try: | |||
__import__(test_name) | |||
tm = sys.modules[test_name] | |||
except ModuleNotFoundError: | |||
tm = None | |||
tested = tm is not None | |||
funcs = {} | |||
classes = {} | |||
for item, obj in inspect.getmembers(m): | |||
if inspect.isclass(obj) and obj.__module__ == name: | |||
print(obj) | |||
if inspect.isfunction(obj) and obj.__module__ == name: | |||
print("FUNC", obj) | |||
if inspect.isclass(obj) and obj.__module__ == name and not obj.__name__.startswith('_'): | |||
this = (obj.__doc__ is not None, tested and obj.__name__ in dir(tm), {}) | |||
for i in dir(obj): | |||
func = getattr(obj, i) | |||
if inspect.isfunction(func) and not i.startswith('_'): | |||
this[2][i] = (func.__doc__ is not None, False) | |||
classes[obj.__name__] = this | |||
if inspect.isfunction(obj) and obj.__module__ == name and not obj.__name__.startswith('_'): | |||
this = (obj.__doc__ is not None, tested and obj.__name__ in dir(tm)) # docs | |||
funcs[obj.__name__] = this | |||
return funcs, classes | |||
def check_files(modules): | |||
def check_files(modules, out=None): | |||
for name in sorted(modules.keys()): | |||
if name == 'fastNLP.core.utils': | |||
check_file(modules[name], name) | |||
print(name, file=out) | |||
funcs, classes = check_file(modules[name], name) | |||
if out is None: | |||
for f in funcs: | |||
print("%-30s \t %s \t %s" % (f, gr("文档", funcs[f][0]), gr("测试", funcs[f][1]))) | |||
for c in classes: | |||
print("%-30s \t %s \t %s" % (c, gr("文档", classes[c][0]), gr("测试", classes[c][1]))) | |||
methods = classes[c][2] | |||
for f in methods: | |||
print(" %-28s \t %s" % (f, gr("文档", methods[f][0]))) | |||
else: | |||
for f in funcs: | |||
if not funcs[f][0]: | |||
print("缺少文档 %s" % (f), file=out) | |||
if not funcs[f][1]: | |||
print("缺少测试 %s" % (f), file=out) | |||
for c in classes: | |||
if not classes[c][0]: | |||
print("缺少文档 %s" % (c), file=out) | |||
if not classes[c][1]: | |||
print("缺少测试 %s" % (c), file=out) | |||
methods = classes[c][2] | |||
for f in methods: | |||
if not methods[f][0]: | |||
print("缺少文档 %s" % (c + "." + f), file=out) | |||
print(file=out) | |||
def main(): | |||
sys.path.append("..") | |||
print(_colored_string('Getting modules...', "Blue")) | |||
modules, to_doc, children = find_all_modules() | |||
print(_colored_string('Done!', "Green")) | |||
@@ -99,7 +150,7 @@ def main(): | |||
create_rst_file(modules, name, children) | |||
print(_colored_string('Done!', "Green")) | |||
print(_colored_string('Checking all files...', "Blue")) | |||
check_files(modules) | |||
check_files(modules, out=open("results.txt", "w")) | |||
print(_colored_string('Done!', "Green")) | |||
@@ -168,10 +168,12 @@ texinfo_documents = [ | |||
# -- Extension configuration ------------------------------------------------- | |||
def maybe_skip_member(app, what, name, obj, skip, options): | |||
if name.startswith("_"): | |||
return True | |||
if obj.__doc__ is None: | |||
return True | |||
if name == "__init__": | |||
return False | |||
if name.startswith("_"): | |||
return True | |||
return False | |||
@@ -2,13 +2,12 @@ fastNLP.core | |||
============ | |||
.. automodule:: fastNLP.core | |||
:members: DataSet, Instance, FieldArray, Padder, AutoPadder, EngChar2DPadder, Vocabulary, DataSetIter, BatchIter, TorchLoaderIter, Const, Tester, Trainer, cache_results, seq_len_to_mask, get_seq_len, logger, Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, TesterCallback, CallbackException, EarlyStopError, LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric, Optimizer, SGD, Adam, AdamW, SequentialSampler, BucketSampler, RandomSampler, Sampler | |||
:inherited-members: | |||
子模块 | |||
------ | |||
.. toctree:: | |||
:maxdepth: 1 | |||
fastNLP.core.batch | |||
fastNLP.core.callback | |||
@@ -3,5 +3,4 @@ fastNLP.embeddings.bert_embedding | |||
.. automodule:: fastNLP.embeddings.bert_embedding | |||
:members: BertEmbedding, BertWordPieceEncoder | |||
:inherited-members: | |||
@@ -3,5 +3,4 @@ fastNLP.embeddings.char_embedding | |||
.. automodule:: fastNLP.embeddings.char_embedding | |||
:members: CNNCharEmbedding, LSTMCharEmbedding | |||
:inherited-members: | |||
@@ -3,5 +3,4 @@ fastNLP.embeddings.contextual_embedding | |||
.. automodule:: fastNLP.embeddings.contextual_embedding | |||
:members: ContextualEmbedding | |||
:inherited-members: | |||
@@ -3,5 +3,4 @@ fastNLP.embeddings.elmo_embedding | |||
.. automodule:: fastNLP.embeddings.elmo_embedding | |||
:members: ElmoEmbedding | |||
:inherited-members: | |||
@@ -3,5 +3,4 @@ fastNLP.embeddings.embedding | |||
.. automodule:: fastNLP.embeddings.embedding | |||
:members: Embedding, TokenEmbedding | |||
:inherited-members: | |||
@@ -3,12 +3,12 @@ fastNLP.embeddings | |||
.. automodule:: fastNLP.embeddings | |||
:members: Embedding, TokenEmbedding, StaticEmbedding, ElmoEmbedding, BertEmbedding, BertWordPieceEncoder, StackEmbedding, LSTMCharEmbedding, CNNCharEmbedding, get_embeddings | |||
:inherited-members: | |||
子模块 | |||
------ | |||
.. toctree:: | |||
:maxdepth: 1 | |||
fastNLP.embeddings.bert_embedding | |||
fastNLP.embeddings.char_embedding | |||
@@ -3,5 +3,4 @@ fastNLP.embeddings.stack_embedding | |||
.. automodule:: fastNLP.embeddings.stack_embedding | |||
:members: StackEmbedding | |||
:inherited-members: | |||
@@ -3,5 +3,4 @@ fastNLP.embeddings.static_embedding | |||
.. automodule:: fastNLP.embeddings.static_embedding | |||
:members: StaticEmbedding | |||
:inherited-members: | |||
@@ -3,5 +3,4 @@ fastNLP.embeddings.utils | |||
.. automodule:: fastNLP.embeddings.utils | |||
:members: get_embeddings | |||
:inherited-members: | |||
@@ -1,6 +0,0 @@ | |||
fastNLP.io.dataset_loader | |||
========================= | |||
.. automodule:: fastNLP.io.dataset_loader | |||
:members: CSVLoader, JsonLoader | |||
@@ -2,6 +2,6 @@ fastNLP.io.loader | |||
================= | |||
.. automodule:: fastNLP.io.loader | |||
:members: Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader | |||
:members: Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader | |||
:inherited-members: | |||
@@ -2,6 +2,6 @@ fastNLP.io.pipe | |||
=============== | |||
.. automodule:: fastNLP.io.pipe | |||
:members: Pipe, CWSPipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe, Conll2003Pipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe | |||
:members: Pipe, CWSPipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe, Conll2003Pipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe | |||
:inherited-members: | |||
@@ -2,13 +2,14 @@ fastNLP.io | |||
========== | |||
.. automodule:: fastNLP.io | |||
:members: DataBundle, EmbedLoader, Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, WeiboNERLoader, PeopleDailyNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, Pipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, Conll2003Pipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, PeopleDailyPipe, WeiboNERPipe, CWSPipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, ModelLoader, ModelSaver | |||
:members: DataBundle, EmbedLoader, Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, WeiboNERLoader, PeopleDailyNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, Pipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, Conll2003Pipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, PeopleDailyPipe, WeiboNERPipe, CWSPipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, ModelLoader, ModelSaver | |||
:inherited-members: | |||
子模块 | |||
------ | |||
.. toctree:: | |||
:maxdepth: 1 | |||
fastNLP.io.data_bundle | |||
fastNLP.io.embed_loader | |||
@@ -0,0 +1,6 @@ | |||
fastNLP.models.bert | |||
=================== | |||
.. automodule:: fastNLP.models.bert | |||
:members: BertForSequenceClassification, BertForSentenceMatching, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering | |||
@@ -3,5 +3,4 @@ fastNLP.models.biaffine_parser | |||
.. automodule:: fastNLP.models.biaffine_parser | |||
:members: BiaffineParser, GraphParser | |||
:inherited-members: | |||
@@ -3,5 +3,4 @@ fastNLP.models.cnn_text_classification | |||
.. automodule:: fastNLP.models.cnn_text_classification | |||
:members: CNNText | |||
:inherited-members: | |||
@@ -2,14 +2,15 @@ fastNLP.models | |||
============== | |||
.. automodule:: fastNLP.models | |||
:members: CNNText, SeqLabeling, AdvSeqLabel, ESIM, StarTransEnc, STSeqLabel, STNLICls, STSeqCls, BiaffineParser, GraphParser | |||
:inherited-members: | |||
:members: CNNText, SeqLabeling, AdvSeqLabel, ESIM, StarTransEnc, STSeqLabel, STNLICls, STSeqCls, BiaffineParser, GraphParser, BertForSequenceClassification, BertForSentenceMatching, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering | |||
子模块 | |||
------ | |||
.. toctree:: | |||
:maxdepth: 1 | |||
fastNLP.models.bert | |||
fastNLP.models.biaffine_parser | |||
fastNLP.models.cnn_text_classification | |||
fastNLP.models.sequence_labeling | |||
@@ -2,6 +2,5 @@ fastNLP.models.sequence_labeling | |||
================================ | |||
.. automodule:: fastNLP.models.sequence_labeling | |||
:members: SeqLabeling, AdvSeqLabel | |||
:inherited-members: | |||
:members: SeqLabeling, AdvSeqLabel, BiLSTMCRF | |||
@@ -3,5 +3,4 @@ fastNLP.models.snli | |||
.. automodule:: fastNLP.models.snli | |||
:members: ESIM | |||
:inherited-members: | |||
@@ -3,5 +3,4 @@ fastNLP.models.star_transformer | |||
.. automodule:: fastNLP.models.star_transformer | |||
:members: StarTransEnc, STNLICls, STSeqCls, STSeqLabel | |||
:inherited-members: | |||
@@ -3,5 +3,4 @@ fastNLP.modules.decoder | |||
.. automodule:: fastNLP.modules.decoder | |||
:members: MLP, ConditionalRandomField, viterbi_decode, allowed_transitions | |||
:inherited-members: | |||
@@ -3,5 +3,4 @@ fastNLP.modules.encoder | |||
.. automodule:: fastNLP.modules.encoder | |||
:members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, MultiHeadAttention | |||
:inherited-members: | |||
@@ -3,12 +3,12 @@ fastNLP.modules | |||
.. automodule:: fastNLP.modules | |||
:members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, MultiHeadAttention, MLP, ConditionalRandomField, viterbi_decode, allowed_transitions, TimestepDropout | |||
:inherited-members: | |||
子模块 | |||
------ | |||
.. toctree:: | |||
:maxdepth: 1 | |||
fastNLP.modules.decoder | |||
fastNLP.modules.encoder | |||
@@ -3,5 +3,4 @@ fastNLP.modules.utils | |||
.. automodule:: fastNLP.modules.utils | |||
:members: initial_parameter, summary | |||
:inherited-members: | |||
@@ -9,6 +9,7 @@ fastNLP | |||
------ | |||
.. toctree:: | |||
:maxdepth: 1 | |||
fastNLP.core | |||
fastNLP.embeddings | |||
@@ -23,7 +23,7 @@ Callback的构建和使用 | |||
class LRDecay(fastNLP.Callback): | |||
def __init__(self): | |||
super(MyCallback, self).__init__() | |||
super(LRDecay, self).__init__() | |||
self.base_lrs = [] | |||
self.delta = [] | |||
@@ -1,21 +1,20 @@ | |||
============================== | |||
使用DataSet预处理文本 | |||
DataSet | |||
============================== | |||
:class:`~fastNLP.DataSet` 是fastNLP中用于承载数据的容器。可以将DataSet看做是一个表格, | |||
每一行是一个sample (在fastNLP中被称为 :mod:`~fastNLP.core.instance` ), | |||
每一列是一个feature (在fastNLP中称为 :mod:`~fastNLP.core.field` )。 | |||
:class:`~fastNLP.DataSet` 是fastNLP用于承载数据的类,一般训练集、验证集和测试集会被加载为三个单独的:class:`~fastNLP.DataSet`对象。 | |||
:class:`~fastNLP.DataSet`中的数据组织形式类似一个表格,比如下面 :class:`~fastNLP.DataSet` 一共有3列,列在fastNLP中被称为field。 | |||
.. csv-table:: | |||
:header: "sentence", "words", "seq_len" | |||
:header: "raw_chars", "chars", "seq_len" | |||
"This is the first instance .", "[This, is, the, first, instance, .]", 6 | |||
"Second instance .", "[Second, instance, .]", 3 | |||
"历任公司副总经理、总工程师,", "[历 任 公 司 副 总 经 理 、 总 工 程 师 ,]", 6 | |||
"Third instance .", "[Third, instance, .]", 3 | |||
"...", "[...]", "..." | |||
上面是一个样例数据中 DataSet 的存储结构。其中它的每一行是一个 :class:`~fastNLP.Instance` 对象; 每一列是一个 :class:`~fastNLP.FieldArray` 对象。 | |||
每一行是一个instance (在fastNLP中被称为 :mod:`~fastNLP.core.Instance` ), | |||
每一列是一个field (在fastNLP中称为 :mod:`~fastNLP.core.FieldArray` )。 | |||
----------------------------- | |||
数据集构建和删除 | |||
@@ -26,11 +25,23 @@ | |||
.. code-block:: python | |||
from fastNLP import DataSet | |||
data = {'sentence':["This is the first instance .", "Second instance .", "Third instance ."], | |||
data = {'raw_words':["This is the first instance .", "Second instance .", "Third instance ."], | |||
'words': [['this', 'is', 'the', 'first', 'instance', '.'], ['Second', 'instance', '.'], ['Third', 'instance', '.']], | |||
'seq_len': [6, 3, 3]} | |||
dataset = DataSet(data) | |||
# 传入的dict的每个key的value应该为具有相同长度的list | |||
print(dataset) | |||
输出为:: | |||
+------------------------------+------------------------------------------------+---------+ | |||
| raw_words | words | seq_len | | |||
+------------------------------+------------------------------------------------+---------+ | |||
| This is the first instance . | ['this', 'is', 'the', 'first', 'instance', ... | 6 | | |||
| Second instance . | ['Second', 'instance', '.'] | 3 | | |||
| Third instance . | ['Third', 'instance', '.'] | 3 | | |||
+------------------------------+------------------------------------------------+---------+ | |||
我们还可以使用 :func:`~fastNLP.DataSet.append` 方法向数据集内增加数据 | |||
@@ -39,7 +50,7 @@ | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
dataset = DataSet() | |||
instance = Instance(sentence="This is the first instance", | |||
instance = Instance(raw_words="This is the first instance", | |||
words=['this', 'is', 'the', 'first', 'instance', '.'], | |||
seq_len=6) | |||
dataset.append(instance) | |||
@@ -52,10 +63,10 @@ | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
dataset = DataSet([ | |||
Instance(sentence="This is the first instance", | |||
Instance(raw_words="This is the first instance", | |||
words=['this', 'is', 'the', 'first', 'instance', '.'], | |||
seq_len=6), | |||
Instance(sentence="Second instance .", | |||
Instance(raw_words="Second instance .", | |||
words=['Second', 'instance', '.'], | |||
seq_len=3) | |||
]) | |||
@@ -106,24 +117,49 @@ FastNLP 同样提供了多种删除数据的方法 :func:`~fastNLP.DataSet.drop` | |||
.. code-block:: python | |||
from fastNLP import DataSet | |||
data = {'sentence':["This is the first instance .", "Second instance .", "Third instance ."]} | |||
data = {'raw_words':["This is the first instance .", "Second instance .", "Third instance ."]} | |||
dataset = DataSet(data) | |||
# 将句子分成单词形式, 详见DataSet.apply()方法 | |||
dataset.apply(lambda ins: ins['sentence'].split(), new_field_name='words') | |||
dataset.apply(lambda ins: ins['raw_words'].split(), new_field_name='words') | |||
# 或使用DataSet.apply_field() | |||
dataset.apply_field(lambda sent:sent.split(), field_name='sentence', new_field_name='words') | |||
dataset.apply_field(lambda sent:sent.split(), field_name='raw_words', new_field_name='words') | |||
# 除了匿名函数,也可以定义函数传递进去 | |||
def get_words(instance): | |||
sentence = instance['sentence'] | |||
sentence = instance['raw_words'] | |||
words = sentence.split() | |||
return words | |||
dataset.apply(get_words, new_field_name='words') | |||
除了手动处理数据集之外,你还可以使用 fastNLP 提供的各种 :class:`~fastNLP.io.base_loader.DataSetLoader` 来进行数据处理。 | |||
详细请参考这篇教程 :doc:`使用DataSetLoader加载数据集 </tutorials/tutorial_2_load_dataset>` 。 | |||
除了手动处理数据集之外,你还可以使用 fastNLP 提供的各种 :class:`~fastNLP.io.Loader`和:class:`~fastNLP.io.Pipe` 来进行数据处理。 | |||
详细请参考这篇教程 :doc:`使用Loader和Pipe处理数据 </tutorials/tutorial_2_load_dataset>` 。 | |||
----------------------------- | |||
fastNLP中field的命名习惯 | |||
----------------------------- | |||
在英文任务中,fastNLP常用的field名称有: | |||
- raw_words: 表示的是原始的str。例如"This is a demo sentence ."。存在多个raw_words的情况,例如matching任务,它们会被定义为 | |||
raw_words0, raw_words1。但在conll格式下,raw_words列也可能为["This", "is", "a", "demo", "sentence", "."]的形式。 | |||
- words: 表示的是已经tokenize后的词语。例如["This", "is", "a", "demo", "sentence"], 但由于str并不能直接被神经网络所使用, | |||
所以words中的内容往往被转换为int,如[3, 10, 4, 2, 7, ...]等。多列words的情况,会被命名为words0, words1 | |||
- target: 表示目标值。分类场景下,只有一个值;序列标注场景下是一个序列。 | |||
- seq_len: 一般用于表示words列的长度 | |||
在中文任务中,fastNLP常用的field名称有: | |||
- raw_chars: 表示的是原始的连续汉字序列。例如"这是一个示例。" | |||
- chars: 表示已经切分为单独的汉字的序列。例如["这", "是", "一", "个", "示", "例", "。"]。但由于神经网络不能识别汉字,所以一般 | |||
该列会被转为int形式,如[3, 4, 5, 6, ...]。 | |||
- raw_words: 如果原始汉字序列中已经包含了词语的边界,则该列称为raw_words。如"上海 浦东 开发 与 法制 建设 同步"。 | |||
- words: 表示单独的汉字词语序列。例如["上海", "", "浦东", "开发", "与", "法制", "建设", ...]或[2, 3, 4, ...] | |||
- target: 表示目标值。分类场景下,只有一个值;序列标注场景下是一个序列。 | |||
- seq_len: 表示输入序列的长度 | |||
# TODO 这一段移动到datasetiter那里 | |||
----------------------------- | |||
DataSet与pad | |||
@@ -1,150 +0,0 @@ | |||
======================================= | |||
使用Loader和Pipe加载并处理数据集 | |||
======================================= | |||
这一部分是一个关于如何加载数据集的教程 | |||
教程目录: | |||
- `Part I: 数据集容器DataBundle`_ | |||
- `Part II: 加载数据集的基类Loader`_ | |||
- `Part III: 不同格式类型的基础Loader`_ | |||
- `Part IV: 使用Pipe对数据集进行预处理`_ | |||
- `Part V: fastNLP封装好的Loader和Pipe`_ | |||
------------------------------------ | |||
Part I: 数据集容器DataBundle | |||
------------------------------------ | |||
在fastNLP中,我们使用 :class:`~fastNLP.io.data_bundle.DataBundle` 来存储数据集信息。 | |||
:class:`~fastNLP.io.data_bundle.DataBundle` 类包含了两个重要内容: `datasets` 和 `vocabs` 。 | |||
`datasets` 是一个 `key` 为数据集名称(如 `train` , `dev` ,和 `test` 等), `value` 为 :class:`~fastNLP.DataSet` 的字典。 | |||
`vocabs` 是一个 `key` 为词表名称(如 :attr:`fastNLP.Const.INPUT` 表示输入文本的词表名称, :attr:`fastNLP.Const.TARGET` 表示目标 | |||
的真实标签词表的名称,等等), `value` 为词表内容( :class:`~fastNLP.Vocabulary` )的字典。 | |||
------------------------------------- | |||
Part II: 加载数据集的基类Loader | |||
------------------------------------- | |||
在fastNLP中,我们采用 :class:`~fastNLP.io.loader.Loader` 来作为加载数据集的基类。 | |||
:class:`~fastNLP.io.loader.Loader` 定义了各种Loader所需的API接口,开发者应该继承它实现各种的Loader。 | |||
在各种数据集的Loader当中,至少应该编写如下内容: | |||
- _load 函数:从一个数据文件中读取数据,返回一个 :class:`~fastNLP.DataSet` | |||
- load 函数:从文件或者文件夹中读取数据并组装成 :class:`~fastNLP.io.data_bundle.DataBundle` | |||
Loader的load函数返回的 :class:`~fastNLP.io.data_bundle.DataBundle` 里面包含了数据集的原始数据。 | |||
-------------------------------------------------------- | |||
Part III: 不同格式类型的基础Loader | |||
-------------------------------------------------------- | |||
:class:`~fastNLP.io.loader.CSVLoader` | |||
读取CSV类型的数据集文件。例子如下: | |||
.. code-block:: python | |||
from fastNLP.io.loader import CSVLoader | |||
data_set_loader = CSVLoader( | |||
headers=('words', 'target'), sep='\t' | |||
) | |||
# 表示将CSV文件中每一行的第一项填入'words' field,第二项填入'target' field。 | |||
# 其中每两项之间由'\t'分割开来 | |||
data_set = data_set_loader._load('path/to/your/file') | |||
数据集内容样例如下 :: | |||
But it does not leave you with much . 1 | |||
You could hate it for the same reason . 1 | |||
The performances are an absolute joy . 4 | |||
:class:`~fastNLP.io.loader.JsonLoader` | |||
读取Json类型的数据集文件,数据必须按行存储,每行是一个包含各类属性的Json对象。例子如下: | |||
.. code-block:: python | |||
from fastNLP.io.loader import JsonLoader | |||
oader = JsonLoader( | |||
fields={'sentence1': 'words1', 'sentence2': 'words2', 'gold_label': 'target'} | |||
) | |||
# 表示将Json对象中'sentence1'、'sentence2'和'gold_label'对应的值赋给'words1'、'words2'、'target'这三个fields | |||
data_set = loader._load('path/to/your/file') | |||
数据集内容样例如下 :: | |||
{"annotator_labels": ["neutral"], "captionID": "3416050480.jpg#4", "gold_label": "neutral", "pairID": "3416050480.jpg#4r1n", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is training his horse for a competition.", "sentence2_binary_parse": "( ( A person ) ( ( is ( ( training ( his horse ) ) ( for ( a competition ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (VP (VBG training) (NP (PRP$ his) (NN horse)) (PP (IN for) (NP (DT a) (NN competition))))) (. .)))"} | |||
{"annotator_labels": ["contradiction"], "captionID": "3416050480.jpg#4", "gold_label": "contradiction", "pairID": "3416050480.jpg#4r1c", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is at a diner, ordering an omelette.", "sentence2_binary_parse": "( ( A person ) ( ( ( ( is ( at ( a diner ) ) ) , ) ( ordering ( an omelette ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (PP (IN at) (NP (DT a) (NN diner))) (, ,) (S (VP (VBG ordering) (NP (DT an) (NN omelette))))) (. .)))"} | |||
{"annotator_labels": ["entailment"], "captionID": "3416050480.jpg#4", "gold_label": "entailment", "pairID": "3416050480.jpg#4r1e", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is outdoors, on a horse.", "sentence2_binary_parse": "( ( A person ) ( ( ( ( is outdoors ) , ) ( on ( a horse ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (ADVP (RB outdoors)) (, ,) (PP (IN on) (NP (DT a) (NN horse)))) (. .)))"} | |||
------------------------------------------ | |||
Part IV: 使用Pipe对数据集进行预处理 | |||
------------------------------------------ | |||
在fastNLP中,我们采用 :class:`~fastNLP.io.pipe.Pipe` 来作为加载数据集的基类。 | |||
:class:`~fastNLP.io.pipe.Pipe` 定义了各种Pipe所需的API接口,开发者应该继承它实现各种的Pipe。 | |||
在各种数据集的Pipe当中,至少应该编写如下内容: | |||
- process 函数:对输入的 :class:`~fastNLP.io.data_bundle.DataBundle` 进行处理(如构建词表、 | |||
将dataset的文本内容转成index等等),然后返回该 :class:`~fastNLP.io.data_bundle.DataBundle` | |||
- process_from_file 函数:输入数据集所在文件夹,读取内容并组装成 :class:`~fastNLP.io.data_bundle.DataBundle` , | |||
然后调用相对应的process函数对数据进行预处理 | |||
以SNLI数据集为例,写一个自定义Pipe的例子如下: | |||
.. code-block:: python | |||
from fastNLP.io.loader import SNLILoader | |||
from fastNLP.io.pipe import MatchingPipe | |||
class MySNLIPipe(MatchingPipe): | |||
def process(self, data_bundle): | |||
data_bundle = super(MySNLIPipe, self).process(data_bundle) | |||
# MatchingPipe类里封装了一个关于matching任务的process函数,可以直接继承使用 | |||
# 如果有需要进行额外的预处理操作可以在这里加入您的代码 | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
data_bundle = SNLILoader().load(paths) # 使用SNLILoader读取原始数据集 | |||
# SNLILoader的load函数中,paths如果为None则会自动下载 | |||
return self.process(data_bundle) # 调用相对应的process函数对data_bundle进行处理 | |||
调用Pipe示例: | |||
.. code-block:: python | |||
from fastNLP.io.pipe import SNLIBertPipe | |||
data_bundle = SNLIBertPipe(lower=True, tokenizer=arg.tokenizer).process_from_file() | |||
print(data_bundle) | |||
输出的内容是:: | |||
In total 3 datasets: | |||
train has 549367 instances. | |||
dev has 9842 instances. | |||
test has 9824 instances. | |||
In total 2 vocabs: | |||
words has 34184 entries. | |||
target has 3 entries. | |||
这里表示一共有3个数据集和2个词表。其中: | |||
- 3个数据集分别为train、dev、test数据集,分别有549367、9842、9824个instance | |||
- 2个词表分别为words词表与target词表。其中words词表为句子文本所构建的词表,一共有34184个单词; | |||
target词表为目标标签所构建的词表,一共有3种标签。(注:如果有多个输入,则句子文本所构建的词表将 | |||
会被命名为words1以对应相对应的列名) | |||
------------------------------------------ | |||
Part V: fastNLP封装好的Loader和Pipe | |||
------------------------------------------ | |||
fastNLP封装了多种任务/数据集的Loader和Pipe并提供自动下载功能,具体参见文档 | |||
`fastNLP可加载的embedding与数据集 <https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0>`_ | |||
@@ -0,0 +1,131 @@ | |||
============================== | |||
Vocabulary | |||
============================== | |||
:class:`~fastNLP.Vocabulary`是包含字或词与index关系的类,用于将文本转换为index。 | |||
----------------------------- | |||
构建Vocabulary | |||
----------------------------- | |||
.. code-block:: python | |||
from fastNLP import Vocabulary | |||
vocab = Vocabulary() | |||
vocab.add_word_lst(['复', '旦', '大', '学']) # 加入新的字 | |||
vocab.add_word('上海') # `上海`会作为一个整体 | |||
vocab.to_index('复') # 应该会为3 | |||
vocab.to_index('我') # 会输出1,Vocabulary中默认pad的index为0, unk(没有找到的词)的index为1 | |||
# 在构建target的Vocabulary时,词表中应该用不上pad和unk,可以通过以下的初始化 | |||
vocab = Vocabulary(unknown=None, pad=None) | |||
vocab.add_word_lst(['positive', 'negative']) | |||
vocab.to_index('positive') # 输出0 | |||
vocab.to_index('neutral') # 会报错 | |||
除了通过以上的方式建立词表,Vocabulary还可以通过使用下面的函数直从 :class:`~fastNLP.DataSet` 中的某一列建立词表以及将该列转换为index | |||
.. code-block:: python | |||
from fastNLP import Vocabulary | |||
from fastNLP import DataSet | |||
dataset = DataSet({'chars': [ | |||
['今', '天', '天', '气', '很', '好', '。'], | |||
['被', '这', '部', '电', '影', '浪', '费', '了', '两', '个', '小', '时', '。'] | |||
], | |||
'target': ['neutral', 'negative'] | |||
}) | |||
vocab = Vocabulary() | |||
vocab.from_dataset(dataset, field_name='chars') | |||
vocab.index_dataset(dataset, field_name='chars') | |||
target_vocab = Vocabulary(padding=None, unknown=None) | |||
target_vocab.from_dataset(dataset, field_name='target') | |||
target_vocab.index_dataset(dataset, field_name='target') | |||
print(dataset) | |||
输出内容为:: | |||
+---------------------------------------------------+--------+ | |||
| chars | target | | |||
+---------------------------------------------------+--------+ | |||
| [4, 2, 2, 5, 6, 7, 3] | 0 | | |||
| [8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 3] | 1 | | |||
+---------------------------------------------------+--------+ | |||
----------------------------- | |||
一些使用tips | |||
----------------------------- | |||
在通过使用from_dataset()函数在DataSet上建立词表时,将测试集和验证集放入参数no_create_entry_dataset中,如下所示 | |||
.. code-block:: python | |||
from fastNLP import Vocabulary | |||
from fastNLP import DataSet | |||
tr_data = DataSet({'chars': [ | |||
['今', '天', '心', '情', '很', '好', '。'], | |||
['被', '这', '部', '电', '影', '浪', '费', '了', '两', '个', '小', '时', '。'] | |||
], | |||
'target': ['positive', 'negative'] | |||
}) | |||
dev_data = DataSet({'chars': [ | |||
['住', '宿', '条', '件', '还', '不', '错'], | |||
['糟', '糕', '的', '天', '气', ',', '无', '法', '出', '行', '。'] | |||
], | |||
'target': ['positive', 'negative'] | |||
}) | |||
vocab = Vocabulary() | |||
# 将验证集或者测试集在建立词表是放入no_create_entry_dataset这个参数中。 | |||
vocab.from_dataset(tr_data, field_name='chars', no_create_entry_dataset=[dev_data]) | |||
:class:`~fastNLP.Vocabulary` 中的`no_create_entry`, 建议在添加来自于测试集和验证集的词的时候将该参数置为True, 或将验证集和测试集 | |||
传入`no_create_entry_dataset`参数。它们的意义是在接下来的模型会使用pretrain的embedding(包括glove, word2vec, elmo与bert)且会finetune的 | |||
情况下,如果仅使用来自于train的数据建立vocabulary,会导致只出现在test与dev中的词语无法充分利用到来自于预训练embedding的信息(因为他们 | |||
会被认为是unk),所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。通过与fastNLP中的各种Embedding配合使用,会有如下的效果, | |||
如果一个词出现在了train中,但是没在预训练模型中,embedding会为随机初始化,且它单独的一个vector,如果finetune embedding的话, | |||
这个词在更新之后可能会有更好的表示; 而如果这个词仅出现在了dev或test中,那么就不能为它们单独建立vector,而应该让它指向unk这个vector的 | |||
值(当unk的值更新时,这个词也使用的是更新之后的vector)。所以被认为是no_create_entry的token,将首先从预训练的词表中寻找它的表示,如 | |||
果找到了,就使用该表示; 如果没有找到,则认为该词的表示应该为unk的表示。 | |||
下面我们结合部分:code:`~fastNLP.embeddings.StaticEmbedding`的例子来说明下该值造成的影响,如果您对 | |||
:code:`~fastNLP.embeddings.StaticEmbedding`不太了解,您可以先参考\{Embedding教程的引用}部分再来阅读该部分 | |||
.. code-block:: python | |||
import torch | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP import Vocabulary | |||
vocab = Vocabulary() | |||
vocab.add_word('train') | |||
vocab.add_word('only_in_train') # 仅在train出现,但肯定在预训练词表中不存在 | |||
vocab.add_word('test', no_create_entry=True) # 该词只在dev或test中出现 | |||
vocab.add_word('only_in_test', no_create_entry=True) # 这个词肯定在预训练中找不到 | |||
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d') | |||
print(embed(torch.LongTensor([vocab.to_index('train')]))) | |||
print(embed(torch.LongTensor([vocab.to_index('only_in_train')]))) | |||
print(embed(torch.LongTensor([vocab.to_index('test')]))) | |||
print(embed(torch.LongTensor([vocab.to_index('only_in_test')]))) | |||
print(embed(torch.LongTensor([vocab.unknown_idx]))) | |||
输出结果(只截取了部分vector):: | |||
tensor([[ 0.9497, 0.3433, 0.8450, -0.8852, ...]], grad_fn=<EmbeddingBackward>) # train | |||
tensor([[ 0.0540, -0.0557, -0.0514, -0.1688, ...]], grad_fn=<EmbeddingBackward>) # only_in_train | |||
tensor([[ 0.1318, -0.2552, -0.0679, 0.2619, ...]], grad_fn=<EmbeddingBackward>) # test | |||
tensor([[0., 0., 0., 0., 0., ...]], grad_fn=<EmbeddingBackward>) # only_in_test | |||
tensor([[0., 0., 0., 0., 0., ...]], grad_fn=<EmbeddingBackward>) # unk | |||
首先train和test都能够从预训练中找到对应的vector,所以它们是各自的vector表示; only_in_train在预训练中找不到,StaticEmbedding为它 | |||
新建了一个entry,所以它有一个单独的vector; 而only_in_dev在预训练中找不到被指向了unk的值(fastNLP用零向量初始化unk),与最后一行unk的 | |||
表示相同。 |
@@ -7,161 +7,446 @@ | |||
教程目录: | |||
- `Part I: embedding介绍`_ | |||
- `Part II: 使用随机初始化的embedding`_ | |||
- `Part III: 使用预训练的静态embedding`_ | |||
- `Part IV: 使用预训练的Contextual Embedding(ELMo & BERT)`_ | |||
- `Part V: 使用character-level的embedding`_ | |||
- `Part VI: 叠加使用多个embedding`_ | |||
- `Part VII: fastNLP支持的预训练Embedding`_ | |||
- `Part II: 使用预训练的静态embedding`_ | |||
- `Part III: 使用随机初始化的embedding`_ | |||
- `Part IV: ELMo Embedding`_ | |||
- `Part V: Bert Embedding`_ | |||
- `Part VI: 使用character-level的embedding`_ | |||
- `Part VII: 叠加使用多个embedding`_ | |||
- `Part VIII: Embedding的其它说明`_ | |||
- `Part IX: StaticEmbedding的使用建议`_ | |||
--------------------------------------- | |||
Part I: embedding介绍 | |||
--------------------------------------- | |||
与torch.nn.Embedding类似,fastNLP的embedding接受的输入是一个被index好的序列,输出的内容是这个序列的embedding结果。 | |||
fastNLP的embedding包括了预训练embedding和随机初始化embedding。 | |||
Embedding是一种词嵌入技术,可以将字或者词转换为实向量。目前使用较多的预训练词嵌入有word2vec, fasttext, glove, character embedding, | |||
elmo以及bert。 | |||
但使用这些词嵌入方式的时候都需要做一些加载上的处理,比如预训练的word2vec, fasttext以及glove都有着超过几十万个词语的表示,但一般任务大概 | |||
只会用到其中几万个词,如果直接加载所有的词汇,会导致内存占用变大以及运行速度变慢,需要从预训练文件中抽取本次实验的用到的词汇;而对于英文的 | |||
elmo和character embedding, 需要将word拆分成character才能使用;Bert的使用更是涉及到了Byte pair encoding(BPE)相关的内容。为了方便 | |||
大家的使用,fastNLP通过:class:`~fastNLP.Vocabulary`统一了不同embedding的使用。下面我们将讲述一些例子来说明一下 | |||
--------------------------------------- | |||
Part II: 使用随机初始化的embedding | |||
Part II: 使用预训练的静态embedding | |||
--------------------------------------- | |||
使用随机初始化的embedding参见 :class:`~fastNLP.embeddings.embedding.Embedding` 。 | |||
可以传入词表大小和embedding维度: | |||
在fastNLP中,加载预训练的word2vec, glove以及fasttext都使用的是 :class:`~fastNLP.embeddings.StaticEmbedding`。另外,为了方便大家的 | |||
使用,fastNLP提供了多种静态词向量的自动下载并缓存(默认缓存到~/.fastNLP/embeddings文件夹下)的功能,支持自动下载的预训练向量可以在 | |||
`<https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0>`_ | |||
查看。 | |||
.. code-block:: python | |||
from fastNLP import Embedding | |||
embed = Embedding(10000, 50) | |||
import torch | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP import Vocabulary | |||
也可以传入一个初始化的参数矩阵: | |||
vocab = Vocabulary() | |||
vocab.add_word_lst("this is a demo .".split()) | |||
.. code-block:: python | |||
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d', requires_grad=True) | |||
words = torch.LongTensor([[vocab.to_index(word) for word in "this is a demo .".split()]]) | |||
print(embed(words).size()) | |||
from fastNLP import Embedding | |||
embed = Embedding(init_embed) | |||
输出为:: | |||
其中的init_embed可以是torch.FloatTensor、torch.nn.Embedding或者numpy.ndarray。 | |||
torch.Size([1, 5, 50]) | |||
fastNLP的StaticEmbedding在初始化之后,就和pytorch中的Embedding是类似的了。:class:`~fastNLP.embeddings.StaticEmbedding`的初始化 | |||
主要是从model_dir_or_name提供的词向量中抽取出:class:`~fastNLP.Vocabulary`中词语的vector。 | |||
除了可以通过使用预先提供的Embedding,:class:`~fastNLP.embeddings.StaticEmbedding`也支持加载本地的预训练词向量,glove, word2vec以及 | |||
fasttext格式的。通过将model_dir_or_name修改为本地的embedding文件路径,即可使用本地的embedding。 | |||
--------------------------------------- | |||
Part III: 使用预训练的静态embedding | |||
Part III: 使用随机初始化的embedding | |||
--------------------------------------- | |||
在使用预训练的embedding之前,需要根据数据集的内容构建一个词表 :class:`~fastNLP.core.vocabulary.Vocabulary` ,在 | |||
预训练embedding类初始化的时候需要将这个词表作为参数传入。 | |||
在fastNLP中,我们提供了 :class:`~fastNLP.embeddings.StaticEmbedding` 这一个类。 | |||
通过 :class:`~fastNLP.embeddings.StaticEmbedding` 可以加载预训练好的静态 | |||
Embedding,例子如下: | |||
有时候需要使用随机初始化的Embedding,也可以通过使用 :class:`~fastNLP.embeddings.StaticEmbedding`获得。只需要将model_dir_or_name | |||
置为None,且传入embedding_dim,如下例所示 | |||
.. code-block:: python | |||
from fastNLP import StaticEmbedding | |||
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True) | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP import Vocabulary | |||
vocab为根据数据集构建的词表,model_dir_or_name可以是一个路径,也可以是embedding模型的名称: | |||
vocab = Vocabulary() | |||
vocab.add_word_lst("this is a demo .".split()) | |||
1 如果传入的是路径,那么fastNLP将会根据该路径来读取预训练的权重文件并将embedding加载进来(glove | |||
和word2vec类型的权重文件都支持) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=30) | |||
2 如果传入的是模型名称,那么fastNLP将会根据名称查找embedding模型,如果在cache目录下找到模型则会 | |||
自动加载;如果找不到则会自动下载到cache目录。默认的cache目录为 `~/.fastNLP` 文件夹。可以通过环境 | |||
变量 ``FASTNLP_CACHE_DIR`` 来自定义cache目录,如:: | |||
words = torch.LongTensor([[vocab.to_index(word) for word in "this is a demo .".split()]]) | |||
print(embed(words).size()) | |||
$ FASTNLP_CACHE_DIR=~/fastnlp_cache_dir python your_python_file.py | |||
输出为:: | |||
torch.Size([1, 5, 30]) | |||
这个命令表示fastNLP将会在 `~/fastnlp_cache_dir` 这个目录下寻找模型,找不到则会自动将模型下载到这个目录 | |||
----------------------------------------------------------- | |||
Part IV: 使用预训练的Contextual Embedding(ELMo & BERT) | |||
Part IV: ELMo Embedding | |||
----------------------------------------------------------- | |||
在fastNLP中,我们提供了ELMo和BERT的embedding: :class:`~fastNLP.embeddings.ElmoEmbedding` | |||
和 :class:`~fastNLP.embeddings.BertEmbedding` 。 | |||
和 :class:`~fastNLP.embeddings.BertEmbedding` 。可自动下载的ElmoEmbedding可以 | |||
从`<https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0>`_找到。 | |||
与静态embedding类似,ELMo的使用方法如下: | |||
.. code-block:: python | |||
from fastNLP import ElmoEmbedding | |||
embed = ElmoEmbedding(vocab, model_dir_or_name='small', requires_grad=False) | |||
from fastNLP.embeddings import ElmoEmbedding | |||
from fastNLP import Vocabulary | |||
vocab = Vocabulary() | |||
vocab.add_word_lst("this is a demo .".split()) | |||
embed = ElmoEmbedding(vocab, model_dir_or_name='en-small', requires_grad=False) | |||
words = torch.LongTensor([[vocab.to_index(word) for word in "this is a demo .".split()]]) | |||
print(embed(words).size()) | |||
输出为:: | |||
torch.Size([1, 5, 256]) | |||
也可以输出多层的ELMo结果,fastNLP将在不同层的结果在最后一维上拼接,下面的代码需要在上面的代码执行结束之后执行 | |||
.. code-block:: python | |||
embed = ElmoEmbedding(vocab, model_dir_or_name='en-small', requires_grad=False, layers='1,2') | |||
print(embed(words).size()) | |||
输出为:: | |||
torch.Size([1, 5, 512]) | |||
另外,根据`<https://arxiv.org/abs/1802.05365>`_,不同层之间使用可学习的权重可以使得ELMo的效果更好,在fastNLP中可以通过以下的初始化 | |||
实现3层输出的结果通过可学习的权重进行加法融合。 | |||
.. code-block:: python | |||
embed = ElmoEmbedding(vocab, model_dir_or_name='en-small', requires_grad=True, layers='mix') | |||
print(embed(words).size()) | |||
输出为:: | |||
torch.Size([1, 5, 256]) | |||
----------------------------------------------------------- | |||
Part V: Bert Embedding | |||
----------------------------------------------------------- | |||
BERT-embedding的使用方法如下: | |||
虽然Bert并不算严格意义上的Embedding,但通过将Bert封装成Embedding的形式将极大减轻使用的复杂程度。可自动下载的Bert Embedding可以 | |||
从`<https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0>`_找到。我们将使用下面的例子讲述一下 | |||
BertEmbedding的使用 | |||
.. code-block:: python | |||
from fastNLP import BertEmbedding | |||
embed = BertEmbedding( | |||
vocab, model_dir_or_name='en-base-cased', requires_grad=False, layers='4,-2,-1' | |||
) | |||
from fastNLP.embeddings import BertEmbedding | |||
from fastNLP import Vocabulary | |||
vocab = Vocabulary() | |||
vocab.add_word_lst("this is a demo .".split()) | |||
embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased') | |||
words = torch.LongTensor([[vocab.to_index(word) for word in "this is a demo .".split()]]) | |||
print(embed(words).size()) | |||
输出为:: | |||
torch.Size([1, 5, 768]) | |||
可以通过申明使用指定层数的output也可以使用多层的output,下面的代码需要在上面的代码执行结束之后执行 | |||
.. code-block:: python | |||
# 使用后面两层的输出 | |||
embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='10,11') | |||
print(embed(words).size()) # 结果将是在最后一维做拼接 | |||
输出为:: | |||
torch.Size([1, 5, 1536]) | |||
在Bert中还存在两个特殊的字符[CLS]和[SEP],默认情况下这两个字符是自动加入并且在计算结束之后会自动删除,以使得输入的序列长度和输出的序列 | |||
长度是一致的,但是有些分类的情况,必须需要使用[CLS]的表示,这种情况可以通过在初始化时申明一下需要保留[CLS]的表示,如下例所示 | |||
.. code-block:: python | |||
embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='-1', include_cls_sep=True) | |||
print(embed(words).size()) # 结果将在序列维度上增加2 | |||
# 取出句子的cls表示 | |||
cls_reps = embed(words)[:, 0] # shape: [batch_size, 768] | |||
输出为:: | |||
torch.Size([1, 7, 768]) | |||
在英文Bert模型中,一个英文单词可能会被切分为多个subword,例如"fairness"会被拆分为["fair", "##ness"],这样一个word对应的将有两个输出, | |||
:class:`~fastNLP.embeddings.BertEmbedding`会使用pooling方法将一个word的subword的表示合并成一个vector,通过pool_method可以控制 | |||
该pooling方法,支持的有"first"(即使用fair的表示作为fairness的表示), "last"(使用##ness的表示作为fairness的表示), "max"(对fair和 | |||
##ness在每一维上做max),"avg"(对fair和##ness每一维做average)。 | |||
.. code-block:: python | |||
embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='-1', pool_method='max') | |||
print(embed(words).size()) | |||
输出为:: | |||
torch.Size([1, 5, 768]) | |||
另外,根据`<https://arxiv.org/abs/1810.04805>`_ ,Bert的还存在一种用法,句子之间通过[SEP]拼接起来,前一句话的token embedding为0, | |||
后一句话的token embedding为1。BertEmbedding能够自动识别句子中间的[SEP]来正确设置对应的token_type_id的。 | |||
.. code-block:: python | |||
vocab = Vocabulary() | |||
vocab.add_word_lst("this is a demo . [SEP] another sentence .".split()) | |||
embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='-1', pool_method='max') | |||
words = torch.LongTensor([[vocab.to_index(word) for word in "this is a demo . [SEP] another sentence .".split()]]) | |||
print(embed(words).size()) | |||
输出为:: | |||
其中layers变量表示需要取哪几层的encode结果。 | |||
torch.Size([1, 9, 768]) | |||
在多个[SEP]的情况下,将会使token_type_id不断0,1循环。比如"first sentence [SEP] second sentence [SEP] third sentence", 它们的 | |||
token_type_id将是[0, 0, 0, 1, 1, 1, 0, 0]。但请注意[SEP]一定要大写的,不能是[sep],否则无法识别。 | |||
更多:class:`~fastNLP.embedding.BertEmbedding`的使用,请参考\ref{找人写一篇BertEmbedding的使用教程} | |||
----------------------------------------------------- | |||
Part V: 使用character-level的embedding | |||
Part VI: 使用character-level的embedding | |||
----------------------------------------------------- | |||
除了预训练的embedding以外,fastNLP还提供了CharEmbedding: :class:`~fastNLP.embeddings.CNNCharEmbedding` 和 | |||
:class:`~fastNLP.embeddings.LSTMCharEmbedding` 。 | |||
除了预训练的embedding以外,fastNLP还提供了两种Character Embedding: :class:`~fastNLP.embeddings.CNNCharEmbedding` 和 | |||
:class:`~fastNLP.embeddings.LSTMCharEmbedding` 。一般在使用character embedding时,需要在预处理的时候将word拆分成character,这 | |||
会使得预处理过程变得非常繁琐。在fastNLP中,使用character embedding也只需要传入:class:`~fastNLP.Vocabulary`即可,而且该 | |||
Vocabulary与其它Embedding使用的Vocabulary是一致的,如下面的例子所示 | |||
CNNCharEmbedding的使用例子如下: | |||
.. code-block:: python | |||
from fastNLP import CNNCharEmbedding | |||
embed = CNNCharEmbedding(vocab, embed_size=100, char_emb_size=50) | |||
from fastNLP.embeddings import CNNCharEmbedding | |||
from fastNLP import Vocabulary | |||
vocab = Vocabulary() | |||
vocab.add_word_lst("this is a demo .".split()) | |||
# character的embedding维度大小为50,返回的embedding结果维度大小为64。 | |||
embed = CNNCharEmbedding(vocab, embed_size=64, char_emb_size=50) | |||
words = torch.LongTensor([[vocab.to_index(word) for word in "this is a demo .".split()]]) | |||
print(embed(words).size()) | |||
这表示这个CNNCharEmbedding当中character的embedding维度大小为50,返回的embedding结果维度大小为100。 | |||
输出为:: | |||
torch.Size([1, 5, 64]) | |||
与CNNCharEmbedding类似,LSTMCharEmbedding的使用例子如下: | |||
.. code-block:: python | |||
from fastNLP import LSTMCharEmbedding | |||
embed = LSTMCharEmbedding(vocab, embed_size=100, char_emb_size=50) | |||
from fastNLP.embeddings import LSTMCharEmbeddding | |||
from fastNLP import Vocabulary | |||
vocab = Vocabulary() | |||
vocab.add_word_lst("this is a demo .".split()) | |||
这表示这个LSTMCharEmbedding当中character的embedding维度大小为50,返回的embedding结果维度大小为100。 | |||
# character的embedding维度大小为50,返回的embedding结果维度大小为64。 | |||
embed = LSTMCharEmbeddding(vocab, embed_size=64, char_emb_size=50) | |||
words = torch.LongTensor([[vocab.to_index(word) for word in "this is a demo .".split()]]) | |||
print(embed(words).size()) | |||
输出为:: | |||
torch.Size([1, 5, 64]) | |||
----------------------------------------------------- | |||
Part VI: 叠加使用多个embedding | |||
Part VII: 叠加使用多个embedding | |||
----------------------------------------------------- | |||
在fastNLP中,我们使用 :class:`~fastNLP.embeddings.StackEmbedding` 来叠加多个embedding | |||
单独使用Character Embedding往往效果并不是很好,需要同时结合word embedding。在fastNLP中可以通过:class:`~fastNLP.embeddings.StackEmbedding` | |||
来叠加embedding,具体的例子如下所示 | |||
.. code-block:: python | |||
from fastNLP.embeddings import StaticEmbedding, StackEmbedding, CNNCharEmbedding | |||
from fastNLP import Vocabulary | |||
vocab = Vocabulary() | |||
vocab.add_word_lst("this is a demo .".split()) | |||
word_embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d') | |||
char_embed = CNNCharEmbedding(vocab, embed_size=64, char_emb_size=50) | |||
embed = StackEmbedding([word_embed, char_embed]) | |||
words = torch.LongTensor([[vocab.to_index(word) for word in "this is a demo .".split()]]) | |||
print(embed(words).size()) # 输出embedding的维度为50+64=114 | |||
输出为:: | |||
torch.Size([1, 5, 114]) | |||
例子如下: | |||
:class:`~fastNLP.embeddings.StaticEmbedding`, :class:`~fastNLP.embeddings.ElmoEmbedding`, | |||
:class:`~fastNLP.embeddings.CNNCharEmbedding`, :class:`~fastNLP.embeddings.BertEmbedding`等都可以互相拼接。 | |||
:class:`~fastNLP.embeddings.StackEmbedding`的使用也是和其它Embedding是一致的,即输出index返回对应的表示。但能够拼接起来的Embedding | |||
必须使用同样的:class:`~fastNLP.Vocabulary`,因为只有使用同样的:class:`~fastNLP.Vocabulary`才能保证同一个index指向的是同一个词或字 | |||
----------------------------------------------------------- | |||
Part VIII: Embedding的其它说明 | |||
----------------------------------------------------------- | |||
(1) 获取各种Embedding的dimension | |||
.. code-block:: python | |||
from fastNLP import StaticEmbedding, StackEmbedding | |||
embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True) | |||
embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True) | |||
from fastNLP.embeddings import * | |||
stack_embed = StackEmbedding([embed_1, embed_2]) | |||
vocab = Vocabulary() | |||
vocab.add_word_lst("this is a demo .".split()) | |||
StackEmbedding会把多个embedding的结果拼接起来,如上面例子的stack_embed返回的embedding维度为350维。 | |||
static_embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d') | |||
print(static_embed.embedding_dim) # 50 | |||
char_embed = CNNCharEmbedding(vocab, embed_size=30) | |||
print(char_embed.embedding_dim) # 30 | |||
elmo_embed_1 = ElmoEmbedding(vocab, model_dir_or_name='en-small', layers='2') | |||
print(elmo_embed_1.embedding_dim) # 256 | |||
elmo_embed_2 = ElmoEmbedding(vocab, model_dir_or_name='en-small', layers='1,2') | |||
print(elmo_embed_2.embedding_dim) # 512 | |||
bert_embed_1 = BertEmbedding(vocab, layers='-1', model_dir_or_name='en-base-cased') | |||
print(bert_embed_1.embedding_dim) # 768 | |||
bert_embed_2 = BertEmbedding(vocab, layers='2,-1', model_dir_or_name='en-base-cased') | |||
print(bert_embed_2.embedding_dim) # 1536 | |||
stack_embed = StackEmbedding([static_embed, char_embed]) | |||
print(stack_embed.embedding_dim) # 80 | |||
除此以外,还可以把静态embedding跟上下文相关的embedding拼接起来: | |||
(2) 设置Embedding的权重是否更新 | |||
.. code-block:: python | |||
from fastNLP import StaticEmbedding, StackEmbedding, ElmoEmbedding | |||
elmo_embedding = ElmoEmbedding(vocab, model_dir_or_name='medium', layers='0,1,2', requires_grad=False) | |||
glove_embedding = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True) | |||
from fastNLP.embeddings import * | |||
vocab = Vocabulary() | |||
vocab.add_word_lst("this is a demo .".split()) | |||
embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased') | |||
embed.requires_grad = False # BertEmbedding不更新 | |||
(3) 各种Embedding中word_dropout与dropout的说明 | |||
fastNLP中所有的Embedding都支持传入word_dropout和dropout参数,word_dropout指示的是以多大概率将输入的word置为unk的index,这样既可以 | |||
是的unk得到训练,也可以有一定的regularize效果; dropout参数是在获取到word的表示之后,以多大概率将一些维度的表示置为0。 | |||
如果使用:class:`~fastNLP.embeddings.StackEmbedding`且需要用到word_dropout,建议将word_dropout设置在:class:`~fastNLP.embeddings.StackEmbedding`。 | |||
----------------------------------------------------------- | |||
Part IX: StaticEmbedding的使用建议 | |||
----------------------------------------------------------- | |||
在英文的命名实体识别(NER)任务中,由`<http://xxx.itp.ac.cn/pdf/1511.08308.pdf>`_ 指出,同时使用cnn character embedding和word embedding | |||
会使得NER的效果有比较大的提升。正如你在\ref{引用第七节}看到的那样,fastNLP支持将:class:`~fastNLP.embeddings.CNNCharacterEmbedding` | |||
与:class:`~fastNLP.embeddings.StaticEmbedding`拼成一个:class:`~fastNLP.embeddings.StackEmbedding`。如果通过这种方式使用,需要 | |||
在预处理文本时,不要将词汇小写化(因为Character Embedding需要利用词语中的大小写信息)且不要将出现频次低于某个阈值的word设置为unk(因为 | |||
Character embedding需要利用字形信息);但:class:`~fastNLP.embeddings.StaticEmbedding`使用的某些预训练词嵌入的词汇表中只有小写的词 | |||
语, 且某些低频词并未在预训练中出现需要被剔除。即(1) character embedding需要保留大小写,而某些static embedding不需要保留大小写。(2) | |||
character embedding需要保留所有的字形, 而static embedding需要设置一个最低阈值以学到更好的表示。 | |||
(1) fastNLP如何解决关于大小写的问题 | |||
fastNLP通过在:class:`~fastNLP.embeddings.StaticEmbedding`增加了一个lower参数解决该问题。如下面的例子所示 | |||
.. code-block:: python | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP import Vocabulary | |||
vocab = Vocabulary().add_word_lst("The the a A".split()) | |||
# 下面用随机的StaticEmbedding演示,但与使用预训练时效果是一致的 | |||
embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5) | |||
print(embed(torch.LongTensor([vocab.to_index('The')]))) | |||
print(embed(torch.LongTensor([vocab.to_index('the')]))) | |||
输出为:: | |||
tensor([[-0.4685, 0.4572, 0.5159, -0.2618, -0.6871]], grad_fn=<EmbeddingBackward>) | |||
tensor([[ 0.2615, 0.1490, -0.2491, 0.4009, -0.3842]], grad_fn=<EmbeddingBackward>) | |||
可以看到"The"与"the"的vector是不一致的。但如果我们在初始化:class:`~fastNLP.embeddings.StaticEmbedding`将lower设置为True,效果将 | |||
如下所示 | |||
.. code-block:: python | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP import Vocabulary | |||
vocab = Vocabulary().add_word_lst("The the a A".split()) | |||
# 下面用随机的StaticEmbedding演示,但与使用预训练时效果是一致的 | |||
embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5, lower=True) | |||
print(embed(torch.LongTensor([vocab.to_index('The')]))) | |||
print(embed(torch.LongTensor([vocab.to_index('the')]))) | |||
输出为:: | |||
tensor([[-0.2237, 0.6825, -0.3459, -0.1795, 0.7516]], grad_fn=<EmbeddingBackward>) | |||
tensor([[-0.2237, 0.6825, -0.3459, -0.1795, 0.7516]], grad_fn=<EmbeddingBackward>) | |||
可以看到"The"与"the"的vector是一致的。他们实际上也是引用的同一个vector。通过将lower设置为True,可以在:class:`~fastNLP.embeddings.StaticEmbedding` | |||
实现类似具备相同小写结果的词语引用同一个vector。 | |||
(2) fastNLP如何解决min_freq的问题 | |||
fastNLP通过在:class:`~fastNLP.embeddings.StaticEmbedding`增加了一个min_freq参数解决该问题。如下面的例子所示 | |||
.. code-block:: python | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP import Vocabulary | |||
vocab = Vocabulary().add_word_lst("the the the a".split()) | |||
# 下面用随机的StaticEmbedding演示,但与使用预训练时效果是一致的 | |||
embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5, min_freq=2) | |||
print(embed(torch.LongTensor([vocab.to_index('the')]))) | |||
print(embed(torch.LongTensor([vocab.to_index('a')]))) | |||
print(embed(torch.LongTensor([vocab.unknown_idx]))) | |||
输出为:: | |||
tensor([[ 0.0454, 0.3375, 0.6758, -0.2026, -0.4715]], grad_fn=<EmbeddingBackward>) | |||
tensor([[-0.7602, 0.0149, 0.2733, 0.3974, 0.7371]], grad_fn=<EmbeddingBackward>) | |||
tensor([[-0.7602, 0.0149, 0.2733, 0.3974, 0.7371]], grad_fn=<EmbeddingBackward>) | |||
其中最后一行为unknown值的vector,可以看到a的vector表示与unknown是一样的,这是由于a的频次低于了2,所以被指向了unknown的表示;而the由于 | |||
词频超过了2次,所以它是单独的表示。 | |||
在计算min_freq时,也会考虑到lower的作用,比如 | |||
.. code-block:: python | |||
stack_embed = StackEmbedding([elmo_embedding, glove_embedding]) | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP import Vocabulary | |||
------------------------------------------ | |||
Part VII: fastNLP支持的预训练Embedding | |||
------------------------------------------ | |||
vocab = Vocabulary().add_word_lst("the the the a A".split()) | |||
# 下面用随机的StaticEmbedding演示,但与使用预训练时效果是一致的 | |||
embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5, min_freq=2, lower=True) | |||
print(embed(torch.LongTensor([vocab.to_index('the')]))) | |||
print(embed(torch.LongTensor([vocab.to_index('a')]))) | |||
print(embed(torch.LongTensor([vocab.to_index('A')]))) | |||
print(embed(torch.LongTensor([vocab.unknown_idx]))) | |||
fastNLP支持多种预训练Embedding并提供自动下载功能,具体参见文档 | |||
输出为:: | |||
`fastNLP可加载的embedding与数据集 <https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0>`_ | |||
tensor([[-0.7453, -0.5542, 0.5039, 0.6195, -0.4723]], grad_fn=<EmbeddingBackward>) # the | |||
tensor([[ 0.0170, -0.0995, -0.5743, -0.2469, -0.2095]], grad_fn=<EmbeddingBackward>) # a | |||
tensor([[ 0.0170, -0.0995, -0.5743, -0.2469, -0.2095]], grad_fn=<EmbeddingBackward>) # A | |||
tensor([[ 0.6707, -0.5786, -0.6967, 0.0111, 0.1209]], grad_fn=<EmbeddingBackward>) # unk | |||
可以看到a不再和最后一行的unknown共享一个表示了,这是由于a与A都算入了a的词频,且A的表示也是a的表示。 |
@@ -0,0 +1,219 @@ | |||
======================================= | |||
使用Loader和Pipe加载并处理数据集 | |||
======================================= | |||
这一部分是一个关于如何加载数据集的教程 | |||
教程目录: | |||
- `Part I: 数据集容器DataBundle`_ | |||
- `Part II: 加载的各种数据集的Loader`_ | |||
- `Part III: 使用Pipe对数据集进行预处理`_ | |||
- `Part IV: fastNLP封装好的Loader和Pipe`_ | |||
- `Part V: 不同格式类型的基础Loader`_ | |||
------------------------------------ | |||
Part I: 数据集容器DataBundle | |||
------------------------------------ | |||
而由于对于同一个任务,训练集,验证集和测试集会共用同一个词表以及具有相同的目标值,所以在fastNLP中我们使用了 :class:`~fastNLP.io.DataBundle` | |||
来承载同一个任务的多个数据集 :class:`~fastNLP.DataSet` 以及它们的词表 :class:`~fastNLP.Vocabulary`。下面会有例子介绍:class:`~fastNLP.io.DataBundle` | |||
的相关使用。 | |||
:class: `~fastNLP.io.DataBundle` 在fastNLP中主要在各个 :class: `~fastNLP.io.Loader` 和 :class: `~fastNLP.io.Pipe` 中被使用。 | |||
下面我们将先介绍一下 :class: `~fastNLP.io.Loader` 和 :class: `~fastNLP.io.Pipe`, 之后我们将给出相应的例子。 | |||
------------------------------------- | |||
Part II: 加载的各种数据集的Loader | |||
------------------------------------- | |||
在fastNLP中,所有的数据Loader都可以通过其文档判断其支持读取的数据格式,以及读取之后返回的 :class:`~fastNLP.DataSet` 的格式。例如 | |||
\ref 加个引用。 | |||
- download 函数:自动将该数据集下载到缓存地址,默认缓存地址为~/.fastNLP/datasets/。由于版权等原因,不是所有的Loader都实现了该方法。 | |||
该方法会返回下载后文件所处的缓存地址。可以查看对应Loader的download的方法的文档来判断该Loader加载的数据。 | |||
- _load 函数:从一个数据文件中读取数据,返回一个 :class:`~fastNLP.DataSet`。返回的DataSet的格式可从Loader文档判断。 | |||
- load 函数:从文件或者文件夹中读取数据并组装成 :class:`~fastNLP.io.DataBundle`。支持接受的参数类型有以下的几种 | |||
- None, 将尝试读取自动缓存的数据,仅支持提供了自动下载数据的Loader | |||
- 文件夹路径, 默认将尝试在该路径下匹配文件名中含有`train`, `test`, `dev`的文件,如果有多个文件含有这相同的关键字,将无法通过 | |||
该方式读取 | |||
- dict, 例如{'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} | |||
.. code-block:: python | |||
from fastNLP.io import CWSLoader | |||
loader = CWSLoader(dataset_name='pku') | |||
data_bundle = loader.load() | |||
print(data_bundle) | |||
输出内容为:: | |||
In total 3 datasets: | |||
dev has 1831 instances. | |||
train has 17223 instances. | |||
test has 1944 instances. | |||
这里表示一共有3个数据集。其中: | |||
- 3个数据集分别为train、dev、test数据集,分别有17223、1831、1944个instance | |||
也可以取出DataSet并DataSet中的具体内容 | |||
.. code-block:: python | |||
tr_data = data_bundle.get_dataset('train') | |||
print(tr_data[:2]) | |||
输出为:: | |||
+--------------------------------------------------------------------------------------+ | |||
| raw_words | | |||
+--------------------------------------------------------------------------------------+ | |||
| 迈向 充满 希望 的 新 世纪 —— 一九九八年 新年 讲话 ( 附 图片 1 张 ) | | |||
| 中共中央 总书记 、 国家 主席 江 泽民 | | |||
+--------------------------------------------------------------------------------------+ | |||
------------------------------------------ | |||
Part III: 使用Pipe对数据集进行预处理 | |||
------------------------------------------ | |||
通过:class:`~fastNLP.io.Loader` 可以将文本数据读入,但并不能直接被神经网络使用,还需要进行一定的预处理。 | |||
在fastNLP中,我们使用 :class:`~fastNLP.io.Pipe`的子类作为数据预处理的类,Pipe和Loader一般具备一一对应的关系,该关系可以从其名称判断, | |||
例如:class:`~fastNLP.io.CWSLoader`与:class:`~fastNLP.io.CWSPipe`是一一对应的。一般情况下Pipe处理包含以下的几个过程,(1)将raw_words或 | |||
raw_chars进行tokenize以切分成不同的词或字; (2) 再建立词或字的 :class:`~fastNLP.Vocabulary`, 并将词或字转换为index; (3)将target | |||
列建立词表并将target列转为index; | |||
所有的Pipe都可通过其文档查看通过该Pipe之后DataSet中的field的情况; 如 \ref{TODO 添加对例子的引用} | |||
各种数据集的Pipe当中,都包含了以下的两个函数: | |||
- process 函数:对输入的 :class:`~fastNLP.io.DataBundle` 进行处理, 然后返回处理之后的 :class:`~fastNLP.io.DataBundle`。 | |||
process函数的文档中包含了该Pipe支持处理的DataSet的格式。 | |||
- process_from_file 函数:输入数据集所在文件夹,使用对应的Loader读取数据(所以该函数支持的参数类型是由于其对应的Loader的load函数 | |||
决定的),然后调用相对应的process函数对数据进行预处理。相当于是把Load和process放在一个函数中执行。 | |||
接着上面CWSLoader的例子,我们展示一下CWSPipe的功能: | |||
.. code-block:: python | |||
from fastNLP.io import CWSPipe | |||
data_bundle = CWSPipe().process(data_bundle) | |||
print(data_bundle) | |||
输出内容为:: | |||
In total 3 datasets: | |||
dev has 1831 instances. | |||
train has 17223 instances. | |||
test has 1944 instances. | |||
In total 2 vocabs: | |||
chars has 4777 entries. | |||
target has 4 entries. | |||
表示一共有3个数据集和2个词表。其中: | |||
- 3个数据集分别为train、dev、test数据集,分别有17223、1831、1944个instance | |||
- 2个词表分别为chars词表与target词表。其中chars词表为句子文本所构建的词表,一共有4777个字; | |||
target词表为目标标签所构建的词表,一共有4种标签。 | |||
相较于之前CWSLoader读取的DataBundle,新增了两个Vocabulary。 我们可以打印一下处理之后的DataSet | |||
.. code-block:: python | |||
tr_data = data_bundle.get_dataset('train') | |||
print(tr_data[:2]) | |||
输出为:: | |||
+---------------------------------------------------+------------------------------------+------------------------------------+---------+ | |||
| raw_words | chars | target | seq_len | | |||
+---------------------------------------------------+------------------------------------+------------------------------------+---------+ | |||
| 迈向 充满 希望 的 新 世纪 —— 一九九八年... | [1224, 178, 674, 544, 573, 435,... | [0, 1, 0, 1, 0, 1, 2, 2, 0, 1, ... | 29 | | |||
| 中共中央 总书记 、 国家 主席 江 泽民 | [11, 212, 11, 335, 124, 256, 10... | [0, 3, 3, 1, 0, 3, 1, 2, 0, 1, ... | 15 | | |||
+---------------------------------------------------+------------------------------------+------------------------------------+---------+ | |||
可以看到有两列为int的field: chars和target。这两列的名称同时也是DataBundle中的Vocabulary的名称。可以通过下列的代码获取并查看Vocabulary的 | |||
信息 | |||
.. code-block:: python | |||
vocab = data_bundle.get_vocab('target') | |||
print(vocab) | |||
输出为:: | |||
Vocabulary(['B', 'E', 'S', 'M']...) | |||
------------------------------------------ | |||
Part IV: fastNLP封装好的Loader和Pipe | |||
------------------------------------------ | |||
fastNLP封装了多种任务/数据集的Loader和Pipe并提供自动下载功能,具体参见文档 | |||
`fastNLP可加载数据集 <https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0>`_ | |||
-------------------------------------------------------- | |||
Part V: 不同格式类型的基础Loader | |||
-------------------------------------------------------- | |||
除了上面提到的针对具体任务的Loader,我们还提供了CSV格式和JSON格式的Loader | |||
:class:`~fastNLP.io.loader.CSVLoader` | |||
读取CSV类型的数据集文件。例子如下: | |||
.. code-block:: python | |||
from fastNLP.io.loader import CSVLoader | |||
data_set_loader = CSVLoader( | |||
headers=('raw_words', 'target'), sep='\t' | |||
) | |||
# 表示将CSV文件中每一行的第一项填入'words' field,第二项填入'target' field。 | |||
# 其中项之间由'\t'分割开来 | |||
data_set = data_set_loader._load('path/to/your/file') | |||
数据集内容样例如下 :: | |||
But it does not leave you with much . 1 | |||
You could hate it for the same reason . 1 | |||
The performances are an absolute joy . 4 | |||
读取之后的DataSet具有以下的field | |||
.. csv-table:: | |||
:header: raw_words, target | |||
"But it does not leave you with much .", "1" | |||
"You could hate it for the same reason .", "1" | |||
"The performances are an absolute joy .", "4" | |||
:class:`~fastNLP.io.loader.JsonLoader` | |||
读取Json类型的数据集文件,数据必须按行存储,每行是一个包含各类属性的Json对象。例子如下: | |||
.. code-block:: python | |||
from fastNLP.io.loader import JsonLoader | |||
oader = JsonLoader( | |||
fields={'sentence1': 'raw_words1', 'sentence2': 'raw_words2', 'gold_label': 'target'} | |||
) | |||
# 表示将Json对象中'sentence1'、'sentence2'和'gold_label'对应的值赋给'raw_words1'、'raw_words2'、'target'这三个fields | |||
data_set = loader._load('path/to/your/file') | |||
数据集内容样例如下 :: | |||
{"annotator_labels": ["neutral"], "captionID": "3416050480.jpg#4", "gold_label": "neutral", "pairID": "3416050480.jpg#4r1n", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is training his horse for a competition.", "sentence2_binary_parse": "( ( A person ) ( ( is ( ( training ( his horse ) ) ( for ( a competition ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (VP (VBG training) (NP (PRP$ his) (NN horse)) (PP (IN for) (NP (DT a) (NN competition))))) (. .)))"} | |||
{"annotator_labels": ["contradiction"], "captionID": "3416050480.jpg#4", "gold_label": "contradiction", "pairID": "3416050480.jpg#4r1c", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is at a diner, ordering an omelette.", "sentence2_binary_parse": "( ( A person ) ( ( ( ( is ( at ( a diner ) ) ) , ) ( ordering ( an omelette ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (PP (IN at) (NP (DT a) (NN diner))) (, ,) (S (VP (VBG ordering) (NP (DT an) (NN omelette))))) (. .)))"} | |||
{"annotator_labels": ["entailment"], "captionID": "3416050480.jpg#4", "gold_label": "entailment", "pairID": "3416050480.jpg#4r1e", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is outdoors, on a horse.", "sentence2_binary_parse": "( ( A person ) ( ( ( ( is outdoors ) , ) ( on ( a horse ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (ADVP (RB outdoors)) (, ,) (PP (IN on) (NP (DT a) (NN horse)))) (. .)))"} | |||
读取之后的DataSet具有以下的field | |||
.. csv-table:: | |||
:header: raw_words0, raw_words1, target | |||
"A person on a horse jumps over a broken down airplane.", "A person is training his horse for a competition.", "neutral" | |||
"A person on a horse jumps over a broken down airplane.", "A person is at a diner, ordering an omelette.", "contradiction" | |||
"A person on a horse jumps over a broken down airplane.", "A person is outdoors, on a horse.", "entailment" |
@@ -1,114 +0,0 @@ | |||
===================== | |||
快速实现序列标注模型 | |||
===================== | |||
这一部分的内容主要展示如何使用fastNLP 实现序列标注任务。你可以使用fastNLP的各个组件快捷,方便地完成序列标注任务,达到出色的效果。 | |||
在阅读这篇Tutorial前,希望你已经熟悉了fastNLP的基础使用,包括基本数据结构以及数据预处理,embedding的嵌入等,希望你对之前的教程有更进一步的掌握。 | |||
我们将对CoNLL-03的英文数据集进行处理,展示如何完成命名实体标注任务整个训练的过程。 | |||
载入数据 | |||
=================================== | |||
fastNLP可以方便地载入各种类型的数据。同时,针对常见的数据集,我们已经预先实现了载入方法,其中包含CoNLL-03数据集。 | |||
在设计dataloader时,以DataSetLoader为基类,可以改写并应用于其他数据集的载入。 | |||
.. code-block:: python | |||
class Conll2003DataLoader(DataSetLoader): | |||
def __init__(self, task:str='ner', encoding_type:str='bioes'): | |||
assert task in ('ner', 'pos', 'chunk') | |||
index = {'ner':3, 'pos':1, 'chunk':2}[task] | |||
#ConllLoader是fastNLP内置的类 | |||
self._loader = ConllLoader(headers=['raw_words', 'target'], indexes=[0, index]) | |||
self._tag_converters = None | |||
if task in ('ner', 'chunk'): | |||
#iob和iob2bioes会对tag进行统一,标准化 | |||
self._tag_converters = [iob2] | |||
if encoding_type == 'bioes': | |||
self._tag_converters.append(iob2bioes) | |||
def load(self, path: str): | |||
dataset = self._loader.load(path) | |||
def convert_tag_schema(tags): | |||
for converter in self._tag_converters: | |||
tags = converter(tags) | |||
return tags | |||
if self._tag_converters: | |||
#使用apply实现convert_tag_schema函数,实际上也支持匿名函数 | |||
dataset.apply_field(convert_tag_schema, field_name=Const.TARGET, new_field_name=Const.TARGET) | |||
return dataset | |||
输出数据格式如: | |||
{'raw_words': ['on', 'Friday', ':'] type=list, | |||
'target': ['O', 'O', 'O'] type=list}, | |||
数据处理 | |||
---------------------------- | |||
我们进一步处理数据。将数据和词表封装在 :class:`~fastNLP.DataBundle` 类中。data是DataBundle的实例。 | |||
我们输入模型的数据包括char embedding,以及word embedding。在数据处理部分,我们尝试完成词表的构建。 | |||
使用fastNLP中的Vocabulary类来构建词表。 | |||
.. code-block:: python | |||
word_vocab = Vocabulary(min_freq=2) | |||
word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT) | |||
word_vocab.index_dataset(*data.datasets.values(),field_name=Const.INPUT, new_field_name=Const.INPUT) | |||
处理后的data对象内部为: | |||
dataset | |||
vocabs | |||
dataset保存了train和test中的数据,并保存为dataset类型 | |||
vocab保存了words,raw-words以及target的词表。 | |||
模型构建 | |||
-------------------------------- | |||
我们使用CNN-BILSTM-CRF模型完成这一任务。在网络构建方面,fastNLP的网络定义继承pytorch的 :class:`nn.Module` 类。 | |||
自己可以按照pytorch的方式定义网络。需要注意的是命名。fastNLP的标准命名位于 :class:`~fastNLP.Const` 类。 | |||
模型的训练 | |||
首先实例化模型,导入所需的char embedding以及word embedding。Embedding的载入可以参考教程。 | |||
也可以查看 :mod:`~fastNLP.modules.encoder.embedding` 使用所需的embedding 载入方法。 | |||
fastNLP将模型的训练过程封装在了 :class:`~fastnlp.trainer` 类中。 | |||
根据不同的任务调整trainer中的参数即可。通常,一个trainer实例需要有:指定的训练数据集,模型,优化器,loss函数,评测指标,以及指定训练的epoch数,batch size等参数。 | |||
.. code-block:: python | |||
#实例化模型 | |||
model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type) | |||
#定义优化器 | |||
optimizer = Adam(model.parameters(), lr=0.005) | |||
#定义评估指标 | |||
Metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type) | |||
#实例化trainer | |||
trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, dev_data=data.datasets['test'], batch_size=10, metrics=Metrics,callbacks=callbacks, n_epochs=100) | |||
#开始训练 | |||
trainer.train() | |||
训练中会保存最优的参数配置。 | |||
训练的结果如下: | |||
.. code-block:: python | |||
Evaluation on DataSet test: | |||
SpanFPreRecMetric: f=0.727661, pre=0.732293, rec=0.723088 | |||
Evaluation at Epoch 1/100. Step:1405/140500. SpanFPreRecMetric: f=0.727661, pre=0.732293, rec=0.723088 | |||
Evaluation on DataSet test: | |||
SpanFPreRecMetric: f=0.784307, pre=0.779371, rec=0.789306 | |||
Evaluation at Epoch 2/100. Step:2810/140500. SpanFPreRecMetric: f=0.784307, pre=0.779371, rec=0.789306 | |||
Evaluation on DataSet test: | |||
SpanFPreRecMetric: f=0.810068, pre=0.811003, rec=0.809136 | |||
Evaluation at Epoch 3/100. Step:4215/140500. SpanFPreRecMetric: f=0.810068, pre=0.811003, rec=0.809136 | |||
Evaluation on DataSet test: | |||
SpanFPreRecMetric: f=0.829592, pre=0.84153, rec=0.817989 | |||
Evaluation at Epoch 4/100. Step:5620/140500. SpanFPreRecMetric: f=0.829592, pre=0.84153, rec=0.817989 | |||
Evaluation on DataSet test: | |||
SpanFPreRecMetric: f=0.828789, pre=0.837096, rec=0.820644 | |||
Evaluation at Epoch 5/100. Step:7025/140500. SpanFPreRecMetric: f=0.828789, pre=0.837096, rec=0.820644 | |||
@@ -0,0 +1,98 @@ | |||
===================== | |||
快速实现序列标注模型 | |||
===================== | |||
这一部分的内容主要展示如何使用fastNLP 实现序列标注任务。你可以使用fastNLP的各个组件快捷,方便地完成序列标注任务,达到出色的效果。 | |||
在阅读这篇Tutorial前,希望你已经熟悉了fastNLP的基础使用,尤其是数据的载入以及模型的构建,通过这个小任务的能让你进一步熟悉fastNLP的使用。 | |||
我们将对基于Weibo的中文社交数据集进行处理,展示如何完成命名实体标注任务的整个过程。 | |||
载入数据 | |||
=================================== | |||
fastNLP的数据载入主要是由Loader与Pipe两个基类衔接完成的。通过Loader可以方便地载入各种类型的数据。同时,针对常见的数据集,我们已经预先实现了载入方法,其中包含weibo数据集。 | |||
在设计dataloader时,以DataSetLoader为基类,可以改写并应用于其他数据集的载入。 | |||
.. code-block:: python | |||
from fastNLP.io import WeiboNERLoader | |||
data_bundle = WeiboNERLoader().load() | |||
载入后的数据如 :: | |||
{'dev': DataSet( | |||
{{'raw_chars': ['用', '最', '大', '努', '力', '去', '做''人', '生', '。', '哈', '哈', '哈', '哈', '哈', '哈', ' | |||
'target': ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O',, 'O', 'O', 'O', 'O', 'O', 'O'] type=list})} | |||
{'test': DataSet( | |||
{{'raw_chars': ['感', '恩', '大', '回', '馈'] type=list, 'target': ['O', 'O', 'O', 'O', 'O'] type=list})} | |||
{'train': DataSet( | |||
{'raw_chars': ['国', '安', '老', '球', '迷'] type=list, 'target': ['B-ORG.NAM', 'I-ORG.NAM', 'B-PER.NOM', 'I-PER.NOM', 'I-PER.NOM'] type=list})} | |||
数据处理 | |||
---------------------------- | |||
我们进一步处理数据。通过Pipe基类处理Loader载入的数据。 如果你还有印象,应该还能想起,实现自定义数据集的Pipe时,至少要编写process 函数或者process_from_file 函数。前者接受 :class:`~fastNLP.DataBundle` 类的数据,并返回该 :class:`~fastNLP.DataBundle` 。后者接收数据集所在文件夹为参数,读取并处理为 :class:`~fastNLP.DataBundle` 后,通过process 函数处理数据。 | |||
这里我们已经实现通过Loader载入数据,并已返回 :class:`~fastNLP.DataBundle` 类的数据。我们编写process 函数以处理Loader载入后的数据。 | |||
.. code-block:: python | |||
from fastNLP.io import ChineseNERPipe | |||
data_bundle = ChineseNERPipe(encoding_type='bioes', bigram=True).process(data_bundle) | |||
载入后的数据如下 :: | |||
{'raw_chars': ['用', '最', '大', '努', '力', '去', '做', '值', '得', '的', '事', '人', '生', '。', '哈', '哈', '哈', '哈', '哈', '哈', '我', '在'] type=list, | |||
'target': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] type=list, | |||
'chars': [97, 71, 34, 422, 104, 72, 144, 628, 66, 3, 158, 2, 9, 647, 485, 196, 2,19] type=list, | |||
'bigrams': [5948, 1950, 34840, 98, 8413, 3961, 34841, 631, 34842, 407, 462, 45, 3 1959, 1619, 3, 3, 3, 3, 3, 2663, 29, 90] type=list, | |||
'seq_len': 30 type=int} | |||
模型构建 | |||
-------------------------------- | |||
我们使用CNN-BILSTM-CRF模型完成这一任务。在网络构建方面,fastNLP的网络定义继承pytorch的 :class:`nn.Module` 类。 | |||
自己可以按照pytorch的方式定义网络。需要注意的是命名。fastNLP的标准命名位于 :class:`~fastNLP.Const` 类。 | |||
模型的训练 | |||
首先实例化模型,导入所需的char embedding以及word embedding。Embedding的载入可以参考教程。 | |||
也可以查看 :mod:`~fastNLP.embedding` 使用所需的embedding 载入方法。 | |||
fastNLP将模型的训练过程封装在了 :class:`~fastnlp.Trainer` 类中。 | |||
根据不同的任务调整trainer中的参数即可。通常,一个trainer实例需要有:指定的训练数据集,模型,优化器,loss函数,评测指标,以及指定训练的epoch数,batch size等参数。 | |||
.. code-block:: python | |||
#实例化模型 | |||
model = CNBiLSTMCRFNER(char_embed, num_classes=len(data_bundle.vocabs['target']), bigram_embed=bigram_embed) | |||
#定义评估指标 | |||
Metrics=SpanFPreRecMetric(data_bundle.vocabs['target'], encoding_type='bioes') | |||
#实例化trainer并训练 | |||
Trainer(data_bundle.datasets['train'], model, batch_size=20, metrics=Metrics, num_workers=2, dev_data=data_bundle. datasets['dev']).train() | |||
训练中会保存最优的参数配置。 | |||
训练的结果如下 :: | |||
Evaluation on DataSet test: | |||
SpanFPreRecMetric: f=0.727661, pre=0.732293, rec=0.723088 | |||
Evaluation at Epoch 1/100. Step:1405/140500. SpanFPreRecMetric: f=0.727661, pre=0.732293, rec=0.723088 | |||
Evaluation on DataSet test: | |||
SpanFPreRecMetric: f=0.784307, pre=0.779371, rec=0.789306 | |||
Evaluation at Epoch 2/100. Step:2810/140500. SpanFPreRecMetric: f=0.784307, pre=0.779371, rec=0.789306 | |||
Evaluation on DataSet test: | |||
SpanFPreRecMetric: f=0.810068, pre=0.811003, rec=0.809136 | |||
Evaluation at Epoch 3/100. Step:4215/140500. SpanFPreRecMetric: f=0.810068, pre=0.811003, rec=0.809136 | |||
Evaluation on DataSet test: | |||
SpanFPreRecMetric: f=0.829592, pre=0.84153, rec=0.817989 | |||
Evaluation at Epoch 4/100. Step:5620/140500. SpanFPreRecMetric: f=0.829592, pre=0.84153, rec=0.817989 | |||
Evaluation on DataSet test: | |||
SpanFPreRecMetric: f=0.828789, pre=0.837096, rec=0.820644 | |||
Evaluation at Epoch 5/100. Step:7025/140500. SpanFPreRecMetric: f=0.828789, pre=0.837096, rec=0.820644 | |||
@@ -8,13 +8,14 @@ fastNLP 详细使用教程 | |||
:maxdepth: 1 | |||
使用DataSet预处理文本 </tutorials/tutorial_1_data_preprocess> | |||
使用Loader和Pipe加载并处理数据集 </tutorials/tutorial_2_load_dataset> | |||
使用Vocabulary转换文本与index </tutorials/tutorial_2_vocabulary> | |||
使用Embedding模块将文本转成向量 </tutorials/tutorial_3_embedding> | |||
动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试 </tutorials/tutorial_4_loss_optimizer> | |||
使用Loader和Pipe加载并处理数据集 </tutorials/tutorial_4_load_dataset> | |||
动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程 </tutorials/tutorial_5_datasetiter> | |||
快速实现序列标注模型 </tutorials/tutorial_6_seq_labeling> | |||
使用Modules和Models快速搭建自定义模型 </tutorials/tutorial_7_modules_models> | |||
使用Metric快速评测你的模型 </tutorials/tutorial_8_metrics> | |||
使用Callback自定义你的训练过程 </tutorials/tutorial_9_callback> | |||
使用fitlog 辅助 fastNLP 进行科研 </tutorials/tutorial_10_fitlog> | |||
动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试 </tutorials/tutorial_6_loss_optimizer> | |||
使用Metric快速评测你的模型 </tutorials/tutorial_7_metrics> | |||
使用Modules和Models快速搭建自定义模型 </tutorials/tutorial_8_modules_models> | |||
快速实现序列标注模型 </tutorials/tutorial_9_seq_labeling> | |||
使用Callback自定义你的训练过程 </tutorials/tutorial_10_callback> | |||
使用fitlog 辅助 fastNLP 进行科研 </tutorials/tutorial_11_fitlog> | |||
@@ -70,3 +70,7 @@ from . import models | |||
from . import modules | |||
from .core import * | |||
from .io import loader, pipe | |||
import sys | |||
from .doc_utils import doc_process | |||
doc_process(sys.modules[__name__]) |
@@ -9,15 +9,16 @@ __all__ = [ | |||
] | |||
import atexit | |||
from numbers import Number | |||
import numpy as np | |||
import torch | |||
import torch.utils.data | |||
from numbers import Number | |||
from .sampler import SequentialSampler | |||
from .dataset import DataSet | |||
from ._logger import logger | |||
from .dataset import DataSet | |||
from .sampler import SequentialSampler | |||
_python_is_exit = False | |||
@@ -145,8 +146,6 @@ class BatchIter: | |||
class DataSetIter(BatchIter): | |||
""" | |||
别名::class:`fastNLP.DataSetIter` :class:`fastNLP.core.batch.DataSetIter` | |||
DataSetIter 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出, | |||
组成 `x` 和 `y`:: | |||
@@ -155,23 +154,26 @@ class DataSetIter(BatchIter): | |||
for batch_x, batch_y in batch: | |||
# do stuff ... | |||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | |||
:param int batch_size: 取出的batch大小 | |||
:param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`. | |||
Default: ``None`` | |||
:param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`. | |||
Default: ``False`` | |||
:param int num_workers: 使用多少个进程来预处理数据 | |||
:param bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快速度。 | |||
:param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 | |||
:param timeout: | |||
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | |||
""" | |||
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, | |||
num_workers=0, pin_memory=False, drop_last=False, | |||
timeout=0, worker_init_fn=None): | |||
""" | |||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | |||
:param int batch_size: 取出的batch大小 | |||
:param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`. | |||
Default: ``None`` | |||
:param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`. | |||
Default: ``False`` | |||
:param int num_workers: 使用多少个进程来预处理数据 | |||
:param bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快速度。 | |||
:param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 | |||
:param timeout: | |||
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | |||
""" | |||
super().__init__() | |||
assert isinstance(dataset, DataSet) | |||
if not isinstance(sampler, torch.utils.data.Sampler): | |||
@@ -70,10 +70,11 @@ __all__ = [ | |||
] | |||
import os | |||
import sys | |||
from copy import deepcopy | |||
import torch | |||
from copy import deepcopy | |||
import sys | |||
from .utils import _save_model | |||
try: | |||
@@ -95,8 +96,6 @@ except: | |||
class Callback(object): | |||
""" | |||
别名::class:`fastNLP.Callback` :class:`fastNLP.core.callback.Callback` | |||
Callback是fastNLP中被设计用于增强 :class:`~fastNLP.Trainer` 的类。 | |||
如果Callback被传递给了 Trainer , 则 Trainer 会在对应的阶段调用Callback的函数, | |||
具体调用时机可以通过 :doc:`trainer 模块<fastNLP.core.trainer>` 查看。 | |||
@@ -318,9 +317,11 @@ def _transfer(func): | |||
class CallbackManager(Callback): | |||
""" | |||
内部使用的Callback管理类 | |||
""" | |||
def __init__(self, env, callbacks=None): | |||
""" | |||
内部使用的Callback管理类 | |||
:param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | |||
:param List[Callback] callbacks: | |||
@@ -435,26 +436,24 @@ class DistCallbackManager(CallbackManager): | |||
class GradientClipCallback(Callback): | |||
""" | |||
别名::class:`fastNLP.GradientClipCallback` :class:`fastNLP.core.callback.GradientClipCallback` | |||
每次backward前,将parameter的gradient clip到某个范围。 | |||
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。 | |||
如果为None则默认对Trainer的model中所有参数进行clip | |||
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 | |||
:param str clip_type: 支持'norm', 'value' | |||
两种:: | |||
1 'norm', 将gradient的norm rescale到[-clip_value, clip_value] | |||
2 'value', 将gradient限制在[-clip_value, clip_value], | |||
小于-clip_value的gradient被赋值为-clip_value; | |||
大于clip_value的gradient被赋值为clip_value. | |||
""" | |||
def __init__(self, parameters=None, clip_value=1, clip_type='norm'): | |||
""" | |||
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。 | |||
如果为None则默认对Trainer的model中所有参数进行clip | |||
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 | |||
:param str clip_type: 支持'norm', 'value' | |||
两种:: | |||
1 'norm', 将gradient的norm rescale到[-clip_value, clip_value] | |||
2 'value', 将gradient限制在[-clip_value, clip_value], | |||
小于-clip_value的gradient被赋值为-clip_value; | |||
大于clip_value的gradient被赋值为clip_value. | |||
""" | |||
super().__init__() | |||
from torch import nn | |||
@@ -480,14 +479,14 @@ class GradientClipCallback(Callback): | |||
class EarlyStopCallback(Callback): | |||
""" | |||
别名::class:`fastNLP.EarlyStopCallback` :class:`fastNLP.core.callback.EarlyStopCallback` | |||
多少个epoch没有变好就停止训练,相关类 :class:`EarlyStopError` | |||
:param int patience: epoch的数量 | |||
多少个epoch没有变好就停止训练,相关类 :class:`~fastNLP.core.callback.EarlyStopError` | |||
""" | |||
def __init__(self, patience): | |||
""" | |||
:param int patience: epoch的数量 | |||
""" | |||
super(EarlyStopCallback, self).__init__() | |||
self.patience = patience | |||
self.wait = 0 | |||
@@ -511,23 +510,23 @@ class EarlyStopCallback(Callback): | |||
class FitlogCallback(Callback): | |||
""" | |||
别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback` | |||
该callback可将loss和progress写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入 | |||
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。 | |||
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 | |||
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 | |||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要 | |||
传入多个DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。data的结果的名称以'data'开头。 | |||
:param ~fastNLP.Tester,Dict[~fastNLP.Tester] tester: Tester对象,将在on_valid_end时调用。tester的结果的名称以'tester'开头 | |||
:param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值),如果数据集较大建议将该值设置得 | |||
大一些,不然会导致log文件巨大。默认为0, 即不要记录loss。 | |||
:param int verbose: 是否在终端打印evaluation的结果,0不打印。 | |||
:param bool log_exception: fitlog是否记录发生的exception信息 | |||
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。 | |||
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 | |||
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 | |||
""" | |||
def __init__(self, data=None, tester=None, log_loss_every=0, verbose=0, log_exception=False): | |||
""" | |||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要 | |||
传入多个DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。data的结果的名称以'data'开头。 | |||
:param ~fastNLP.Tester,Dict[~fastNLP.Tester] tester: Tester对象,将在on_valid_end时调用。tester的结果的名称以'tester'开头 | |||
:param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值),如果数据集较大建议将该值设置得 | |||
大一些,不然会导致log文件巨大。默认为0, 即不要记录loss。 | |||
:param int verbose: 是否在终端打印evaluation的结果,0不打印。 | |||
:param bool log_exception: fitlog是否记录发生的exception信息 | |||
""" | |||
super().__init__() | |||
self.datasets = {} | |||
self.testers = {} | |||
@@ -610,16 +609,15 @@ class FitlogCallback(Callback): | |||
class EvaluateCallback(Callback): | |||
""" | |||
别名: :class:`fastNLP.EvaluateCallback` :class:`fastNLP.core.callback.EvaluateCallback` | |||
该callback用于扩展Trainer训练过程中只能对dev数据进行验证的问题。 | |||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 | |||
DataSet请通过dict的方式传入。 | |||
:param ~fastNLP.Tester,Dict[~fastNLP.DataSet] tester: Tester对象,将在on_valid_end时调用。 | |||
""" | |||
def __init__(self, data=None, tester=None): | |||
""" | |||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 | |||
DataSet请通过dict的方式传入。 | |||
:param ~fastNLP.Tester,Dict[~fastNLP.DataSet] tester: Tester对象,将在on_valid_end时调用。 | |||
""" | |||
super().__init__() | |||
self.datasets = {} | |||
self.testers = {} | |||
@@ -672,15 +670,13 @@ class EvaluateCallback(Callback): | |||
class LRScheduler(Callback): | |||
""" | |||
别名::class:`fastNLP.LRScheduler` :class:`fastNLP.core.callback.LRScheduler` | |||
对PyTorch LR Scheduler的包装以使得其可以被Trainer所使用 | |||
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler | |||
""" | |||
def __init__(self, lr_scheduler): | |||
""" | |||
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler | |||
""" | |||
super(LRScheduler, self).__init__() | |||
import torch.optim | |||
if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler): | |||
@@ -694,13 +690,13 @@ class LRScheduler(Callback): | |||
class ControlC(Callback): | |||
""" | |||
别名::class:`fastNLP.ControlC` :class:`fastNLP.core.callback.ControlC` | |||
:param bool quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer | |||
检测到 control+C 时的反馈 | |||
""" | |||
def __init__(self, quit_all): | |||
""" | |||
:param bool quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer | |||
""" | |||
super(ControlC, self).__init__() | |||
if type(quit_all) != bool: | |||
raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") | |||
@@ -731,16 +727,15 @@ class SmoothValue(object): | |||
class LRFinder(Callback): | |||
""" | |||
别名::class:`fastNLP.LRFinder` :class:`fastNLP.core.callback.LRFinder` | |||
用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 | |||
:param float start_lr: 学习率下界 | |||
:param float end_lr: 学习率上界 | |||
""" | |||
def __init__(self, start_lr=1e-6, end_lr=10): | |||
""" | |||
:param float start_lr: 学习率下界 | |||
:param float end_lr: 学习率上界 | |||
""" | |||
super(LRFinder, self).__init__() | |||
self.start_lr, self.end_lr = start_lr, end_lr | |||
@@ -803,8 +798,6 @@ class LRFinder(Callback): | |||
class TensorboardCallback(Callback): | |||
""" | |||
别名::class:`fastNLP.TensorboardCallback` :class:`fastNLP.core.callback.TensorboardCallback` | |||
接受以下一个或多个字符串作为参数: | |||
- "model" | |||
- "loss" | |||
@@ -880,13 +873,15 @@ class TensorboardCallback(Callback): | |||
class WarmupCallback(Callback): | |||
""" | |||
按一定的周期调节Learning rate的大小。 | |||
:param int,float warmup: 如果warmup为int,则在该step之前,learning rate根据schedule的策略变化; 如果warmup为float, | |||
如0.1, 则前10%的step是按照schedule策略调整learning rate。 | |||
:param str schedule: 以哪种方式调整。linear: 前warmup的step上升到指定的learning rate(从Trainer中的optimizer处获取的), 后 | |||
warmup的step下降到0; constant前warmup的step上升到指定learning rate,后面的step保持learning rate. | |||
""" | |||
def __init__(self, warmup=0.1, schedule='constant'): | |||
""" | |||
:param int,float warmup: 如果warmup为int,则在该step之前,learning rate根据schedule的策略变化; 如果warmup为float, | |||
如0.1, 则前10%的step是按照schedule策略调整learning rate。 | |||
:param str schedule: 以哪种方式调整。linear: 前warmup的step上升到指定的learning rate(从Trainer中的optimizer处获取的), 后 | |||
warmup的step下降到0; constant前warmup的step上升到指定learning rate,后面的step保持learning rate. | |||
""" | |||
super().__init__() | |||
self.warmup = max(warmup, 0.) | |||
@@ -928,19 +923,23 @@ class WarmupCallback(Callback): | |||
class SaveModelCallback(Callback): | |||
""" | |||
由于Trainer在训练过程中只会保存最佳的模型, 该callback可实现多种方式的结果存储。 | |||
会根据训练开始的时间戳在save_dir下建立文件夹,再在文件夹下存放多个模型 | |||
-save_dir | |||
-2019-07-03-15-06-36 | |||
-epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_performance是性能 | |||
-epoch:1_step:40_{metric_key}:{evaluate_performance}.pt | |||
-2019-07-03-15-10-00 | |||
-epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能 | |||
:param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型 | |||
:param int top: 保存dev表现top多少模型。-1为保存所有模型。 | |||
:param bool only_param: 是否只保存模型d饿权重。 | |||
:param save_on_exception: 发生exception时,是否保存一份发生exception的模型。模型名称为epoch:x_step:x_Exception:{exception_name}. | |||
会根据训练开始的时间戳在save_dir下建立文件夹,再在文件夹下存放多个模型:: | |||
-save_dir | |||
-2019-07-03-15-06-36 | |||
-epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_performance是性能 | |||
-epoch:1_step:40_{metric_key}:{evaluate_performance}.pt | |||
-2019-07-03-15-10-00 | |||
-epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能 | |||
""" | |||
def __init__(self, save_dir, top=3, only_param=False, save_on_exception=False): | |||
""" | |||
:param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型 | |||
:param int top: 保存dev表现top多少模型。-1为保存所有模型。 | |||
:param bool only_param: 是否只保存模型d饿权重。 | |||
:param save_on_exception: 发生exception时,是否保存一份发生exception的模型。模型名称为epoch:x_step:x_Exception:{exception_name}. | |||
""" | |||
super().__init__() | |||
if not os.path.isdir(save_dir): | |||
@@ -1006,11 +1005,13 @@ class SaveModelCallback(Callback): | |||
class CallbackException(BaseException): | |||
""" | |||
当需要通过callback跳出训练的时候可以通过抛出CallbackException并在on_exception中捕获这个值。 | |||
:param str msg: Exception的信息。 | |||
""" | |||
def __init__(self, msg): | |||
""" | |||
:param str msg: Exception的信息。 | |||
""" | |||
super(CallbackException, self).__init__(msg) | |||
@@ -1028,12 +1029,11 @@ class EchoCallback(Callback): | |||
def __init__(self, name, out=sys.stdout): | |||
super(EchoCallback, self).__init__() | |||
self.name = name | |||
self.out = out | |||
self.out = out # deprecated | |||
def __getattribute__(self, item): | |||
if item.startswith('on_'): | |||
logger.info('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid()), | |||
file=self.out) | |||
logger.info('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid())) | |||
return super(EchoCallback, self).__getattribute__(item) | |||
@@ -288,32 +288,33 @@ __all__ = [ | |||
] | |||
import _pickle as pickle | |||
import warnings | |||
from copy import deepcopy | |||
import numpy as np | |||
from copy import deepcopy | |||
from ._logger import logger | |||
from .const import Const | |||
from .field import AppendToTargetOrInputException | |||
from .field import AutoPadder | |||
from .field import FieldArray | |||
from .field import SetInputOrTargetException | |||
from .instance import Instance | |||
from .utils import _get_func_signature | |||
from .field import AppendToTargetOrInputException | |||
from .field import SetInputOrTargetException | |||
from .const import Const | |||
from ._logger import logger | |||
from .utils import pretty_table_printer | |||
from prettytable import PrettyTable | |||
class DataSet(object): | |||
""" | |||
别名::class:`fastNLP.DataSet` :class:`fastNLP.core.dataset.DataSet` | |||
fastNLP的数据容器,详细的使用方法见文档 :doc:`fastNLP.core.dataset` | |||
:param data: 如果为dict类型,则每个key的value应该为等长的list; 如果为list, | |||
每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 | |||
""" | |||
def __init__(self, data=None): | |||
""" | |||
:param data: 如果为dict类型,则每个key的value应该为等长的list; 如果为list, | |||
每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 | |||
""" | |||
self.field_arrays = {} | |||
if data is not None: | |||
if isinstance(data, dict): | |||
@@ -327,26 +328,26 @@ class DataSet(object): | |||
for ins in data: | |||
assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins)) | |||
self.append(ins) | |||
else: | |||
raise ValueError("data only be dict or list type.") | |||
def __contains__(self, item): | |||
return item in self.field_arrays | |||
def __iter__(self): | |||
def iter_func(): | |||
for idx in range(len(self)): | |||
yield self[idx] | |||
return iter_func() | |||
def _inner_iter(self): | |||
class Iter_ptr: | |||
def __init__(self, dataset, idx): | |||
self.dataset = dataset | |||
self.idx = idx | |||
def __getitem__(self, item): | |||
assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[ | |||
self.idx]) | |||
@@ -359,13 +360,13 @@ class DataSet(object): | |||
def __repr__(self): | |||
return self.dataset[self.idx].__repr__() | |||
def inner_iter_func(): | |||
for idx in range(len(self)): | |||
yield Iter_ptr(self, idx) | |||
return inner_iter_func() | |||
def __getitem__(self, idx): | |||
"""给定int的index,返回一个Instance; 给定slice,返回包含这个slice内容的新的DataSet。 | |||
@@ -398,20 +399,20 @@ class DataSet(object): | |||
return dataset | |||
else: | |||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||
def __getattr__(self, item): | |||
# Not tested. Don't use !! | |||
if item == "field_arrays": | |||
raise AttributeError | |||
if isinstance(item, str) and item in self.field_arrays: | |||
return self.field_arrays[item] | |||
def __setstate__(self, state): | |||
self.__dict__ = state | |||
def __getstate__(self): | |||
return self.__dict__ | |||
def __len__(self): | |||
"""Fetch the length of the dataset. | |||
@@ -421,16 +422,66 @@ class DataSet(object): | |||
return 0 | |||
field = iter(self.field_arrays.values()).__next__() | |||
return len(field) | |||
def __inner_repr__(self): | |||
if len(self) < 20: | |||
return ",\n".join([ins.__repr__() for ins in self]) | |||
else: | |||
return self[:5].__inner_repr__() + "\n...\n" + self[-5:].__inner_repr__() | |||
def __repr__(self): | |||
return "DataSet(" + self.__inner_repr__() + ")" | |||
return str(pretty_table_printer(self)) | |||
def print_field_meta(self): | |||
""" | |||
输出当前field的meta信息, 形似下列的输出 | |||
+-------------+-------+-------+ | |||
| field_names | x | y | | |||
+-------------+-------+-------+ | |||
| is_input | True | False | | |||
| is_target | False | False | | |||
| ignore_type | False | | | |||
| pad_value | 0 | | | |||
+-------------+-------+-------+ | |||
field_names: DataSet中field的名称 | |||
is_input: field是否为input | |||
is_target: field是否为target | |||
ignore_type: 是否忽略该field的type, 一般仅在该field至少为input或target时才有意义 | |||
pad_value: 该field的pad的值,仅在该field为input或target时有意义 | |||
:return: | |||
""" | |||
if len(self.field_arrays)>0: | |||
field_names = ['field_names'] | |||
is_inputs = ['is_input'] | |||
is_targets = ['is_target'] | |||
pad_values = ['pad_value'] | |||
ignore_types = ['ignore_type'] | |||
for name, field_array in self.field_arrays.items(): | |||
field_names.append(name) | |||
if field_array.is_input: | |||
is_inputs.append(True) | |||
else: | |||
is_inputs.append(False) | |||
if field_array.is_target: | |||
is_targets.append(True) | |||
else: | |||
is_targets.append(False) | |||
if (field_array.is_input or field_array.is_target) and field_array.padder is not None: | |||
pad_values.append(field_array.padder.get_pad_val()) | |||
else: | |||
pad_values.append(' ') | |||
if field_array._ignore_type: | |||
ignore_types.append(True) | |||
elif field_array.is_input or field_array.is_target: | |||
ignore_types.append(False) | |||
else: | |||
ignore_types.append(' ') | |||
table = PrettyTable(field_names=field_names) | |||
fields = [is_inputs, is_targets, ignore_types, pad_values] | |||
for field in fields: | |||
table.add_row(field) | |||
logger.info(table) | |||
def append(self, instance): | |||
""" | |||
将一个instance对象append到DataSet后面。 | |||
@@ -455,7 +506,7 @@ class DataSet(object): | |||
except AppendToTargetOrInputException as e: | |||
logger.error(f"Cannot append to field:{name}.") | |||
raise e | |||
def add_fieldarray(self, field_name, fieldarray): | |||
""" | |||
将fieldarray添加到DataSet中. | |||
@@ -470,7 +521,7 @@ class DataSet(object): | |||
raise RuntimeError(f"The field to add must have the same size as dataset. " | |||
f"Dataset size {len(self)} != field size {len(fieldarray)}") | |||
self.field_arrays[field_name] = fieldarray | |||
def add_field(self, field_name, fields, padder=AutoPadder(), is_input=False, is_target=False, ignore_type=False): | |||
""" | |||
新增一个field | |||
@@ -482,14 +533,14 @@ class DataSet(object): | |||
:param bool is_target: 新加入的field是否是target | |||
:param bool ignore_type: 是否忽略对新加入的field的类型检查 | |||
""" | |||
if len(self.field_arrays) != 0: | |||
if len(self) != len(fields): | |||
raise RuntimeError(f"The field to add must have the same size as dataset. " | |||
f"Dataset size {len(self)} != field size {len(fields)}") | |||
self.field_arrays[field_name] = FieldArray(field_name, fields, is_target=is_target, is_input=is_input, | |||
padder=padder, ignore_type=ignore_type) | |||
def delete_instance(self, index): | |||
""" | |||
删除第index个instance | |||
@@ -505,7 +556,7 @@ class DataSet(object): | |||
for field in self.field_arrays.values(): | |||
field.pop(index) | |||
return self | |||
def delete_field(self, field_name): | |||
""" | |||
删除名为field_name的field | |||
@@ -539,7 +590,7 @@ class DataSet(object): | |||
if isinstance(field_name, str): | |||
return field_name in self.field_arrays | |||
return False | |||
def get_field(self, field_name): | |||
""" | |||
获取field_name这个field | |||
@@ -550,7 +601,7 @@ class DataSet(object): | |||
if field_name not in self.field_arrays: | |||
raise KeyError("Field name {} not found in DataSet".format(field_name)) | |||
return self.field_arrays[field_name] | |||
def get_all_fields(self): | |||
""" | |||
返回一个dict,key为field_name, value为对应的 :class:`~fastNLP.FieldArray` | |||
@@ -558,7 +609,7 @@ class DataSet(object): | |||
:return dict: 返回如上所述的字典 | |||
""" | |||
return self.field_arrays | |||
def get_field_names(self) -> list: | |||
""" | |||
返回一个list,包含所有 field 的名字 | |||
@@ -566,7 +617,7 @@ class DataSet(object): | |||
:return list: 返回如上所述的列表 | |||
""" | |||
return sorted(self.field_arrays.keys()) | |||
def get_length(self): | |||
""" | |||
获取DataSet的元素数量 | |||
@@ -574,21 +625,21 @@ class DataSet(object): | |||
:return: int: DataSet中Instance的个数。 | |||
""" | |||
return len(self) | |||
def rename_field(self, old_name, new_name): | |||
def rename_field(self, field_name, new_field_name): | |||
""" | |||
将某个field重新命名. | |||
:param str old_name: 原来的field名称。 | |||
:param str new_name: 修改为new_name。 | |||
:param str field_name: 原来的field名称。 | |||
:param str new_field_name: 修改为new_name。 | |||
""" | |||
if old_name in self.field_arrays: | |||
self.field_arrays[new_name] = self.field_arrays.pop(old_name) | |||
self.field_arrays[new_name].name = new_name | |||
if field_name in self.field_arrays: | |||
self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) | |||
self.field_arrays[new_field_name].name = new_field_name | |||
else: | |||
raise KeyError("DataSet has no field named {}.".format(old_name)) | |||
raise KeyError("DataSet has no field named {}.".format(field_name)) | |||
return self | |||
def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): | |||
""" | |||
将field_names的field设置为target | |||
@@ -615,7 +666,7 @@ class DataSet(object): | |||
else: | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
return self | |||
def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): | |||
""" | |||
将field_names的field设置为input:: | |||
@@ -639,7 +690,7 @@ class DataSet(object): | |||
else: | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
return self | |||
def set_ignore_type(self, *field_names, flag=True): | |||
""" | |||
将field设置为忽略类型状态。当某个field被设置了ignore_type, 则在被设置为target或者input时将不进行类型检查, | |||
@@ -656,7 +707,7 @@ class DataSet(object): | |||
else: | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
return self | |||
def set_padder(self, field_name, padder): | |||
""" | |||
为field_name设置padder:: | |||
@@ -672,7 +723,7 @@ class DataSet(object): | |||
raise KeyError("There is no field named {}.".format(field_name)) | |||
self.field_arrays[field_name].set_padder(padder) | |||
return self | |||
def set_pad_val(self, field_name, pad_val): | |||
""" | |||
为某个field设置对应的pad_val. | |||
@@ -684,7 +735,7 @@ class DataSet(object): | |||
raise KeyError("There is no field named {}.".format(field_name)) | |||
self.field_arrays[field_name].set_pad_val(pad_val) | |||
return self | |||
def get_input_name(self): | |||
""" | |||
返回所有is_input被设置为True的field名称 | |||
@@ -692,7 +743,7 @@ class DataSet(object): | |||
:return list: 里面的元素为被设置为input的field名称 | |||
""" | |||
return [name for name, field in self.field_arrays.items() if field.is_input] | |||
def get_target_name(self): | |||
""" | |||
返回所有is_target被设置为True的field名称 | |||
@@ -700,7 +751,7 @@ class DataSet(object): | |||
:return list: 里面的元素为被设置为target的field名称 | |||
""" | |||
return [name for name, field in self.field_arrays.items() if field.is_target] | |||
def apply_field(self, func, field_name, new_field_name=None, **kwargs): | |||
""" | |||
将DataSet中的每个instance中的名为 `field_name` 的field传给func,并获取它的返回值。 | |||
@@ -729,16 +780,16 @@ class DataSet(object): | |||
results.append(func(ins[field_name])) | |||
except Exception as e: | |||
if idx != -1: | |||
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx+1)) | |||
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx + 1)) | |||
raise e | |||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | |||
if new_field_name is not None: | |||
self._add_apply_field(results, new_field_name, kwargs) | |||
return results | |||
def _add_apply_field(self, results, new_field_name, kwargs): | |||
""" | |||
将results作为加入到新的field中,field名称为new_field_name | |||
@@ -770,7 +821,7 @@ class DataSet(object): | |||
self.add_field(field_name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | |||
is_target=extra_param.get("is_target", None), | |||
ignore_type=extra_param.get("ignore_type", False)) | |||
def apply(self, func, new_field_name=None, **kwargs): | |||
""" | |||
将DataSet中每个instance传入到func中,并获取它的返回值. | |||
@@ -802,13 +853,13 @@ class DataSet(object): | |||
# results = [func(ins) for ins in self._inner_iter()] | |||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | |||
if new_field_name is not None: | |||
self._add_apply_field(results, new_field_name, kwargs) | |||
return results | |||
def add_seq_len(self, field_name:str, new_field_name=Const.INPUT_LEN): | |||
def add_seq_len(self, field_name: str, new_field_name=Const.INPUT_LEN): | |||
""" | |||
将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。 | |||
@@ -845,7 +896,7 @@ class DataSet(object): | |||
return dataset | |||
else: | |||
return DataSet() | |||
def split(self, ratio, shuffle=True): | |||
""" | |||
将DataSet按照ratio的比例拆分,返回两个DataSet | |||
@@ -871,9 +922,9 @@ class DataSet(object): | |||
for field_name in self.field_arrays: | |||
train_set.field_arrays[field_name].to(self.field_arrays[field_name]) | |||
dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) | |||
return train_set, dev_set | |||
def save(self, path): | |||
""" | |||
保存DataSet. | |||
@@ -882,7 +933,7 @@ class DataSet(object): | |||
""" | |||
with open(path, 'wb') as f: | |||
pickle.dump(self, f) | |||
@staticmethod | |||
def load(path): | |||
r""" | |||
@@ -53,7 +53,7 @@ class FieldArray: | |||
self.content = _content | |||
self._ignore_type = ignore_type | |||
# 根据input的情况设置input,target等 | |||
self._cell_ndim = None # 多少维度 | |||
self._cell_ndim = None # 多少维度, 如果value是1, dim为0; 如果value是[1, 2], dim=2 | |||
self.dtype = None # 最内层的element都是什么类型的 | |||
self._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) | |||
self._is_input = False | |||
@@ -464,29 +464,30 @@ def _get_ele_type_and_dim(cell: Any, dim=0): | |||
class Padder: | |||
""" | |||
别名::class:`fastNLP.Padder` :class:`fastNLP.core.field.Padder` | |||
所有padder都需要继承这个类,并覆盖__call__方法。 | |||
用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。 | |||
.. py:function:: __call__(self, contents, field_name, field_ele_dtype): | |||
""" | |||
def __init__(self, pad_val=0, **kwargs): | |||
""" | |||
传入的是List内容。假设有以下的DataSet。 | |||
:param List[Any] contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | |||
deepcopy一份。 | |||
:param str, field_name: field的名称。 | |||
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 | |||
:return: np.array([padded_element]) | |||
""" | |||
def __init__(self, pad_val=0, **kwargs): | |||
""" | |||
self.pad_val = pad_val | |||
def set_pad_val(self, pad_val): | |||
self.pad_val = pad_val | |||
def get_pad_val(self): | |||
return self.pad_val | |||
@abstractmethod | |||
def __call__(self, contents, field_name, field_ele_dtype, dim: int): | |||
""" | |||
@@ -534,8 +535,6 @@ class Padder: | |||
class AutoPadder(Padder): | |||
""" | |||
别名::class:`fastNLP.AutoPadder` :class:`fastNLP.core.field.AutoPadder` | |||
根据contents的数据自动判定是否需要做padding。 | |||
1 如果元素类型(元素类型是指field中最里层元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 | |||
@@ -595,7 +594,7 @@ class AutoPadder(Padder): | |||
max_len = max(map(len, contents)) | |||
tensor = torch.full((len(contents), max_len), fill_value=self.pad_val, dtype=field_ele_dtype) | |||
for i, content_i in enumerate(contents): | |||
tensor[i, :len(content_i)] = torch.tensor(content_i) | |||
tensor[i, :len(content_i)] = content_i.clone().detach() | |||
elif dim == 2: | |||
max_len = max(map(len, contents)) | |||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||
@@ -604,7 +603,7 @@ class AutoPadder(Padder): | |||
dtype=field_ele_dtype) | |||
for i, content_i in enumerate(contents): | |||
for j, content_ii in enumerate(content_i): | |||
tensor[i, j, :len(content_ii)] = torch.tensor(content_ii) | |||
tensor[i, j, :len(content_ii)] = content_ii.clone().detach() | |||
else: | |||
shapes = set([np.shape(content_i) for content_i in contents]) | |||
if len(shapes) > 1: | |||
@@ -615,7 +614,7 @@ class AutoPadder(Padder): | |||
tensor = torch.full([len(contents)] + list(shape), fill_value=self.pad_val, | |||
dtype=field_ele_dtype) | |||
for i, content_i in enumerate(contents): | |||
tensor[i] = torch.tensor(content_i, dtype=field_ele_dtype) | |||
tensor[i] = content_i.clone().detach().to(field_ele_dtype) | |||
else: | |||
raise RuntimeError( | |||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
@@ -628,8 +627,6 @@ class AutoPadder(Padder): | |||
class EngChar2DPadder(Padder): | |||
""" | |||
别名::class:`fastNLP.EngChar2DPadder` :class:`fastNLP.core.field.EngChar2DPadder` | |||
用于为英语执行character级别的2D padding操作。对应的field内容应该类似[['T', 'h', 'i', 's'], ['a'], ['d', 'e', 'm', 'o']], | |||
但这个Padder只能处理index为int的情况。 | |||
@@ -3,15 +3,16 @@ instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可 | |||
便于理解的例子可以参考文档 :doc:`fastNLP.core.dataset` 中的表格 | |||
""" | |||
__all__ = [ | |||
"Instance" | |||
] | |||
from .utils import pretty_table_printer | |||
class Instance(object): | |||
""" | |||
别名::class:`fastNLP.Instance` :class:`fastNLP.core.instance.Instance` | |||
Instance是fastNLP中对应一个sample的类。每个sample在fastNLP中是一个Instance对象。 | |||
Instance一般与 :class:`~fastNLP.DataSet` 一起使用, Instance的初始化如下面的Example所示:: | |||
@@ -22,11 +23,11 @@ class Instance(object): | |||
>>>ins.add_field("field_3", [3, 3, 3]) | |||
>>>ins = Instance(**{'x1': 1, 'x2':np.zeros((3, 4))}) | |||
""" | |||
def __init__(self, **fields): | |||
self.fields = fields | |||
def add_field(self, field_name, field): | |||
""" | |||
向Instance中增加一个field | |||
@@ -39,21 +40,19 @@ class Instance(object): | |||
def items(self): | |||
""" | |||
返回一个迭代器,迭代器返回两个内容,第一个内容是field_name, 第二个内容是field_value | |||
:return: | |||
:return: 一个迭代器 | |||
""" | |||
return self.fields.items() | |||
def __getitem__(self, name): | |||
if name in self.fields: | |||
return self.fields[name] | |||
else: | |||
raise KeyError("{} not found".format(name)) | |||
def __setitem__(self, name, field): | |||
return self.add_field(name, field) | |||
def __repr__(self): | |||
s = '\'' | |||
return "{" + ",\n".join( | |||
"\'" + field_name + "\': " + str(self.fields[field_name]) + \ | |||
f" type={(str(type(self.fields[field_name]))).split(s)[1]}" for field_name in self.fields) + "}" | |||
return str(pretty_table_printer(self)) |
@@ -20,7 +20,6 @@ from collections import defaultdict | |||
import torch | |||
import torch.nn.functional as F | |||
from ..core.const import Const | |||
from .utils import _CheckError | |||
from .utils import _CheckRes | |||
from .utils import _build_args | |||
@@ -28,7 +27,7 @@ from .utils import _check_arg_dict_list | |||
from .utils import _check_function_or_method | |||
from .utils import _get_func_signature | |||
from .utils import seq_len_to_mask | |||
import warnings | |||
from ..core.const import Const | |||
class LossBase(object): | |||
@@ -167,8 +166,6 @@ class LossBase(object): | |||
class LossFunc(LossBase): | |||
""" | |||
别名::class:`fastNLP.LossFunc` :class:`fastNLP.core.losses.LossFunc` | |||
提供给用户使用自定义损失函数的类 | |||
:param func: 用户自行定义的损失函数,应当为一个函数或者callable(func)为True的ojbect | |||
@@ -200,8 +197,6 @@ class LossFunc(LossBase): | |||
class CrossEntropyLoss(LossBase): | |||
""" | |||
别名::class:`fastNLP.CrossEntropyLoss` :class:`fastNLP.core.losses.CrossEntropyLoss` | |||
交叉熵损失函数 | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
@@ -238,8 +233,8 @@ class CrossEntropyLoss(LossBase): | |||
pred = pred.tranpose(-1, pred) | |||
pred = pred.reshape(-1, pred.size(-1)) | |||
target = target.reshape(-1) | |||
if seq_len is not None: | |||
mask = seq_len_to_mask(seq_len).reshape(-1).eq(0) | |||
if seq_len is not None and target.dim()>1: | |||
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).reshape(-1).eq(0) | |||
target = target.masked_fill(mask, self.padding_idx) | |||
return F.cross_entropy(input=pred, target=target, | |||
@@ -248,8 +243,6 @@ class CrossEntropyLoss(LossBase): | |||
class L1Loss(LossBase): | |||
""" | |||
别名::class:`fastNLP.L1Loss` :class:`fastNLP.core.losses.L1Loss` | |||
L1损失函数 | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
@@ -270,8 +263,6 @@ class L1Loss(LossBase): | |||
class BCELoss(LossBase): | |||
""" | |||
别名::class:`fastNLP.BCELoss` :class:`fastNLP.core.losses.BCELoss` | |||
二分类交叉熵损失函数 | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
@@ -291,18 +282,18 @@ class BCELoss(LossBase): | |||
class NLLLoss(LossBase): | |||
""" | |||
别名::class:`fastNLP.NLLLoss` :class:`fastNLP.core.losses.NLLLoss` | |||
负对数似然损失函数 | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
:param ignore_idx: ignore的index,在计算loss时将忽略target中标号为ignore_idx的内容, 可以通过该值代替 | |||
传入seq_len. | |||
:param str reduction: 支持 `mean` ,`sum` 和 `none` . | |||
""" | |||
def __init__(self, pred=None, target=None, ignore_idx=-100, reduction='mean'): | |||
""" | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
:param ignore_idx: ignore的index,在计算loss时将忽略target中标号为ignore_idx的内容, 可以通过该值代替 | |||
传入seq_len. | |||
:param str reduction: 支持 `mean` ,`sum` 和 `none` . | |||
""" | |||
super(NLLLoss, self).__init__() | |||
self._init_param_map(pred=pred, target=target) | |||
assert reduction in ('mean', 'sum', 'none') | |||
@@ -315,14 +306,14 @@ class NLLLoss(LossBase): | |||
class LossInForward(LossBase): | |||
""" | |||
别名::class:`fastNLP.LossInForward` :class:`fastNLP.core.losses.LossInForward` | |||
从forward()函数返回结果中获取loss | |||
:param str loss_key: 在forward函数中loss的键名,默认为loss | |||
""" | |||
def __init__(self, loss_key=Const.LOSS): | |||
""" | |||
:param str loss_key: 在forward函数中loss的键名,默认为loss | |||
""" | |||
super().__init__() | |||
if not isinstance(loss_key, str): | |||
raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") | |||
@@ -10,7 +10,10 @@ __all__ = [ | |||
] | |||
import inspect | |||
import warnings | |||
from abc import abstractmethod | |||
from collections import defaultdict | |||
from typing import Union | |||
import numpy as np | |||
import torch | |||
@@ -22,7 +25,6 @@ from .utils import _check_arg_dict_list | |||
from .utils import _get_func_signature | |||
from .utils import seq_len_to_mask | |||
from .vocabulary import Vocabulary | |||
from abc import abstractmethod | |||
class MetricBase(object): | |||
@@ -150,6 +152,7 @@ class MetricBase(object): | |||
def get_metric_name(self): | |||
""" | |||
返回metric的名称 | |||
:return: | |||
""" | |||
return self._metric_name | |||
@@ -293,17 +296,16 @@ class MetricBase(object): | |||
class AccuracyMetric(MetricBase): | |||
""" | |||
别名::class:`fastNLP.AccuracyMetric` :class:`fastNLP.core.metrics.AccuracyMetric` | |||
准确率Metric(其它的Metric参见 :doc:`fastNLP.core.metrics` ) | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
:param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` | |||
""" | |||
def __init__(self, pred=None, target=None, seq_len=None): | |||
""" | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
:param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` | |||
""" | |||
super().__init__() | |||
@@ -336,15 +338,18 @@ class AccuracyMetric(MetricBase): | |||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(seq_len)}.") | |||
if seq_len is not None: | |||
masks = seq_len_to_mask(seq_len=seq_len) | |||
if seq_len is not None and target.dim()>1: | |||
max_len = target.size(1) | |||
masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) | |||
else: | |||
masks = None | |||
if pred.size() == target.size(): | |||
if pred.dim() == target.dim(): | |||
pass | |||
elif len(pred.size()) == len(target.size()) + 1: | |||
elif pred.dim() == target.dim() + 1: | |||
pred = pred.argmax(dim=-1) | |||
if seq_len is None and target.dim()>1: | |||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||
else: | |||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | |||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||
@@ -492,10 +497,75 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | |||
def _get_encoding_type_from_tag_vocab(tag_vocab:Union[Vocabulary, dict])->str: | |||
""" | |||
给定Vocabulary自动判断是哪种类型的encoding, 支持判断bmes, bioes, bmeso, bio | |||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | |||
:return: | |||
""" | |||
tag_set = set() | |||
unk_token = '<unk>' | |||
pad_token = '<pad>' | |||
if isinstance(tag_vocab, Vocabulary): | |||
unk_token = tag_vocab.unknown | |||
pad_token = tag_vocab.padding | |||
tag_vocab = tag_vocab.idx2word | |||
for idx, tag in tag_vocab.items(): | |||
if tag in (unk_token, pad_token): | |||
continue | |||
tag = tag[:1].lower() | |||
tag_set.add(tag) | |||
bmes_tag_set = set('bmes') | |||
if tag_set == bmes_tag_set: | |||
return 'bmes' | |||
bio_tag_set = set('bio') | |||
if tag_set == bio_tag_set: | |||
return 'bio' | |||
bmeso_tag_set = set('bmeso') | |||
if tag_set == bmeso_tag_set: | |||
return 'bmeso' | |||
bioes_tag_set = set('bioes') | |||
if tag_set == bioes_tag_set: | |||
return 'bioes' | |||
raise RuntimeError("encoding_type cannot be inferred automatically. Only support " | |||
"'bio', 'bmes', 'bmeso', 'bioes' type.") | |||
def _check_tag_vocab_and_encoding_type(tag_vocab:Union[Vocabulary, dict], encoding_type:str): | |||
""" | |||
检查vocab中的tag是否与encoding_type是匹配的 | |||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | |||
:param encoding_type: bio, bmes, bioes, bmeso | |||
:return: | |||
""" | |||
tag_set = set() | |||
unk_token = '<unk>' | |||
pad_token = '<pad>' | |||
if isinstance(tag_vocab, Vocabulary): | |||
unk_token = tag_vocab.unknown | |||
pad_token = tag_vocab.padding | |||
tag_vocab = tag_vocab.idx2word | |||
for idx, tag in tag_vocab.items(): | |||
if tag in (unk_token, pad_token): | |||
continue | |||
tag = tag[:1].lower() | |||
tag_set.add(tag) | |||
tags = encoding_type | |||
for tag in tag_set: | |||
assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \ | |||
f"encoding_type." | |||
tags = tags.replace(tag, '') # 删除该值 | |||
if tags: # 如果不为空,说明出现了未使用的tag | |||
warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " | |||
"encoding_type.") | |||
class SpanFPreRecMetric(MetricBase): | |||
r""" | |||
别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric` | |||
在序列标注问题中,以span的方式计算F, pre, rec. | |||
比如中文Part of speech中,会以character的方式进行标注,句子 `中国在亚洲` 对应的POS可能为(以BMES为例) | |||
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 | |||
@@ -518,34 +588,36 @@ class SpanFPreRecMetric(MetricBase): | |||
'rec-label':xxx, | |||
... | |||
} | |||
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | |||
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. | |||
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 | |||
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 | |||
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 | |||
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes | |||
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 | |||
个label | |||
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 | |||
label的f1, pre, rec | |||
:param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` : | |||
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | |||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . | |||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
""" | |||
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, | |||
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type=None, ignore_labels=None, | |||
only_gross=True, f_type='micro', beta=1): | |||
encoding_type = encoding_type.lower() | |||
r""" | |||
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | |||
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. | |||
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 | |||
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 | |||
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 | |||
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes。默认为None,通过tag_vocab自动判断. | |||
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'个label | |||
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个label的f1, pre, rec | |||
:param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` : 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | |||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . 常用为 `beta=0.5, 1, 2` 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
""" | |||
if not isinstance(tag_vocab, Vocabulary): | |||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | |||
if f_type not in ('micro', 'macro'): | |||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||
self.encoding_type = encoding_type | |||
if encoding_type: | |||
encoding_type = encoding_type.lower() | |||
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) | |||
self.encoding_type = encoding_type | |||
else: | |||
self.encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab) | |||
if self.encoding_type == 'bmes': | |||
self.tag_to_span_func = _bmes_tag_to_spans | |||
elif self.encoding_type == 'bio': | |||
@@ -555,7 +627,7 @@ class SpanFPreRecMetric(MetricBase): | |||
elif self.encoding_type == 'bioes': | |||
self.tag_to_span_func = _bioes_tag_to_spans | |||
else: | |||
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") | |||
raise ValueError("Only support 'bio', 'bmes', 'bmeso', 'bioes' type.") | |||
self.ignore_labels = ignore_labels | |||
self.f_type = f_type | |||
@@ -757,24 +829,23 @@ def _pred_topk(y_prob, k=1): | |||
class ExtractiveQAMetric(MetricBase): | |||
r""" | |||
别名::class:`fastNLP.ExtractiveQAMetric` :class:`fastNLP.core.metrics.ExtractiveQAMetric` | |||
抽取式QA(如SQuAD)的metric. | |||
:param pred1: 参数映射表中 `pred1` 的映射关系,None表示映射关系为 `pred1` -> `pred1` | |||
:param pred2: 参数映射表中 `pred2` 的映射关系,None表示映射关系为 `pred2` -> `pred2` | |||
:param target1: 参数映射表中 `target1` 的映射关系,None表示映射关系为 `target1` -> `target1` | |||
:param target2: 参数映射表中 `target2` 的映射关系,None表示映射关系为 `target2` -> `target2` | |||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . | |||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
:param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 | |||
:param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 | |||
""" | |||
def __init__(self, pred1=None, pred2=None, target1=None, target2=None, | |||
beta=1, right_open=True, print_predict_stat=False): | |||
r""" | |||
:param pred1: 参数映射表中 `pred1` 的映射关系,None表示映射关系为 `pred1` -> `pred1` | |||
:param pred2: 参数映射表中 `pred2` 的映射关系,None表示映射关系为 `pred2` -> `pred2` | |||
:param target1: 参数映射表中 `target1` 的映射关系,None表示映射关系为 `target1` -> `target1` | |||
:param target2: 参数映射表中 `target2` 的映射关系,None表示映射关系为 `target2` -> `target2` | |||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . | |||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
:param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 | |||
:param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 | |||
""" | |||
super(ExtractiveQAMetric, self).__init__() | |||
self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) | |||
@@ -9,21 +9,23 @@ __all__ = [ | |||
"AdamW" | |||
] | |||
import torch | |||
import math | |||
import torch | |||
from torch.optim.optimizer import Optimizer as TorchOptimizer | |||
class Optimizer(object): | |||
""" | |||
别名::class:`fastNLP.Optimizer` :class:`fastNLP.core.optimizer.Optimizer` | |||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||
:param kwargs: additional parameters. | |||
Optimizer | |||
""" | |||
def __init__(self, model_params, **kwargs): | |||
""" | |||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||
:param kwargs: additional parameters. | |||
""" | |||
if model_params is not None and not hasattr(model_params, "__next__"): | |||
raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params))) | |||
self.model_params = model_params | |||
@@ -60,14 +62,15 @@ class NullOptimizer(Optimizer): | |||
class SGD(Optimizer): | |||
""" | |||
别名::class:`fastNLP.SGD` :class:`fastNLP.core.optimizer.SGD` | |||
:param float lr: learning rate. Default: 0.01 | |||
:param float momentum: momentum. Default: 0 | |||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||
SGD | |||
""" | |||
def __init__(self, lr=0.001, momentum=0, model_params=None): | |||
""" | |||
:param float lr: learning rate. Default: 0.01 | |||
:param float momentum: momentum. Default: 0 | |||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||
""" | |||
if not isinstance(lr, float): | |||
raise TypeError("learning rate has to be float.") | |||
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) | |||
@@ -82,14 +85,18 @@ class SGD(Optimizer): | |||
class Adam(Optimizer): | |||
""" | |||
别名::class:`fastNLP.Adam` :class:`fastNLP.core.optimizer.Adam` | |||
:param float lr: learning rate | |||
:param float weight_decay: | |||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||
""" | |||
def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None): | |||
""" | |||
:param float lr: learning rate | |||
:param float weight_decay: | |||
:param eps: | |||
:param amsgrad: | |||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||
""" | |||
if not isinstance(lr, float): | |||
raise TypeError("learning rate has to be float.") | |||
super(Adam, self).__init__(model_params, lr=lr, betas=betas, eps=eps, amsgrad=amsgrad, | |||
@@ -105,8 +112,6 @@ class Adam(Optimizer): | |||
class AdamW(TorchOptimizer): | |||
r""" | |||
别名::class:`fastNLP.AdamW` :class:`fastNLP.core.optimizer.AdamW` | |||
对AdamW的实现,该实现应该会在pytorch更高版本中出现,https://github.com/pytorch/pytorch/pull/21250。这里提前加入 | |||
.. todo:: | |||
@@ -115,27 +120,28 @@ class AdamW(TorchOptimizer): | |||
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. | |||
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. | |||
:param params (iterable): iterable of parameters to optimize or dicts defining | |||
parameter groups | |||
:param lr (float, optional): learning rate (default: 1e-3) | |||
:param betas (Tuple[float, float], optional): coefficients used for computing | |||
running averages of gradient and its square (default: (0.9, 0.99)) | |||
:param eps (float, optional): term added to the denominator to improve | |||
numerical stability (default: 1e-8) | |||
:param weight_decay (float, optional): weight decay coefficient (default: 1e-2) | |||
algorithm from the paper `On the Convergence of Adam and Beyond`_ | |||
(default: False) | |||
.. _Adam\: A Method for Stochastic Optimization: | |||
https://arxiv.org/abs/1412.6980 | |||
.. _Decoupled Weight Decay Regularization: | |||
https://arxiv.org/abs/1711.05101 | |||
.. _On the Convergence of Adam and Beyond: | |||
https://openreview.net/forum?id=ryQu7f-RZ | |||
.. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 | |||
.. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 | |||
.. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ | |||
""" | |||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, | |||
weight_decay=1e-2, amsgrad=False): | |||
""" | |||
:param params (iterable): iterable of parameters to optimize or dicts defining | |||
parameter groups | |||
:param lr (float, optional): learning rate (default: 1e-3) | |||
:param betas (Tuple[float, float], optional): coefficients used for computing | |||
running averages of gradient and its square (default: (0.9, 0.99)) | |||
:param eps (float, optional): term added to the denominator to improve | |||
numerical stability (default: 1e-8) | |||
:param weight_decay (float, optional): weight decay coefficient (default: 1e-2) | |||
algorithm from the paper `On the Convergence of Adam and Beyond`_ | |||
(default: False) | |||
""" | |||
if not 0.0 <= lr: | |||
raise ValueError("Invalid learning rate: {}".format(lr)) | |||
if not 0.0 <= eps: | |||
@@ -20,11 +20,13 @@ class Predictor(object): | |||
与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。 | |||
这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。 | |||
:param torch.nn.Module network: 用来完成预测任务的模型 | |||
""" | |||
def __init__(self, network): | |||
""" | |||
:param torch.nn.Module network: 用来完成预测任务的模型 | |||
""" | |||
if not isinstance(network, torch.nn.Module): | |||
raise ValueError( | |||
"Only fastNLP.models.BaseModel or torch.nn,Module is allowed, not {}".format(type(network))) | |||
@@ -15,9 +15,6 @@ import numpy as np | |||
class Sampler(object): | |||
""" | |||
别名::class:`fastNLP.Sampler` :class:`fastNLP.core.sampler.Sampler` | |||
`Sampler` 类的基类. 规定以何种顺序取出data中的元素 | |||
子类必须实现 ``__call__`` 方法. 输入 `DataSet` 对象, 返回其中元素的下标序列 | |||
@@ -33,8 +30,6 @@ class Sampler(object): | |||
class SequentialSampler(Sampler): | |||
""" | |||
别名::class:`fastNLP.SequentialSampler` :class:`fastNLP.core.sampler.SequentialSampler` | |||
顺序取出元素的 `Sampler` | |||
""" | |||
@@ -45,8 +40,6 @@ class SequentialSampler(Sampler): | |||
class RandomSampler(Sampler): | |||
""" | |||
别名::class:`fastNLP.RandomSampler` :class:`fastNLP.core.sampler.RandomSampler` | |||
随机化取元素的 `Sampler` | |||
""" | |||
@@ -57,17 +50,17 @@ class RandomSampler(Sampler): | |||
class BucketSampler(Sampler): | |||
""" | |||
别名::class:`fastNLP.BucketSampler` :class:`fastNLP.core.sampler.BucketSampler` | |||
带Bucket的 `Random Sampler`. 可以随机地取出长度相似的元素 | |||
:param int num_buckets: bucket的数量 | |||
:param int batch_size: batch的大小. 默认为None,Trainer在调用BucketSampler时,会将该值正确设置,如果是非Trainer场景使用,需 | |||
要显示传递该值 | |||
:param str seq_len_field_name: 对应序列长度的 `field` 的名字 | |||
""" | |||
def __init__(self, num_buckets=10, batch_size=None, seq_len_field_name='seq_len'): | |||
""" | |||
:param int num_buckets: bucket的数量 | |||
:param int batch_size: batch的大小. 默认为None,Trainer在调用BucketSampler时,会将该值正确设置,如果是非Trainer场景使用,需 | |||
要显示传递该值 | |||
:param str seq_len_field_name: 对应序列长度的 `field` 的名字 | |||
""" | |||
self.num_buckets = num_buckets | |||
self.batch_size = batch_size | |||
self.seq_len_field_name = seq_len_field_name | |||
@@ -65,33 +65,33 @@ __all__ = [ | |||
class Tester(object): | |||
""" | |||
别名::class:`fastNLP.Tester` :class:`fastNLP.core.tester.Tester` | |||
Tester是在提供数据,模型以及metric的情况下进行性能测试的类。需要传入模型,数据以及metric进行验证。 | |||
:param ~fastNLP.DataSet data: 需要测试的数据集 | |||
:param torch.nn.module model: 使用的模型 | |||
:param ~fastNLP.core.metrics.MetricBase,List[~fastNLP.core.metrics.MetricBase] metrics: 测试时使用的metrics | |||
:param int batch_size: evaluation时使用的batch_size有多大。 | |||
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | |||
的计算位置进行管理。支持以下的输入: | |||
1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中,可见的第一个GPU中,可见的第二个GPU中; | |||
2. torch.device:将模型装载到torch.device上。 | |||
3. int: 将使用device_id为该值的gpu进行训练 | |||
4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。 | |||
5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 | |||
如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 | |||
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | |||
:param bool use_tqdm: 是否使用tqdm来显示测试进度; 如果为False,则不会显示任何内容。 | |||
""" | |||
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True): | |||
""" | |||
:param ~fastNLP.DataSet data: 需要测试的数据集 | |||
:param torch.nn.module model: 使用的模型 | |||
:param ~fastNLP.core.metrics.MetricBase,List[~fastNLP.core.metrics.MetricBase] metrics: 测试时使用的metrics | |||
:param int batch_size: evaluation时使用的batch_size有多大。 | |||
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | |||
的计算位置进行管理。支持以下的输入: | |||
1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中,可见的第一个GPU中,可见的第二个GPU中; | |||
2. torch.device:将模型装载到torch.device上。 | |||
3. int: 将使用device_id为该值的gpu进行训练 | |||
4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。 | |||
5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 | |||
如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 | |||
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | |||
:param bool use_tqdm: 是否使用tqdm来显示测试进度; 如果为False,则不会显示任何内容。 | |||
""" | |||
super(Tester, self).__init__() | |||
if not isinstance(model, nn.Module): | |||
@@ -139,10 +139,9 @@ class Tester(object): | |||
self._predict_func_wrapper = self._model.forward | |||
def test(self): | |||
"""开始进行验证,并返回验证结果。 | |||
r"""开始进行验证,并返回验证结果。 | |||
:return Dict[Dict] : dict的二层嵌套结构,dict的第一层是metric的名称; 第二层是这个metric的指标。 | |||
一个AccuracyMetric的例子为{'AccuracyMetric': {'acc': 1.0}}。 | |||
:return Dict[Dict]: dict的二层嵌套结构,dict的第一层是metric的名称; 第二层是这个metric的指标。一个AccuracyMetric的例子为{'AccuracyMetric': {'acc': 1.0}}。 | |||
""" | |||
# turn on the testing mode; clean up the history | |||
self._model_device = _get_model_device(self._model) | |||
@@ -357,8 +357,6 @@ from ._logger import logger | |||
class Trainer(object): | |||
""" | |||
别名::class:`fastNLP.Trainer` :class:`fastNLP.core.trainer.Trainer` | |||
Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在不同训练任务中重复撰写 | |||
(1) epoch循环; | |||
(2) 将数据分成不同的Batch; | |||
@@ -367,54 +365,6 @@ class Trainer(object): | |||
(5) 保存获得更好验证性能的模型等。 | |||
详细的介绍参见 :doc:`fastNLP.core.trainer` | |||
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 | |||
:param nn.modules model: 待训练的模型 | |||
:param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 | |||
:param int batch_size: 训练和验证的时候的batch大小。 | |||
:param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` | |||
:param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` | |||
:param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch | |||
:param num_workers: int, 有多少个线程来进行数据pad处理。 | |||
:param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 | |||
会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 | |||
:param int n_epochs: 需要优化迭代多少次。 | |||
:param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。 | |||
:param dev_data: 用于做验证的DataSet, :class:`~fastNLP.DataSet` 类型。 | |||
:param metrics: 验证的评估函数。可以只使用一个 :class:`Metric<fastNLP.core.metrics.MetricBase>` , | |||
也可以使用多个 :class:`Metric<fastNLP.core.metrics.MetricBase>` ,通过列表传入。 | |||
如验证时取得了更好的验证结果(如果有多个Metric,以列表中第一个Metric为准),且save_path不为None, | |||
则保存当前模型。Metric种类详见 :doc:`metrics模块 <fastNLP.core.metrics>` 。仅在传入dev_data时有效。 | |||
:param str,None metric_key: :class:`Metric<fastNLP.core.metrics.MetricBase>` 有时会有多个指标, | |||
比如 :class:`~fastNLP.core.metrics.SpanFPreRecMetric` 中包含了'f', 'pre', 'rec'。此时需 | |||
要指定以哪个指标为准。另外有些指标是越小效果越好,比如语言模型的困惑度,这种情况下,在key前面增加一个'-'来表 | |||
明验证时,值越小越好(比如: "-ppl")。仅在传入dev_data时有效。 | |||
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 | |||
:param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存 | |||
最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | |||
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 | |||
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | |||
的计算位置进行管理。支持以下的输入: | |||
1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, | |||
可见的第二个GPU中; | |||
2. torch.device:将模型装载到torch.device上。 | |||
3. int: 将使用device_id为该值的gpu进行训练 | |||
4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。 | |||
5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 | |||
已知可能会出现的问题:Adagrad优化器可能无法正常使用这个参数,请手动管理模型位置。 | |||
:param list(callbacks) callbacks: 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 | |||
通过callback机制实现。 可使用的callback参见 :doc:`callback模块 <fastNLP.core.callback>` | |||
:param int check_code_level: 模型检查等级. -1: 不进行检查; 0: 仅出现错误时停止; 1: 如果有field没有被使用, | |||
报告警告信息; 2: 有任何field没有被使用都报错. 检查的原理是通过使用很小的batch(默认2个sample)来运行代码,但是 | |||
这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个固定值的情况; | |||
(2)模型中存在累加前向计算次数的,可能会多计算1次。以上情况建议将check_code_level设置为-1。 | |||
""" | |||
def __init__(self, train_data, model, optimizer=None, loss=None, | |||
@@ -423,6 +373,56 @@ class Trainer(object): | |||
dev_data=None, metrics=None, metric_key=None, | |||
validate_every=-1, save_path=None, use_tqdm=True, device=None, | |||
callbacks=None, check_code_level=0, **kwargs): | |||
""" | |||
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 | |||
:param nn.modules model: 待训练的模型 | |||
:param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 | |||
:param int batch_size: 训练和验证的时候的batch大小。 | |||
:param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` | |||
:param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` | |||
:param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch | |||
:param num_workers: int, 有多少个线程来进行数据pad处理。 | |||
:param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 | |||
会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 | |||
:param int n_epochs: 需要优化迭代多少次。 | |||
:param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。 | |||
:param dev_data: 用于做验证的DataSet, :class:`~fastNLP.DataSet` 类型。 | |||
:param metrics: 验证的评估函数。可以只使用一个 :class:`Metric<fastNLP.core.metrics.MetricBase>` , | |||
也可以使用多个 :class:`Metric<fastNLP.core.metrics.MetricBase>` ,通过列表传入。 | |||
如验证时取得了更好的验证结果(如果有多个Metric,以列表中第一个Metric为准),且save_path不为None, | |||
则保存当前模型。Metric种类详见 :doc:`metrics模块 <fastNLP.core.metrics>` 。仅在传入dev_data时有效。 | |||
:param str,None metric_key: :class:`Metric<fastNLP.core.metrics.MetricBase>` 有时会有多个指标, | |||
比如 :class:`~fastNLP.core.metrics.SpanFPreRecMetric` 中包含了'f', 'pre', 'rec'。此时需 | |||
要指定以哪个指标为准。另外有些指标是越小效果越好,比如语言模型的困惑度,这种情况下,在key前面增加一个'-'来表 | |||
明验证时,值越小越好(比如: "-ppl")。仅在传入dev_data时有效。 | |||
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 | |||
:param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存 | |||
最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | |||
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 | |||
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | |||
的计算位置进行管理。支持以下的输入: | |||
1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, | |||
可见的第二个GPU中; | |||
2. torch.device:将模型装载到torch.device上。 | |||
3. int: 将使用device_id为该值的gpu进行训练 | |||
4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。 | |||
5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 | |||
已知可能会出现的问题:Adagrad优化器可能无法正常使用这个参数,请手动管理模型位置。 | |||
:param list(callbacks) callbacks: 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 | |||
通过callback机制实现。 可使用的callback参见 :doc:`callback模块 <fastNLP.core.callback>` | |||
:param int check_code_level: 模型检查等级. -1: 不进行检查; 0: 仅出现错误时停止; 1: 如果有field没有被使用, | |||
报告警告信息; 2: 有任何field没有被使用都报错. 检查的原理是通过使用很小的batch(默认2个sample)来运行代码,但是 | |||
这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个固定值的情况; | |||
(2)模型中存在累加前向计算次数的,可能会多计算1次。以上情况建议将check_code_level设置为-1。 | |||
""" | |||
super(Trainer, self).__init__() | |||
if not isinstance(model, nn.Module): | |||
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | |||
@@ -718,7 +718,7 @@ class Trainer(object): | |||
self._save_model(self.model, | |||
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) | |||
elif self._load_best_model: | |||
self._best_model_states = {name: param.cpu().clone() for name, param in self.model.named_parameters()} | |||
self._best_model_states = {name: param.cpu().clone() for name, param in self.model.state_dict().items()} | |||
self.best_dev_perf = res | |||
self.best_dev_epoch = epoch | |||
self.best_dev_step = step | |||
@@ -1,6 +1,7 @@ | |||
""" | |||
utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 | |||
""" | |||
__all__ = [ | |||
"cache_results", | |||
"seq_len_to_mask", | |||
@@ -12,12 +13,12 @@ import inspect | |||
import os | |||
import warnings | |||
from collections import Counter, namedtuple | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
from typing import List | |||
from ._logger import logger | |||
from prettytable import PrettyTable | |||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | |||
'varargs']) | |||
@@ -25,27 +26,27 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require | |||
class Option(dict): | |||
"""a dict can treat keys as attributes""" | |||
def __getattr__(self, item): | |||
try: | |||
return self.__getitem__(item) | |||
except KeyError: | |||
raise AttributeError(item) | |||
def __setattr__(self, key, value): | |||
if key.startswith('__') and key.endswith('__'): | |||
raise AttributeError(key) | |||
self.__setitem__(key, value) | |||
def __delattr__(self, item): | |||
try: | |||
self.pop(item) | |||
except KeyError: | |||
raise AttributeError(item) | |||
def __getstate__(self): | |||
return self | |||
def __setstate__(self, state): | |||
self.update(state) | |||
@@ -66,8 +67,6 @@ def _prepare_cache_filepath(filepath): | |||
def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
""" | |||
别名::class:`fastNLP.cache_results` :class:`fastNLP.core.uitls.cache_results` | |||
cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用:: | |||
import time | |||
@@ -114,13 +113,13 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
:param int _verbose: 是否打印cache的信息。 | |||
:return: | |||
""" | |||
def wrapper_(func): | |||
signature = inspect.signature(func) | |||
for key, _ in signature.parameters.items(): | |||
if key in ('_cache_fp', '_refresh', '_verbose'): | |||
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | |||
def wrapper(*args, **kwargs): | |||
if '_cache_fp' in kwargs: | |||
cache_filepath = kwargs.pop('_cache_fp') | |||
@@ -138,7 +137,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
else: | |||
verbose = _verbose | |||
refresh_flag = True | |||
if cache_filepath is not None and refresh is False: | |||
# load data | |||
if os.path.exists(cache_filepath): | |||
@@ -147,7 +146,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
if verbose == 1: | |||
logger.info("Read cache from {}.".format(cache_filepath)) | |||
refresh_flag = False | |||
if refresh_flag: | |||
results = func(*args, **kwargs) | |||
if cache_filepath is not None: | |||
@@ -157,11 +156,11 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
with open(cache_filepath, 'wb') as f: | |||
_pickle.dump(results, f) | |||
logger.info("Save cache to {}.".format(cache_filepath)) | |||
return results | |||
return wrapper | |||
return wrapper_ | |||
@@ -189,6 +188,7 @@ def _save_model(model, model_name, save_dir, only_param=False): | |||
torch.save(model, model_path) | |||
model.to(_model_device) | |||
def _move_model_to_device(model, device): | |||
""" | |||
将model移动到device | |||
@@ -213,7 +213,7 @@ def _move_model_to_device(model, device): | |||
""" | |||
# if isinstance(model, torch.nn.parallel.DistributedDataParallel): | |||
# raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") | |||
if device is None: | |||
if isinstance(model, torch.nn.DataParallel): | |||
model.cuda() | |||
@@ -222,10 +222,10 @@ def _move_model_to_device(model, device): | |||
if not torch.cuda.is_available() and ( | |||
device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): | |||
raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.") | |||
if isinstance(model, torch.nn.DataParallel): | |||
raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.") | |||
if isinstance(device, int): | |||
assert device > -1, "device can only be non-negative integer" | |||
assert torch.cuda.device_count() > device, "Only has {} gpus, cannot use device {}.".format( | |||
@@ -269,7 +269,7 @@ def _get_model_device(model): | |||
""" | |||
# TODO 这个函数存在一定的风险,因为同一个模型可能存在某些parameter不在显卡中,比如BertEmbedding. 或者跨显卡 | |||
assert isinstance(model, nn.Module) | |||
parameters = list(model.parameters()) | |||
if len(parameters) == 0: | |||
return None | |||
@@ -429,10 +429,10 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False): | |||
""" | |||
if not torch.cuda.is_available(): | |||
return | |||
if not isinstance(device, torch.device): | |||
raise TypeError(f"device must be `torch.device`, got `{type(device)}`") | |||
for arg in args: | |||
if isinstance(arg, dict): | |||
for key, value in arg.items(): | |||
@@ -447,10 +447,10 @@ class _CheckError(Exception): | |||
_CheckError. Used in losses.LossBase, metrics.MetricBase. | |||
""" | |||
def __init__(self, check_res: _CheckRes, func_signature: str): | |||
errs = [f'Problems occurred when calling `{func_signature}`'] | |||
if check_res.varargs: | |||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | |||
if check_res.missing: | |||
@@ -459,9 +459,9 @@ class _CheckError(Exception): | |||
errs.append(f"\tduplicated param: {check_res.duplicated}") | |||
if check_res.unused: | |||
errs.append(f"\tunused param: {check_res.unused}") | |||
Exception.__init__(self, '\n'.join(errs)) | |||
self.check_res = check_res | |||
self.func_signature = func_signature | |||
@@ -481,7 +481,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
# if check_res.varargs: | |||
# errs.append(f"\tvarargs: *{check_res.varargs}") | |||
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | |||
if check_res.unused: | |||
for _unused in check_res.unused: | |||
if _unused in target_dict: | |||
@@ -492,7 +492,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
unuseds.append(f"\tunused field: {_unused_field}") | |||
if _unused_param: | |||
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward | |||
module_name = func_signature.split('.')[0] | |||
if check_res.missing: | |||
errs.append(f"\tmissing param: {check_res.missing}") | |||
@@ -513,7 +513,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
mapped_missing.append(_miss) | |||
else: | |||
unmapped_missing.append(_miss) | |||
for _miss in mapped_missing + unmapped_missing: | |||
if _miss in dataset: | |||
suggestions.append(f"Set `{_miss}` as target.") | |||
@@ -526,17 +526,17 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
else: | |||
_tmp = f'Provide `{_miss}` in DataSet or output of {prev_func_signature}.' | |||
suggestions.append(_tmp) | |||
if check_res.duplicated: | |||
errs.append(f"\tduplicated param: {check_res.duplicated}.") | |||
suggestions.append(f"Delete {check_res.duplicated} in the output of " | |||
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | |||
if len(errs) > 0: | |||
errs.extend(unuseds) | |||
elif check_level == STRICT_CHECK_LEVEL: | |||
errs.extend(unuseds) | |||
if len(errs) > 0: | |||
errs.insert(0, f'Problems occurred when calling {func_signature}') | |||
sugg_str = "" | |||
@@ -563,11 +563,11 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||
check_res = _check_arg_dict_list(forward_func, batch_x) | |||
func_signature = _get_func_signature(forward_func) | |||
errs = [] | |||
suggestions = [] | |||
_unused = [] | |||
# if check_res.varargs: | |||
# errs.append(f"\tvarargs: {check_res.varargs}") | |||
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | |||
@@ -588,14 +588,14 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ | |||
# f"rename the field in `unused field:`." | |||
suggestions.append(_tmp) | |||
if check_res.unused: | |||
_unused = [f"\tunused field: {check_res.unused}"] | |||
if len(errs) > 0: | |||
errs.extend(_unused) | |||
elif check_level == STRICT_CHECK_LEVEL: | |||
errs.extend(_unused) | |||
if len(errs) > 0: | |||
errs.insert(0, f'Problems occurred when calling {func_signature}') | |||
sugg_str = "" | |||
@@ -643,7 +643,7 @@ def seq_len_to_mask(seq_len, max_len=None): | |||
max_len = int(max_len) if max_len else int(seq_len.max()) | |||
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | |||
mask = broad_cast_seq_len < seq_len.reshape(-1, 1) | |||
elif isinstance(seq_len, torch.Tensor): | |||
assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}." | |||
batch_size = seq_len.size(0) | |||
@@ -652,7 +652,7 @@ def seq_len_to_mask(seq_len, max_len=None): | |||
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) | |||
else: | |||
raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.") | |||
return mask | |||
@@ -660,24 +660,25 @@ class _pseudo_tqdm: | |||
""" | |||
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | |||
""" | |||
def __init__(self, **kwargs): | |||
self.logger = logger | |||
def write(self, info): | |||
self.logger.info(info) | |||
def set_postfix_str(self, info): | |||
self.logger.info(info) | |||
def __getattr__(self, item): | |||
def pass_func(*args, **kwargs): | |||
pass | |||
return pass_func | |||
def __enter__(self): | |||
return self | |||
def __exit__(self, exc_type, exc_val, exc_tb): | |||
del self | |||
@@ -751,3 +752,56 @@ def get_seq_len(words, pad_value=0): | |||
""" | |||
mask = words.ne(pad_value) | |||
return mask.sum(dim=-1) | |||
def pretty_table_printer(dataset_or_ins) -> PrettyTable: | |||
""" | |||
:param dataset_or_ins: 传入一个dataSet或者instance | |||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) | |||
+-----------+-----------+-----------------+ | |||
| field_1 | field_2 | field_3 | | |||
+-----------+-----------+-----------------+ | |||
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] | | |||
+-----------+-----------+-----------------+ | |||
:return: 以 pretty table的形式返回根据terminal大小进行自动截断 | |||
""" | |||
x = PrettyTable() | |||
try: | |||
sz = os.get_terminal_size() | |||
column = sz.columns | |||
row = sz.lines | |||
except OSError: | |||
column = 144 | |||
row = 11 | |||
if type(dataset_or_ins).__name__ == "DataSet": | |||
x.field_names = list(dataset_or_ins.field_arrays.keys()) | |||
c_size = len(x.field_names) | |||
for ins in dataset_or_ins: | |||
x.add_row([sub_column(ins[k], column, c_size, k) for k in x.field_names]) | |||
row -= 1 | |||
if row < 0: | |||
x.add_row(["..." for _ in range(c_size)]) | |||
break | |||
elif type(dataset_or_ins).__name__ == "Instance": | |||
x.field_names = list(dataset_or_ins.fields.keys()) | |||
c_size = len(x.field_names) | |||
x.add_row([sub_column(dataset_or_ins[k], column, c_size, k) for k in x.field_names]) | |||
else: | |||
raise Exception("only accept DataSet and Instance") | |||
return x | |||
def sub_column(string: str, c: int, c_size: int, title: str) -> str: | |||
""" | |||
:param string: 要被截断的字符串 | |||
:param c: 命令行列数 | |||
:param c_size: instance或dataset field数 | |||
:param title: 列名 | |||
:return: 对一个过长的列进行截断的结果 | |||
""" | |||
avg = max(int(c / c_size), len(title)) | |||
string = str(string) | |||
if len(string) > avg: | |||
string = string[:(avg - 3)] + "..." | |||
return string |
@@ -39,7 +39,7 @@ def _check_build_vocab(func): | |||
@wraps(func) # to solve missing docstring | |||
def _wrapper(self, *args, **kwargs): | |||
if self.word2idx is None or self.rebuild is True: | |||
if self._word2idx is None or self.rebuild is True: | |||
self.build_vocab() | |||
return func(self, *args, **kwargs) | |||
@@ -66,8 +66,6 @@ def _check_build_status(func): | |||
class Vocabulary(object): | |||
""" | |||
别名::class:`fastNLP.Vocabulary` :class:`fastNLP.core.vocabulary.Vocabulary` | |||
用于构建, 存储和使用 `str` 到 `int` 的一一映射:: | |||
vocab = Vocabulary() | |||
@@ -75,32 +73,52 @@ class Vocabulary(object): | |||
vocab.update(word_list) | |||
vocab["word"] # str to int | |||
vocab.to_word(5) # int to str | |||
:param int max_size: `Vocabulary` 的最大大小, 即能存储词的最大数量 | |||
若为 ``None`` , 则不限制大小. Default: ``None`` | |||
:param int min_freq: 能被记录下的词在文本中的最小出现频率, 应大于或等于 1. | |||
若小于该频率, 词语将被视为 `unknown`. 若为 ``None`` , 所有文本中的词都被记录. Default: ``None`` | |||
:param str optional padding: padding的字符. 如果设置为 ``None`` , | |||
则vocabulary中不考虑padding, 也不计入词表大小,为 ``None`` 的情况多在为label建立Vocabulary的情况. | |||
Default: '<pad>' | |||
:param str optional unknown: unknown的字符,所有未被记录的词在转为 `int` 时将被视为unknown. | |||
如果设置为 ``None`` ,则vocabulary中不考虑unknow, 也不计入词表大小. | |||
为 ``None`` 的情况多在为label建立Vocabulary的情况. | |||
Default: '<unk>' | |||
""" | |||
def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'): | |||
""" | |||
:param int max_size: `Vocabulary` 的最大大小, 即能存储词的最大数量 | |||
若为 ``None`` , 则不限制大小. Default: ``None`` | |||
:param int min_freq: 能被记录下的词在文本中的最小出现频率, 应大于或等于 1. | |||
若小于该频率, 词语将被视为 `unknown`. 若为 ``None`` , 所有文本中的词都被记录. Default: ``None`` | |||
:param str optional padding: padding的字符. 如果设置为 ``None`` , | |||
则vocabulary中不考虑padding, 也不计入词表大小,为 ``None`` 的情况多在为label建立Vocabulary的情况. | |||
Default: '<pad>' | |||
:param str optional unknown: unknown的字符,所有未被记录的词在转为 `int` 时将被视为unknown. | |||
如果设置为 ``None`` ,则vocabulary中不考虑unknow, 也不计入词表大小. | |||
为 ``None`` 的情况多在为label建立Vocabulary的情况. | |||
Default: '<unk>' | |||
""" | |||
self.max_size = max_size | |||
self.min_freq = min_freq | |||
self.word_count = Counter() | |||
self.unknown = unknown | |||
self.padding = padding | |||
self.word2idx = None | |||
self.idx2word = None | |||
self._word2idx = None | |||
self._idx2word = None | |||
self.rebuild = True | |||
# 用于承载不需要单独创建entry的词语,具体见from_dataset()方法 | |||
self._no_create_word = Counter() | |||
@property | |||
@_check_build_vocab | |||
def word2idx(self): | |||
return self._word2idx | |||
@word2idx.setter | |||
def word2idx(self, value): | |||
self._word2idx = value | |||
@property | |||
@_check_build_vocab | |||
def idx2word(self): | |||
return self._idx2word | |||
@idx2word.setter | |||
def idx2word(self, value): | |||
self._word2idx = value | |||
@_check_build_status | |||
def update(self, word_lst, no_create_entry=False): | |||
"""依次增加序列中词在词典中的出现频率 | |||
@@ -187,21 +205,21 @@ class Vocabulary(object): | |||
但已经记录在词典中的词, 不会改变对应的 `int` | |||
""" | |||
if self.word2idx is None: | |||
self.word2idx = {} | |||
if self._word2idx is None: | |||
self._word2idx = {} | |||
if self.padding is not None: | |||
self.word2idx[self.padding] = len(self.word2idx) | |||
self._word2idx[self.padding] = len(self._word2idx) | |||
if self.unknown is not None: | |||
self.word2idx[self.unknown] = len(self.word2idx) | |||
self._word2idx[self.unknown] = len(self._word2idx) | |||
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | |||
words = self.word_count.most_common(max_size) | |||
if self.min_freq is not None: | |||
words = filter(lambda kv: kv[1] >= self.min_freq, words) | |||
if self.word2idx is not None: | |||
words = filter(lambda kv: kv[0] not in self.word2idx, words) | |||
start_idx = len(self.word2idx) | |||
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | |||
if self._word2idx is not None: | |||
words = filter(lambda kv: kv[0] not in self._word2idx, words) | |||
start_idx = len(self._word2idx) | |||
self._word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | |||
self.build_reverse_vocab() | |||
self.rebuild = False | |||
return self | |||
@@ -211,12 +229,12 @@ class Vocabulary(object): | |||
基于 `word to index` dict, 构建 `index to word` dict. | |||
""" | |||
self.idx2word = {i: w for w, i in self.word2idx.items()} | |||
self._idx2word = {i: w for w, i in self._word2idx.items()} | |||
return self | |||
@_check_build_vocab | |||
def __len__(self): | |||
return len(self.word2idx) | |||
return len(self._word2idx) | |||
@_check_build_vocab | |||
def __contains__(self, item): | |||
@@ -226,7 +244,7 @@ class Vocabulary(object): | |||
:param item: the word | |||
:return: True or False | |||
""" | |||
return item in self.word2idx | |||
return item in self._word2idx | |||
def has_word(self, w): | |||
""" | |||
@@ -248,12 +266,12 @@ class Vocabulary(object): | |||
vocab[w] | |||
""" | |||
if w in self.word2idx: | |||
return self.word2idx[w] | |||
if w in self._word2idx: | |||
return self._word2idx[w] | |||
if self.unknown is not None: | |||
return self.word2idx[self.unknown] | |||
return self._word2idx[self.unknown] | |||
else: | |||
raise ValueError("word {} not in vocabulary".format(w)) | |||
raise ValueError("word `{}` not in vocabulary".format(w)) | |||
@_check_build_vocab | |||
def index_dataset(self, *datasets, field_name, new_field_name=None): | |||
@@ -360,7 +378,7 @@ class Vocabulary(object): | |||
try: | |||
dataset.apply(construct_vocab) | |||
except BaseException as e: | |||
log("When processing the `{}` dataset, the following error occurred:".format(idx)) | |||
logger.error("When processing the `{}` dataset, the following error occurred:".format(idx)) | |||
raise e | |||
else: | |||
raise TypeError("Only DataSet type is allowed.") | |||
@@ -386,7 +404,7 @@ class Vocabulary(object): | |||
def to_index(self, w): | |||
""" | |||
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``:: | |||
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出 ``ValueError`` :: | |||
index = vocab.to_index('abc') | |||
# equals to | |||
@@ -405,7 +423,7 @@ class Vocabulary(object): | |||
""" | |||
if self.unknown is None: | |||
return None | |||
return self.word2idx[self.unknown] | |||
return self._word2idx[self.unknown] | |||
@property | |||
@_check_build_vocab | |||
@@ -415,7 +433,7 @@ class Vocabulary(object): | |||
""" | |||
if self.padding is None: | |||
return None | |||
return self.word2idx[self.padding] | |||
return self._word2idx[self.padding] | |||
@_check_build_vocab | |||
def to_word(self, idx): | |||
@@ -425,7 +443,7 @@ class Vocabulary(object): | |||
:param int idx: the index | |||
:return str word: the word | |||
""" | |||
return self.idx2word[idx] | |||
return self._idx2word[idx] | |||
def clear(self): | |||
""" | |||
@@ -434,8 +452,8 @@ class Vocabulary(object): | |||
:return: | |||
""" | |||
self.word_count.clear() | |||
self.word2idx = None | |||
self.idx2word = None | |||
self._word2idx = None | |||
self._idx2word = None | |||
self.rebuild = True | |||
self._no_create_word.clear() | |||
return self | |||
@@ -446,8 +464,8 @@ class Vocabulary(object): | |||
""" | |||
len(self) # make sure vocab has been built | |||
state = self.__dict__.copy() | |||
# no need to pickle idx2word as it can be constructed from word2idx | |||
del state['idx2word'] | |||
# no need to pickle _idx2word as it can be constructed from _word2idx | |||
del state['_idx2word'] | |||
return state | |||
def __setstate__(self, state): | |||
@@ -462,5 +480,5 @@ class Vocabulary(object): | |||
@_check_build_vocab | |||
def __iter__(self): | |||
for word, index in self.word2idx.items(): | |||
for word, index in self._word2idx.items(): | |||
yield word, index |
@@ -0,0 +1,27 @@ | |||
"""undocumented""" | |||
__all__ = [] | |||
import inspect | |||
import sys | |||
def doc_process(m): | |||
for name, obj in inspect.getmembers(m): | |||
if inspect.isclass(obj) or inspect.isfunction(obj): | |||
if obj.__module__ != m.__name__: | |||
if obj.__doc__ is None: | |||
# print(name, obj.__doc__) | |||
pass | |||
else: | |||
module_name = obj.__module__ | |||
while 1: | |||
defined_m = sys.modules[module_name] | |||
if "undocumented" not in defined_m.__doc__ and name in defined_m.__all__: | |||
obj.__doc__ = r"别名 :class:`" + m.__name__ + "." + name + "`" \ | |||
+ " :class:`" + module_name + "." + name + "`\n" + obj.__doc__ | |||
break | |||
module_name = ".".join(module_name.split('.')[:-1]) | |||
if module_name == m.__name__: | |||
# print(name, ": not found defined doc.") | |||
break |
@@ -25,3 +25,7 @@ from .bert_embedding import BertEmbedding, BertWordPieceEncoder | |||
from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding | |||
from .stack_embedding import StackEmbedding | |||
from .utils import get_embeddings | |||
import sys | |||
from ..doc_utils import doc_process | |||
doc_process(sys.modules[__name__]) |
@@ -17,7 +17,7 @@ import numpy as np | |||
from itertools import chain | |||
from ..core.vocabulary import Vocabulary | |||
from ..io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
from ..io.file_utils import PRETRAINED_BERT_MODEL_DIR | |||
from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer | |||
from .contextual_embedding import ContextualEmbedding | |||
import warnings | |||
@@ -26,8 +26,6 @@ from ..core import logger | |||
class BertEmbedding(ContextualEmbedding): | |||
""" | |||
别名::class:`fastNLP.embeddings.BertEmbedding` :class:`fastNLP.embeddings.bert_embedding.BertEmbedding` | |||
使用BERT对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于 | |||
预训练的bert模型长度限制为512个token,而因为输入的word是未进行word piece分割的(word piece的分割有BertEmbedding在输入word | |||
时切分),在分割之后长度可能会超过最大长度限制。 | |||
@@ -68,27 +66,21 @@ class BertEmbedding(ContextualEmbedding): | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', | |||
pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False, | |||
pooled_cls=True, requires_grad: bool = False, auto_truncate: bool = False): | |||
pooled_cls=True, requires_grad: bool = True, auto_truncate: bool = False): | |||
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
# 根据model_dir_or_name检查是否存在并下载 | |||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | |||
if 'cn' in model_dir_or_name.lower() and pool_method not in ('first', 'last'): | |||
logger.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve" | |||
" faster speed.") | |||
warnings.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve" | |||
" faster speed.") | |||
model_url = _get_embedding_url('bert', model_dir_or_name.lower()) | |||
model_dir = cached_path(model_url, name='embedding') | |||
# 检查是否存在 | |||
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||
model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name)) | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
self._word_sep_index = None | |||
if '[SEP]' in vocab: | |||
self._word_sep_index = vocab['[SEP]'] | |||
self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers, | |||
self.model = _WordBertModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | |||
pool_method=pool_method, include_cls_sep=include_cls_sep, | |||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2) | |||
@@ -134,27 +126,6 @@ class BertEmbedding(ContextualEmbedding): | |||
if self._word_sep_index: | |||
words.masked_fill_(sep_mask, self._word_sep_index) | |||
return words | |||
@property | |||
def requires_grad(self): | |||
""" | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
:return: | |||
""" | |||
requires_grads = set([param.requires_grad for name, param in self.named_parameters() | |||
if 'word_pieces_lengths' not in name]) | |||
if len(requires_grads) == 1: | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for name, param in self.named_parameters(): | |||
if 'word_pieces_lengths' in name: # 这个不能加入到requires_grad中 | |||
continue | |||
param.requires_grad = value | |||
class BertWordPieceEncoder(nn.Module): | |||
@@ -171,19 +142,10 @@ class BertWordPieceEncoder(nn.Module): | |||
""" | |||
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, | |||
word_dropout=0, dropout=0, requires_grad: bool = False): | |||
word_dropout=0, dropout=0, requires_grad: bool = True): | |||
super().__init__() | |||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | |||
model_url = _get_embedding_url('bert', model_dir_or_name.lower()) | |||
model_dir = cached_path(model_url, name='embedding') | |||
# 检查是否存在 | |||
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||
model_dir = model_dir_or_name | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers, pooled_cls=pooled_cls) | |||
self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls) | |||
self._sep_index = self.model._sep_index | |||
self._wordpiece_pad_index = self.model._wordpiece_pad_index | |||
self._wordpiece_unk_index = self.model._wordpiece_unknown_index | |||
@@ -192,23 +154,6 @@ class BertWordPieceEncoder(nn.Module): | |||
self.word_dropout = word_dropout | |||
self.dropout_layer = nn.Dropout(dropout) | |||
@property | |||
def requires_grad(self): | |||
""" | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
:return: | |||
""" | |||
requires_grads = set([param.requires_grad for name, param in self.named_parameters()]) | |||
if len(requires_grads) == 1: | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for name, param in self.named_parameters(): | |||
param.requires_grad = value | |||
@property | |||
def embed_size(self): | |||
return self._embed_size | |||
@@ -278,12 +223,12 @@ class BertWordPieceEncoder(nn.Module): | |||
class _WordBertModel(nn.Module): | |||
def __init__(self, model_dir: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | |||
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | |||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): | |||
super().__init__() | |||
self.tokenzier = BertTokenizer.from_pretrained(model_dir) | |||
self.encoder = BertModel.from_pretrained(model_dir) | |||
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) | |||
self.encoder = BertModel.from_pretrained(model_dir_or_name) | |||
self._max_position_embeddings = self.encoder.config.max_position_embeddings | |||
# 检查encoder_layer_number是否合理 | |||
encoder_layer_number = len(self.encoder.encoder.layer) | |||
@@ -303,7 +248,7 @@ class _WordBertModel(nn.Module): | |||
self.auto_truncate = auto_truncate | |||
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] | |||
logger.info("Start to generating word pieces for word.") | |||
logger.info("Start to generate word pieces for word.") | |||
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | |||
word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的 | |||
found_count = 0 | |||
@@ -364,7 +309,7 @@ class _WordBertModel(nn.Module): | |||
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece | |||
logger.info("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab))) | |||
self.word_to_wordpieces = np.array(word_to_wordpieces) | |||
self.word_pieces_lengths = nn.Parameter(torch.LongTensor(word_pieces_lengths), requires_grad=False) | |||
self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths)) | |||
logger.debug("Successfully generate word pieces.") | |||
def forward(self, words): | |||
@@ -389,7 +334,8 @@ class _WordBertModel(nn.Module): | |||
else: | |||
raise RuntimeError( | |||
"After split words into word pieces, the lengths of word pieces are longer than the " | |||
f"maximum allowed sequence length:{self._max_position_embeddings} of bert.") | |||
f"maximum allowed sequence length:{self._max_position_embeddings} of bert. You can set " | |||
f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.") | |||
# +2是由于需要加入[CLS]与[SEP] | |||
word_pieces = words.new_full((batch_size, min(word_piece_length + 2, self._max_position_embeddings)), | |||
@@ -408,7 +354,7 @@ class _WordBertModel(nn.Module): | |||
batch_indexes = torch.arange(batch_size).to(words) | |||
word_pieces[batch_indexes, word_pieces_lengths + 1] = self._sep_index | |||
if self._has_sep_in_vocab: # 但[SEP]在vocab中出现应该才会需要token_ids | |||
sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len | |||
sep_mask = word_pieces.eq(self._sep_index).long() # batch_size x max_len | |||
sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) | |||
token_type_ids = sep_mask_cumsum.fmod(2) | |||
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 | |||
@@ -422,15 +368,26 @@ class _WordBertModel(nn.Module): | |||
# output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size | |||
if self.include_cls_sep: | |||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, | |||
bert_outputs[-1].size(-1)) | |||
s_shift = 1 | |||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, | |||
bert_outputs[-1].size(-1)) | |||
else: | |||
s_shift = 0 | |||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len, | |||
bert_outputs[-1].size(-1)) | |||
s_shift = 0 | |||
batch_word_pieces_cum_length = batch_word_pieces_length.new_zeros(batch_size, max_word_len + 1) | |||
batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len | |||
if self.pool_method == 'first': | |||
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()] | |||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) | |||
_batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | |||
elif self.pool_method == 'last': | |||
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max()+1] - 1 | |||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) | |||
_batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | |||
for l_index, l in enumerate(self.layers): | |||
output_layer = bert_outputs[l] | |||
real_word_piece_length = output_layer.size(1) - 2 | |||
@@ -441,16 +398,15 @@ class _WordBertModel(nn.Module): | |||
output_layer = torch.cat((output_layer, paddings), dim=1).contiguous() | |||
# 从word_piece collapse到word的表示 | |||
truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size | |||
outputs_seq_len = seq_len + s_shift | |||
if self.pool_method == 'first': | |||
for i in range(batch_size): | |||
i_word_pieces_cum_length = batch_word_pieces_cum_length[i, :seq_len[i]] # 每个word的start位置 | |||
outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[ | |||
i, i_word_pieces_cum_length] # num_layer x batch_size x len x hidden_size | |||
tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length] | |||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) | |||
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp | |||
elif self.pool_method == 'last': | |||
for i in range(batch_size): | |||
i_word_pieces_cum_length = batch_word_pieces_cum_length[i, 1:seq_len[i] + 1] - 1 # 每个word的end | |||
outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[i, i_word_pieces_cum_length] | |||
tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length] | |||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) | |||
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp | |||
elif self.pool_method == 'max': | |||
for i in range(batch_size): | |||
for j in range(seq_len[i]): | |||
@@ -467,5 +423,6 @@ class _WordBertModel(nn.Module): | |||
else: | |||
outputs[l_index, :, 0] = output_layer[:, 0] | |||
outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, seq_len + s_shift] | |||
# 3. 最终的embedding结果 | |||
return outputs |
@@ -24,8 +24,6 @@ from ..core import logger | |||
class CNNCharEmbedding(TokenEmbedding): | |||
""" | |||
别名::class:`fastNLP.embeddings.CNNCharEmbedding` :class:`fastNLP.embeddings.char_embedding.CNNCharEmbedding` | |||
使用CNN生成character embedding。CNN的结构为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool -> fc -> Dropout. | |||
不同的kernel大小的fitler结果是concat起来然后通过一层fully connected layer, 然后输出word的表示。 | |||
@@ -89,10 +87,9 @@ class CNNCharEmbedding(TokenEmbedding): | |||
logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||
# 对vocab进行index | |||
max_word_len = max(map(lambda x: len(x[0]), vocab)) | |||
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), max_word_len), | |||
fill_value=self.char_pad_index, dtype=torch.long), | |||
requires_grad=False) | |||
self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False) | |||
self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), max_word_len), | |||
fill_value=self.char_pad_index, dtype=torch.long)) | |||
self.register_buffer('word_lengths', torch.zeros(len(vocab)).long()) | |||
for word, index in vocab: | |||
# if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了。修改为不区分pad, 这样所有的<pad>也是同一个embed | |||
self.words_to_chars_embedding[index, :len(word)] = \ | |||
@@ -109,8 +106,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||
for i in range(len(kernel_sizes))]) | |||
self._embed_size = embed_size | |||
self.fc = nn.Linear(sum(filter_nums), embed_size) | |||
self.reset_parameters() | |||
def forward(self, words): | |||
""" | |||
输入words的index后,生成对应的words的表示。 | |||
@@ -142,46 +138,10 @@ class CNNCharEmbedding(TokenEmbedding): | |||
chars = torch.sum(conv_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float() | |||
chars = self.fc(chars) | |||
return self.dropout(chars) | |||
@property | |||
def requires_grad(self): | |||
""" | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
:return: | |||
""" | |||
params = [] | |||
for name, param in self.named_parameters(): | |||
if 'words_to_chars_embedding' not in name and 'word_lengths' not in name: | |||
params.append(param.requires_grad) | |||
requires_grads = set(params) | |||
if len(requires_grads) == 1: | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for name, param in self.named_parameters(): | |||
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 | |||
continue | |||
param.requires_grad = value | |||
def reset_parameters(self): | |||
for name, param in self.named_parameters(): | |||
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset | |||
continue | |||
if 'char_embedding' in name: | |||
continue | |||
if param.data.dim() > 1: | |||
nn.init.xavier_uniform_(param, 1) | |||
else: | |||
nn.init.uniform_(param, -1, 1) | |||
class LSTMCharEmbedding(TokenEmbedding): | |||
""" | |||
别名::class:`fastNLP.embeddings.LSTMCharEmbedding` :class:`fastNLP.embeddings.char_embedding.LSTMCharEmbedding` | |||
使用LSTM的方式对character进行encode. embed(x) -> Dropout(x) -> LSTM(x) -> activation(x) -> pool -> Dropout | |||
Example:: | |||
@@ -244,10 +204,9 @@ class LSTMCharEmbedding(TokenEmbedding): | |||
logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||
# 对vocab进行index | |||
self.max_word_len = max(map(lambda x: len(x[0]), vocab)) | |||
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len), | |||
fill_value=self.char_pad_index, dtype=torch.long), | |||
requires_grad=False) | |||
self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False) | |||
self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), self.max_word_len), | |||
fill_value=self.char_pad_index, dtype=torch.long)) | |||
self.register_buffer('word_lengths', torch.zeros(len(vocab)).long()) | |||
for word, index in vocab: | |||
# if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否 | |||
self.words_to_chars_embedding[index, :len(word)] = \ | |||
@@ -299,27 +258,3 @@ class LSTMCharEmbedding(TokenEmbedding): | |||
chars = self.fc(chars) | |||
return self.dropout(chars) | |||
@property | |||
def requires_grad(self): | |||
""" | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
:return: | |||
""" | |||
params = [] | |||
for name, param in self.named_parameters(): | |||
if 'words_to_chars_embedding' not in name and 'word_lengths' not in name: | |||
params.append(param) | |||
requires_grads = set(params) | |||
if len(requires_grads) == 1: | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for name, param in self.named_parameters(): | |||
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 | |||
continue | |||
param.requires_grad = value |
@@ -22,8 +22,6 @@ from ..core import logger | |||
class ElmoEmbedding(ContextualEmbedding): | |||
""" | |||
别名::class:`fastNLP.embeddings.ElmoEmbedding` :class:`fastNLP.embeddings.elmo_embedding.ElmoEmbedding` | |||
使用ELMo的embedding。初始化之后,只需要传入words就可以得到对应的embedding。当前支持的使用名称初始化的模型有以下的这些(待补充) | |||
Example:: | |||
@@ -57,7 +55,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||
并删除character encoder,之后将直接使用cache的embedding。默认为False。 | |||
""" | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', layers: str = '2', requires_grad: bool = False, | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', layers: str = '2', requires_grad: bool = True, | |||
word_dropout=0.0, dropout=0.0, cache_word_reprs: bool = False): | |||
super(ElmoEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
@@ -71,6 +69,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs) | |||
num_layers = self.model.encoder.num_layers | |||
if layers == 'mix': | |||
self.layer_weights = nn.Parameter(torch.zeros(self.model.config['lstm']['n_layers'] + 1), | |||
@@ -80,9 +79,9 @@ class ElmoEmbedding(ContextualEmbedding): | |||
self._embed_size = self.model.config['lstm']['projection_dim'] * 2 | |||
else: | |||
layers = list(map(int, layers.split(','))) | |||
assert len(layers) > 0, "Must choose one output" | |||
assert len(layers) > 0, "Must choose at least one output, but got None." | |||
for layer in layers: | |||
assert 0 <= layer <= 2, "Layer index should be in range [0, 2]." | |||
assert 0 <= layer <= num_layers, f"Layer index should be in range [0, {num_layers}], but got {layer}." | |||
self.layers = layers | |||
self._get_outputs = self._get_layer_outputs | |||
self._embed_size = len(self.layers) * self.model.config['lstm']['projection_dim'] * 2 | |||
@@ -137,27 +136,6 @@ class ElmoEmbedding(ContextualEmbedding): | |||
for name in ['layers', 'model', 'layer_weights', 'gamma']: | |||
if hasattr(self, name): | |||
delattr(self, name) | |||
@property | |||
def requires_grad(self): | |||
""" | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
:return: | |||
""" | |||
requires_grads = set([param.requires_grad for name, param in self.named_parameters() | |||
if 'words_to_chars_embedding' not in name and 'words_to_words' not in name]) | |||
if len(requires_grads) == 1: | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for name, param in self.named_parameters(): | |||
if 'words_to_chars_embedding' in name or 'words_to_words' in name: # 这个不能加入到requires_grad中 | |||
continue | |||
param.requires_grad = value | |||
class _ElmoModel(nn.Module): | |||
@@ -246,11 +224,9 @@ class _ElmoModel(nn.Module): | |||
logger.info(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") | |||
# 生成words到chars的映射 | |||
max_chars = config['char_cnn']['max_characters_per_token'] | |||
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab) + 2, max_chars), | |||
self.register_buffer('words_to_chars_embedding', torch.full((len(vocab) + 2, max_chars), | |||
fill_value=len(char_vocab), | |||
dtype=torch.long), | |||
requires_grad=False) | |||
dtype=torch.long)) | |||
for word, index in list(iter(vocab)) + [(BOS_TAG, len(vocab)), (EOS_TAG, len(vocab) + 1)]: | |||
if len(word) + 2 > max_chars: | |||
word = word[:max_chars - 2] | |||
@@ -17,8 +17,6 @@ from .utils import get_embeddings | |||
class Embedding(nn.Module): | |||
""" | |||
别名::class:`fastNLP.embeddings.Embedding` :class:`fastNLP.embeddings.embedding.Embedding` | |||
词向量嵌入,支持输入多种方式初始化. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度. | |||
Example:: | |||
@@ -117,6 +115,10 @@ class Embedding(nn.Module): | |||
class TokenEmbedding(nn.Module): | |||
""" | |||
fastNLP中各种Embedding的基类 | |||
""" | |||
def __init__(self, vocab, word_dropout=0.0, dropout=0.0): | |||
super(TokenEmbedding, self).__init__() | |||
if vocab.rebuild: | |||
@@ -17,17 +17,16 @@ from .embedding import TokenEmbedding | |||
class StackEmbedding(TokenEmbedding): | |||
""" | |||
别名::class:`fastNLP.embeddings.StackEmbedding` :class:`fastNLP.embeddings.stack_embedding.StackEmbedding` | |||
支持将多个embedding集合成一个embedding。 | |||
Example:: | |||
>>> from fastNLP import Vocabulary | |||
>>> from fastNLP.embeddings import StaticEmbedding | |||
>>> from fastNLP.embeddings import StaticEmbedding, StackEmbedding | |||
>>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) | |||
>>> embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d', requires_grad=True) | |||
>>> embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True) | |||
>>> embed = StackEmbedding([embed_1, embed_2]) | |||
:param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致 | |||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。不同embedidng会在相同的位置 | |||
@@ -59,35 +58,26 @@ class StackEmbedding(TokenEmbedding): | |||
:return: | |||
""" | |||
assert isinstance(embed, TokenEmbedding) | |||
self._embed_size += embed.embed_size | |||
self.embeds.append(embed) | |||
return self | |||
def pop(self): | |||
""" | |||
弹出最后一个embed | |||
:return: | |||
""" | |||
return self.embeds.pop() | |||
embed = self.embeds.pop() | |||
self._embed_size -= embed.embed_size | |||
return embed | |||
@property | |||
def embed_size(self): | |||
return self._embed_size | |||
@property | |||
def requires_grad(self): | |||
""" | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
该Embedding输出的vector的最后一维的维度。 | |||
:return: | |||
""" | |||
requires_grads = set([embed.requires_grad for embed in self.embeds()]) | |||
if len(requires_grads) == 1: | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for embed in self.embeds(): | |||
embed.requires_grad = value | |||
return self._embed_size | |||
def forward(self, words): | |||
""" | |||
@@ -24,8 +24,6 @@ from ..core import logger | |||
class StaticEmbedding(TokenEmbedding): | |||
""" | |||
别名::class:`fastNLP.embeddings.StaticEmbedding` :class:`fastNLP.embeddings.static_embedding.StaticEmbedding` | |||
StaticEmbedding组件. 给定预训练embedding的名称或路径,根据vocab从embedding中抽取相应的数据(只会将出现在vocab中的词抽取出来, | |||
如果没有找到,则会随机初始化一个值(但如果该word是被标记为no_create_entry的话,则不会单独创建一个值,而是会被指向unk的index))。 | |||
当前支持自动下载的预训练vector有以下的几种(待补充); | |||
@@ -56,13 +54,16 @@ class StaticEmbedding(TokenEmbedding): | |||
如果输入为None则使用embedding_dim的维度随机初始化一个embedding。 | |||
:param int embedding_dim: 随机初始化的embedding的维度,当该值为大于0的值时,将忽略model_dir_or_name。 | |||
:param bool requires_grad: 是否需要gradient. 默认为True | |||
:param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对 | |||
:param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法, 传入的方法应该接受一个tensor,并 | |||
inplace地修改其值。 | |||
:param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独 | |||
为大写的词语开辟一个vector表示,则将lower设置为False。 | |||
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | |||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||
:param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。 | |||
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 | |||
:param dict **kwarngs: only_train_min_freq, 仅对train中的词语使用min_freq筛选; only_norm_found_vector是否仅对在预训练中 | |||
找到的词语使用normalize。 | |||
""" | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', embedding_dim=-1, requires_grad: bool = True, | |||
@@ -131,28 +132,27 @@ class StaticEmbedding(TokenEmbedding): | |||
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) | |||
else: | |||
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) | |||
self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False) | |||
self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) | |||
if lowered_vocab.unknown: | |||
unknown_idx = lowered_vocab.unknown_idx | |||
else: | |||
unknown_idx = embedding.size(0) - 1 # 否则是最后一个为unknow | |||
self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False) | |||
words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), | |||
requires_grad=False) | |||
self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) | |||
words_to_words = torch.full((len(vocab),), fill_value=unknown_idx).long() | |||
for word, index in vocab: | |||
if word not in lowered_vocab: | |||
word = word.lower() | |||
if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(word): | |||
continue # 如果不需要创建entry,已经默认unknown了 | |||
words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)] | |||
self.words_to_words = words_to_words | |||
self.register_buffer('words_to_words', words_to_words) | |||
self._word_unk_index = lowered_vocab.unknown_idx # 替换一下unknown的index | |||
else: | |||
if model_path: | |||
embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method) | |||
else: | |||
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) | |||
self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False) | |||
self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) | |||
if not self.only_norm_found_vector and normalize: | |||
embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12) | |||
@@ -161,8 +161,7 @@ class StaticEmbedding(TokenEmbedding): | |||
index_in_truncated_vocab = truncated_words_to_words[i] | |||
truncated_words_to_words[i] = self.words_to_words[index_in_truncated_vocab] | |||
del self.words_to_words | |||
self.words_to_words = nn.Parameter(truncated_words_to_words, requires_grad=False) | |||
self.register_buffer('words_to_words', truncated_words_to_words) | |||
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], | |||
padding_idx=vocab.padding_idx, | |||
max_norm=None, norm_type=2, scale_grad_by_freq=False, | |||
@@ -187,27 +186,6 @@ class StaticEmbedding(TokenEmbedding): | |||
return embed | |||
@property | |||
def requires_grad(self): | |||
""" | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
:return: | |||
""" | |||
requires_grads = set([param.requires_grad for name, param in self.named_parameters() | |||
if 'words_to_words' not in name]) | |||
if len(requires_grads) == 1: | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for name, param in self.named_parameters(): | |||
if 'words_to_words' in name: | |||
continue | |||
param.requires_grad = value | |||
def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', | |||
error='ignore', init_method=None): | |||
""" | |||
@@ -283,9 +261,7 @@ class StaticEmbedding(TokenEmbedding): | |||
vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() | |||
else: | |||
unknown_idx = vocab.unknown_idx | |||
self.words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), | |||
requires_grad=False) | |||
self.register_buffer('words_to_words', torch.full((len(vocab), ), fill_value=unknown_idx).long()) | |||
for index, (index_in_vocab, vec) in enumerate(matrix.items()): | |||
if vec is not None: | |||
vectors[index] = vec | |||
@@ -24,6 +24,7 @@ __all__ = [ | |||
'IMDBLoader', | |||
'SSTLoader', | |||
'SST2Loader', | |||
"ChnSentiCorpLoader", | |||
'ConllLoader', | |||
'Conll2003Loader', | |||
@@ -52,8 +53,9 @@ __all__ = [ | |||
"SSTPipe", | |||
"SST2Pipe", | |||
"IMDBPipe", | |||
"Conll2003Pipe", | |||
"ChnSentiCorpPipe", | |||
"Conll2003Pipe", | |||
"Conll2003NERPipe", | |||
"OntoNotesNERPipe", | |||
"MsraNERPipe", | |||
@@ -82,8 +84,11 @@ __all__ = [ | |||
from .embed_loader import EmbedLoader | |||
from .data_bundle import DataBundle | |||
from .dataset_loader import CSVLoader, JsonLoader | |||
from .model_io import ModelLoader, ModelSaver | |||
from .loader import * | |||
from .pipe import * | |||
import sys | |||
from ..doc_utils import doc_process | |||
doc_process(sys.modules[__name__]) |
@@ -6,111 +6,9 @@ __all__ = [ | |||
'DataBundle', | |||
] | |||
import _pickle as pickle | |||
import os | |||
from typing import Union, Dict | |||
from ..core.dataset import DataSet | |||
from ..core.vocabulary import Vocabulary | |||
class BaseLoader(object): | |||
""" | |||
各个 Loader 的基类,提供了 API 的参考。 | |||
""" | |||
def __init__(self): | |||
super(BaseLoader, self).__init__() | |||
@staticmethod | |||
def load_lines(data_path): | |||
""" | |||
按行读取,舍弃每行两侧空白字符,返回list of str | |||
:param data_path: 读取数据的路径 | |||
""" | |||
with open(data_path, "r", encoding="utf=8") as f: | |||
text = f.readlines() | |||
return [line.strip() for line in text] | |||
@classmethod | |||
def load(cls, data_path): | |||
""" | |||
先按行读取,去除一行两侧空白,再提取每行的字符。返回list of list of str | |||
:param data_path: | |||
""" | |||
with open(data_path, "r", encoding="utf-8") as f: | |||
text = f.readlines() | |||
return [[word for word in sent.strip()] for sent in text] | |||
@classmethod | |||
def load_with_cache(cls, data_path, cache_path): | |||
"""缓存版的load | |||
""" | |||
if os.path.isfile(cache_path) and os.path.getmtime(data_path) < os.path.getmtime(cache_path): | |||
with open(cache_path, 'rb') as f: | |||
return pickle.load(f) | |||
else: | |||
obj = cls.load(data_path) | |||
with open(cache_path, 'wb') as f: | |||
pickle.dump(obj, f) | |||
return obj | |||
def _download_from_url(url, path): | |||
try: | |||
from tqdm.auto import tqdm | |||
except: | |||
from ..core.utils import _pseudo_tqdm as tqdm | |||
import requests | |||
"""Download file""" | |||
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) | |||
chunk_size = 16 * 1024 | |||
total_size = int(r.headers.get('Content-length', 0)) | |||
with open(path, "wb") as file, \ | |||
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t: | |||
for chunk in r.iter_content(chunk_size): | |||
if chunk: | |||
file.write(chunk) | |||
t.update(len(chunk)) | |||
def _uncompress(src, dst): | |||
import zipfile | |||
import gzip | |||
import tarfile | |||
import os | |||
def unzip(src, dst): | |||
with zipfile.ZipFile(src, 'r') as f: | |||
f.extractall(dst) | |||
def ungz(src, dst): | |||
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: | |||
length = 16 * 1024 # 16KB | |||
buf = f.read(length) | |||
while buf: | |||
uf.write(buf) | |||
buf = f.read(length) | |||
def untar(src, dst): | |||
with tarfile.open(src, 'r:gz') as f: | |||
f.extractall(dst) | |||
fn, ext = os.path.splitext(src) | |||
_, ext_2 = os.path.splitext(fn) | |||
if ext == '.zip': | |||
unzip(src, dst) | |||
elif ext == '.gz' and ext_2 != '.tar': | |||
ungz(src, dst) | |||
elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz': | |||
untar(src, dst) | |||
else: | |||
raise ValueError('unsupported file {}'.format(src)) | |||
from typing import Union | |||
class DataBundle: | |||
""" | |||
@@ -154,7 +52,7 @@ class DataBundle: | |||
self.datasets[name] = dataset | |||
return self | |||
def get_dataset(self, name:str)->DataSet: | |||
def get_dataset(self, name: str) -> DataSet: | |||
""" | |||
获取名为name的dataset | |||
@@ -163,7 +61,7 @@ class DataBundle: | |||
""" | |||
return self.datasets[name] | |||
def delete_dataset(self, name:str): | |||
def delete_dataset(self, name: str): | |||
""" | |||
删除名为name的DataSet | |||
@@ -173,7 +71,7 @@ class DataBundle: | |||
self.datasets.pop(name, None) | |||
return self | |||
def get_vocab(self, field_name:str)->Vocabulary: | |||
def get_vocab(self, field_name: str) -> Vocabulary: | |||
""" | |||
获取field名为field_name对应的vocab | |||
@@ -182,7 +80,7 @@ class DataBundle: | |||
""" | |||
return self.vocabs[field_name] | |||
def delete_vocab(self, field_name:str): | |||
def delete_vocab(self, field_name: str): | |||
""" | |||
删除vocab | |||
:param str field_name: | |||
@@ -204,7 +102,7 @@ class DataBundle: | |||
行的数据进行类型和维度推断本列的数据的类型和维度。 | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:return self | |||
:return: self | |||
""" | |||
for field_name in field_names: | |||
for name, dataset in self.datasets.items(): | |||
@@ -229,7 +127,7 @@ class DataBundle: | |||
行的数据进行类型和维度推断本列的数据的类型和维度。 | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:return self | |||
:return: self | |||
""" | |||
for field_name in field_names: | |||
for name, dataset in self.datasets.items(): | |||
@@ -241,9 +139,44 @@ class DataBundle: | |||
dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type) | |||
return self | |||
def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True): | |||
""" | |||
将DataBundle中所有的DataSet中名为field_name的Field的padding值设置为pad_val. | |||
:param str field_name: | |||
:param int pad_val: | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:return: self | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
if dataset.has_field(field_name=field_name): | |||
dataset.set_pad_val(field_name=field_name, pad_val=pad_val) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
return self | |||
def set_ignore_type(self, *field_names, flag=True, ignore_miss_dataset=True): | |||
""" | |||
将DataBundle中所有的DataSet中名为*field_names的Field的ignore_type设置为flag状态 | |||
:param str field_names: | |||
:param bool flag: | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:return: self | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
for field_name in field_names: | |||
if dataset.has_field(field_name=field_name): | |||
dataset.set_ignore_type(field_name, flag=flag) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
return self | |||
def copy_field(self, field_name, new_field_name, ignore_miss_dataset=True): | |||
""" | |||
将DataBundle中所有的field_name复制一份叫new_field_name. | |||
将DataBundle中所有的DataSet中名为field_name的Field复制一份并命名为叫new_field_name. | |||
:param str field_name: | |||
:param str new_field_name: | |||
@@ -258,9 +191,79 @@ class DataBundle: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
return self | |||
def rename_field(self, field_name, new_field_name, ignore_miss_dataset=True, rename_vocab=True): | |||
""" | |||
将DataBundle中所有DataSet中名为field_name的field重命名为new_field_name. | |||
:param str field_name: | |||
:param str new_field_name: | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:param bool rename_vocab: 如果该field同时也存在于vocabs中,会将该field的名称对应修改 | |||
:return: self | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
if dataset.has_field(field_name=field_name): | |||
dataset.rename_field(field_name=field_name, new_field_name=new_field_name) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
if rename_vocab: | |||
if field_name in self.vocabs: | |||
self.vocabs[new_field_name] = self.vocabs.pop(field_name) | |||
return self | |||
def delete_field(self, field_name, ignore_miss_dataset=True, delete_vocab=True): | |||
""" | |||
将DataBundle中所有DataSet中名为field_name的field删除掉. | |||
:param str field_name: | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:param bool delete_vocab: 如果该field也在vocabs中存在,将该值也一并删除 | |||
:return: self | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
if dataset.has_field(field_name=field_name): | |||
dataset.delete_field(field_name=field_name) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
if delete_vocab: | |||
if field_name in self.vocabs: | |||
self.vocabs.pop(field_name) | |||
return self | |||
def iter_datasets(self)->Union[str, DataSet]: | |||
""" | |||
迭代data_bundle中的DataSet | |||
Example:: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
pass | |||
:return: | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
yield name, dataset | |||
def iter_vocabs(self)->Union[str, Vocabulary]: | |||
""" | |||
迭代data_bundle中的DataSet | |||
Example: | |||
for field_name, vocab in data_bundle.iter_vocabs(): | |||
pass | |||
:return: | |||
""" | |||
for field_name, vocab in self.vocabs.items(): | |||
yield field_name, vocab | |||
def apply_field(self, func, field_name:str, new_field_name:str, ignore_miss_dataset=True, **kwargs): | |||
""" | |||
对DataBundle中所有的dataset使用apply方法 | |||
对DataBundle中所有的dataset使用apply_field方法 | |||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | |||
:param str field_name: 传入func的是哪个field。 | |||
@@ -303,99 +306,15 @@ class DataBundle: | |||
return self | |||
def __repr__(self): | |||
_str = 'In total {} datasets:\n'.format(len(self.datasets)) | |||
for name, dataset in self.datasets.items(): | |||
_str += '\t{} has {} instances.\n'.format(name, len(dataset)) | |||
_str += 'In total {} vocabs:\n'.format(len(self.vocabs)) | |||
for name, vocab in self.vocabs.items(): | |||
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) | |||
_str = '' | |||
if len(self.datasets): | |||
_str += 'In total {} datasets:\n'.format(len(self.datasets)) | |||
for name, dataset in self.datasets.items(): | |||
_str += '\t{} has {} instances.\n'.format(name, len(dataset)) | |||
if len(self.vocabs): | |||
_str += 'In total {} vocabs:\n'.format(len(self.vocabs)) | |||
for name, vocab in self.vocabs.items(): | |||
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) | |||
return _str | |||
class DataSetLoader: | |||
""" | |||
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` | |||
定义了各种 DataSetLoader 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。 | |||
开发者至少应该编写如下内容: | |||
- _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet` | |||
- load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet` | |||
- process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet` | |||
**process 函数中可以 调用load 函数或 _load 函数** | |||
""" | |||
URL = '' | |||
DATA_DIR = '' | |||
ROOT_DIR = '.fastnlp/datasets/' | |||
UNCOMPRESS = True | |||
def _download(self, url: str, pdir: str, uncompress=True) -> str: | |||
""" | |||
从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 | |||
:param url: 下载的网站 | |||
:param pdir: 下载到的目录 | |||
:param uncompress: 是否自动解压缩 | |||
:return: 数据的存放路径 | |||
""" | |||
fn = os.path.basename(url) | |||
path = os.path.join(pdir, fn) | |||
"""check data exists""" | |||
if not os.path.exists(path): | |||
os.makedirs(pdir, exist_ok=True) | |||
_download_from_url(url, path) | |||
if uncompress: | |||
dst = os.path.join(pdir, 'data') | |||
if not os.path.exists(dst): | |||
_uncompress(path, dst) | |||
return dst | |||
return path | |||
def download(self): | |||
return self._download( | |||
self.URL, | |||
os.path.join(self.ROOT_DIR, self.DATA_DIR), | |||
uncompress=self.UNCOMPRESS) | |||
def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]: | |||
""" | |||
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。 | |||
如果处理多个路径,传入的 dict 中的 key 与返回的 dict 中的 key 保存一致。 | |||
:param Union[str, Dict[str, str]] paths: 文件路径 | |||
:return: :class:`~fastNLP.DataSet` 类的对象或存储多个 :class:`~fastNLP.DataSet` 的字典 | |||
""" | |||
if isinstance(paths, str): | |||
return self._load(paths) | |||
return {name: self._load(path) for name, path in paths.items()} | |||
def _load(self, path: str) -> DataSet: | |||
"""从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象 | |||
:param str path: 文件路径 | |||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||
""" | |||
raise NotImplementedError | |||
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataBundle: | |||
""" | |||
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。 | |||
从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。 | |||
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。 | |||
返回的 :class:`DataBundle` 对象有如下属性: | |||
- vocabs: 由从数据集中获取的词表组成的字典,每个词表 | |||
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` | |||
:param paths: 原始数据读取的路径 | |||
:param options: 根据不同的任务和数据集,设计自己的参数 | |||
:return: 返回一个 DataBundle | |||
""" | |||
raise NotImplementedError |
@@ -1,39 +0,0 @@ | |||
"""undocumented | |||
.. warning:: | |||
本模块在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。 | |||
用于读数据集的模块, 可以读取文本分类、序列标注、Matching任务的数据集 | |||
这些模块的具体介绍如下,您可以通过阅读 :doc:`教程</tutorials/tutorial_2_load_dataset>` 来进行了解。 | |||
""" | |||
__all__ = [ | |||
'ConllLoader', | |||
'Conll2003Loader', | |||
'IMDBLoader', | |||
'MatchingLoader', | |||
'SNLILoader', | |||
'MNLILoader', | |||
'MTL16Loader', | |||
'PeopleDailyCorpusLoader', | |||
'QNLILoader', | |||
'QuoraLoader', | |||
'RTELoader', | |||
'SSTLoader', | |||
'SST2Loader', | |||
'YelpLoader', | |||
] | |||
from .conll import ConllLoader, Conll2003Loader | |||
from .imdb import IMDBLoader | |||
from .matching import MatchingLoader | |||
from .mnli import MNLILoader | |||
from .mtl import MTL16Loader | |||
from .people_daily import PeopleDailyCorpusLoader | |||
from .qnli import QNLILoader | |||
from .quora import QuoraLoader | |||
from .rte import RTELoader | |||
from .snli import SNLILoader | |||
from .sst import SSTLoader, SST2Loader | |||
from .yelp import YelpLoader |
@@ -1,109 +0,0 @@ | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ..data_bundle import DataSetLoader | |||
from ..file_reader import _read_conll | |||
from typing import Union, Dict | |||
from ..utils import check_loader_paths | |||
from ..data_bundle import DataBundle | |||
class ConllLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.data_loader.ConllLoader` | |||
该ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示: | |||
Example:: | |||
# 文件中的内容 | |||
Nadim NNP B-NP B-PER | |||
Ladki NNP I-NP I-PER | |||
AL-AIN NNP B-NP B-LOC | |||
United NNP B-NP B-LOC | |||
Arab NNP I-NP I-LOC | |||
Emirates NNPS I-NP I-LOC | |||
1996-12-06 CD I-NP O | |||
... | |||
# 如果用以下的参数读取,返回的DataSet将包含raw_words和pos两个field, 这两个field的值分别取自于第0列与第1列 | |||
dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll') | |||
# 如果用以下的参数读取,返回的DataSet将包含raw_words和ner两个field, 这两个field的值分别取自于第0列与第2列 | |||
dataset = ConllLoader(headers=['raw_words', 'ner'], indexes=[0, 3])._load('/path/to/train.conll') | |||
# 如果用以下的参数读取,返回的DataSet将包含raw_words, pos和ner三个field | |||
dataset = ConllLoader(headers=['raw_words', 'pos', 'ner'], indexes=[0, 1, 3])._load('/path/to/train.conll') | |||
dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll')中DataSet的raw_words | |||
列与pos列的内容都是List[str] | |||
数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 | |||
:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | |||
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | |||
""" | |||
def __init__(self, headers, indexes=None, dropna=True): | |||
super(ConllLoader, self).__init__() | |||
if not isinstance(headers, (list, tuple)): | |||
raise TypeError( | |||
'invalid headers: {}, should be list of strings'.format(headers)) | |||
self.headers = headers | |||
self.dropna = dropna | |||
if indexes is None: | |||
self.indexes = list(range(len(self.headers))) | |||
else: | |||
if len(indexes) != len(headers): | |||
raise ValueError | |||
self.indexes = indexes | |||
def _load(self, path): | |||
""" | |||
传入的一个文件路径,将该文件读入DataSet中,field由Loader初始化时指定的headers决定。 | |||
:param str path: 文件的路径 | |||
:return: DataSet | |||
""" | |||
ds = DataSet() | |||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | |||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||
ds.append(Instance(**ins)) | |||
return ds | |||
def load(self, paths: Union[str, Dict[str, str]]) -> DataBundle: | |||
""" | |||
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 | |||
读取的field根据ConllLoader初始化时传入的headers决定。 | |||
:param Union[str, Dict[str, str]] paths: | |||
:return: :class:`~fastNLP.DataSet` 类的对象或 :class:`~fastNLP.io.DataBundle` 的字典 | |||
""" | |||
paths = check_loader_paths(paths) | |||
datasets = {name: self._load(path) for name, path in paths.items()} | |||
data_bundle = DataBundle(datasets=datasets) | |||
return data_bundle | |||
class Conll2003Loader(ConllLoader): | |||
""" | |||
别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.data_loader.Conll2003Loader` | |||
该Loader用以读取Conll2003数据,conll2003的数据可以在https://github.com/davidsbatista/NER-datasets/tree/master/CONLL2003 | |||
找到。数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 | |||
返回的DataSet将具有以下['raw_words', 'pos', 'chunks', 'ner']四个field, 每个field中的内容都是List[str]。 | |||
.. csv-table:: Conll2003Loader处理之 :header: "raw_words", "words", "target", "seq_len" | |||
"[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2 | |||
"[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 5 | |||
"[...]", "[...]", "[...]", . | |||
""" | |||
def __init__(self): | |||
headers = [ | |||
'raw_words', 'pos', 'chunks', 'ner', | |||
] | |||
super(Conll2003Loader, self).__init__(headers=headers) |
@@ -1,99 +0,0 @@ | |||
from typing import Union, Dict | |||
from ..embed_loader import EmbeddingOption, EmbedLoader | |||
from ..data_bundle import DataSetLoader, DataBundle | |||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ...core.const import Const | |||
from ..utils import get_tokenizer | |||
class IMDBLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.IMDBLoader` :class:`fastNLP.io.data_loader.IMDBLoader` | |||
读取IMDB数据集,DataSet包含以下fields: | |||
words: list(str), 需要分类的文本 | |||
target: str, 文本的标签 | |||
""" | |||
def __init__(self): | |||
super(IMDBLoader, self).__init__() | |||
self.tokenizer = get_tokenizer() | |||
def _load(self, path): | |||
dataset = DataSet() | |||
with open(path, 'r', encoding="utf-8") as f: | |||
for line in f: | |||
line = line.strip() | |||
if not line: | |||
continue | |||
parts = line.split('\t') | |||
target = parts[0] | |||
words = self.tokenizer(parts[1].lower()) | |||
dataset.append(Instance(words=words, target=target)) | |||
if len(dataset) == 0: | |||
raise RuntimeError(f"{path} has no valid data.") | |||
return dataset | |||
def process(self, | |||
paths: Union[str, Dict[str, str]], | |||
src_vocab_opt: VocabularyOption = None, | |||
tgt_vocab_opt: VocabularyOption = None, | |||
char_level_op=False): | |||
datasets = {} | |||
info = DataBundle() | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
datasets[name] = dataset | |||
def wordtochar(words): | |||
chars = [] | |||
for word in words: | |||
word = word.lower() | |||
for char in word: | |||
chars.append(char) | |||
chars.append('') | |||
chars.pop() | |||
return chars | |||
if char_level_op: | |||
for dataset in datasets.values(): | |||
dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') | |||
datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) | |||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||
src_vocab.from_dataset(datasets['train'], field_name='words') | |||
src_vocab.index_dataset(*datasets.values(), field_name='words') | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||
tgt_vocab.from_dataset(datasets['train'], field_name='target') | |||
tgt_vocab.index_dataset(*datasets.values(), field_name='target') | |||
info.vocabs = { | |||
Const.INPUT: src_vocab, | |||
Const.TARGET: tgt_vocab | |||
} | |||
info.datasets = datasets | |||
for name, dataset in info.datasets.items(): | |||
dataset.set_input(Const.INPUT) | |||
dataset.set_target(Const.TARGET) | |||
return info | |||
@@ -1,248 +0,0 @@ | |||
import os | |||
from typing import Union, Dict, List | |||
from ...core.const import Const | |||
from ...core.vocabulary import Vocabulary | |||
from ..data_bundle import DataBundle, DataSetLoader | |||
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
from ...modules.encoder.bert import BertTokenizer | |||
class MatchingLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.data_loader.MatchingLoader` | |||
读取Matching任务的数据集 | |||
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 | |||
""" | |||
def __init__(self, paths: dict=None): | |||
self.paths = paths | |||
def _load(self, path): | |||
""" | |||
:param str path: 待读取数据集的路径名 | |||
:return: fastNLP.DataSet ds: 返回一个DataSet对象,里面必须包含3个field:其中两个分别为两个句子 | |||
的原始字符串文本,第三个为标签 | |||
""" | |||
raise NotImplementedError | |||
def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, | |||
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, | |||
cut_text: int = None, get_index=True, auto_pad_length: int=None, | |||
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, | |||
set_target: Union[list, str, bool]=True, concat: Union[str, list, bool]=None, | |||
extra_split: List[str]=None, ) -> DataBundle: | |||
""" | |||
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, | |||
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 | |||
对应的全路径文件名。 | |||
:param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义 | |||
这个数据集的名字,如果不定义则默认为train。 | |||
:param bool to_lower: 是否将文本自动转为小写。默认值为False。 | |||
:param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` : | |||
提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和 | |||
attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len | |||
:param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 | |||
:param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 | |||
:param bool get_index: 是否需要根据词表将文本转为index | |||
:param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad | |||
:param str auto_pad_token: 自动pad的内容 | |||
:param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False | |||
则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, | |||
于此同时其他field不会被设置为input。默认值为True。 | |||
:param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。 | |||
:param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个<sep>。 | |||
如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果 | |||
传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]']. | |||
:param extra_split: 额外的分隔符,即除了空格之外的用于分词的字符。 | |||
:return: | |||
""" | |||
if isinstance(set_input, str): | |||
set_input = [set_input] | |||
if isinstance(set_target, str): | |||
set_target = [set_target] | |||
if isinstance(set_input, bool): | |||
auto_set_input = set_input | |||
else: | |||
auto_set_input = False | |||
if isinstance(set_target, bool): | |||
auto_set_target = set_target | |||
else: | |||
auto_set_target = False | |||
if isinstance(paths, str): | |||
if os.path.isdir(paths): | |||
path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()} | |||
else: | |||
path = {dataset_name if dataset_name is not None else 'train': paths} | |||
else: | |||
path = paths | |||
data_info = DataBundle() | |||
for data_name in path.keys(): | |||
data_info.datasets[data_name] = self._load(path[data_name]) | |||
for data_name, data_set in data_info.datasets.items(): | |||
if auto_set_input: | |||
data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) | |||
if auto_set_target: | |||
if Const.TARGET in data_set.get_field_names(): | |||
data_set.set_target(Const.TARGET) | |||
if extra_split is not None: | |||
for data_name, data_set in data_info.datasets.items(): | |||
data_set.apply(lambda x: ' '.join(x[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) | |||
data_set.apply(lambda x: ' '.join(x[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) | |||
for s in extra_split: | |||
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s, ' ' + s + ' '), | |||
new_field_name=Const.INPUTS(0)) | |||
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s, ' ' + s + ' '), | |||
new_field_name=Const.INPUTS(0)) | |||
_filt = lambda x: x | |||
data_set.apply(lambda x: list(filter(_filt, x[Const.INPUTS(0)].split(' '))), | |||
new_field_name=Const.INPUTS(0), is_input=auto_set_input) | |||
data_set.apply(lambda x: list(filter(_filt, x[Const.INPUTS(1)].split(' '))), | |||
new_field_name=Const.INPUTS(1), is_input=auto_set_input) | |||
_filt = None | |||
if to_lower: | |||
for data_name, data_set in data_info.datasets.items(): | |||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0), | |||
is_input=auto_set_input) | |||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1), | |||
is_input=auto_set_input) | |||
if bert_tokenizer is not None: | |||
if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: | |||
PRETRAIN_URL = _get_base_url('bert') | |||
model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] | |||
model_url = PRETRAIN_URL + model_name | |||
model_dir = cached_path(model_url, name='embedding') | |||
# 检查是否存在 | |||
elif os.path.isdir(bert_tokenizer): | |||
model_dir = bert_tokenizer | |||
else: | |||
raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") | |||
words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') | |||
with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: | |||
lines = f.readlines() | |||
lines = [line.strip() for line in lines] | |||
words_vocab.add_word_lst(lines) | |||
words_vocab.build_vocab() | |||
tokenizer = BertTokenizer.from_pretrained(model_dir) | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields, | |||
is_input=auto_set_input) | |||
if isinstance(concat, bool): | |||
concat = 'default' if concat else None | |||
if concat is not None: | |||
if isinstance(concat, str): | |||
CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'], | |||
'default': ['', '<sep>', '', '']} | |||
if concat.lower() in CONCAT_MAP: | |||
concat = CONCAT_MAP[concat] | |||
else: | |||
concat = 4 * [concat] | |||
assert len(concat) == 4, \ | |||
f'Please choose a list with 4 symbols which at the beginning of first sentence ' \ | |||
f'the end of first sentence, the begin of second sentence, and the end of second' \ | |||
f'sentence. Your input is {concat}' | |||
for data_name, data_set in data_info.datasets.items(): | |||
data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] + | |||
x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT) | |||
data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT, | |||
is_input=auto_set_input) | |||
if seq_len_type is not None: | |||
if seq_len_type == 'seq_len': # | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: len(x[fields]), | |||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||
is_input=auto_set_input) | |||
elif seq_len_type == 'mask': | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: [1] * len(x[fields]), | |||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||
is_input=auto_set_input) | |||
elif seq_len_type == 'bert': | |||
for data_name, data_set in data_info.datasets.items(): | |||
if Const.INPUT not in data_set.get_field_names(): | |||
raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: ' | |||
f'got {data_set.get_field_names()}') | |||
data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), | |||
new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input) | |||
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), | |||
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) | |||
if auto_pad_length is not None: | |||
cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length) | |||
if cut_text is not None: | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')): | |||
data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields, | |||
is_input=auto_set_input) | |||
data_set_list = [d for n, d in data_info.datasets.items()] | |||
assert len(data_set_list) > 0, f'There are NO data sets in data info!' | |||
if bert_tokenizer is None: | |||
words_vocab = Vocabulary(padding=auto_pad_token) | |||
words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||
field_name=[n for n in data_set_list[0].get_field_names() | |||
if (Const.INPUT in n)], | |||
no_create_entry_dataset=[d for n, d in data_info.datasets.items() | |||
if 'train' not in n]) | |||
target_vocab = Vocabulary(padding=None, unknown=None) | |||
target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||
field_name=Const.TARGET) | |||
data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} | |||
if get_index: | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, | |||
is_input=auto_set_input) | |||
if Const.TARGET in data_set.get_field_names(): | |||
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, | |||
is_input=auto_set_input, is_target=auto_set_target) | |||
if auto_pad_length is not None: | |||
if seq_len_type == 'seq_len': | |||
raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, ' | |||
f'so the seq_len_type cannot be `{seq_len_type}`!') | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * | |||
(auto_pad_length - len(x[fields])), new_field_name=fields, | |||
is_input=auto_set_input) | |||
elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): | |||
data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])), | |||
new_field_name=fields, is_input=auto_set_input) | |||
for data_name, data_set in data_info.datasets.items(): | |||
if isinstance(set_input, list): | |||
data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) | |||
if isinstance(set_target, list): | |||
data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) | |||
return data_info |
@@ -1,62 +0,0 @@ | |||
from ...core.const import Const | |||
from .matching import MatchingLoader | |||
from ..dataset_loader import CSVLoader | |||
class MNLILoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.data_loader.MNLILoader` | |||
读取MNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev_matched': 'dev_matched.tsv', | |||
'dev_mismatched': 'dev_mismatched.tsv', | |||
'test_matched': 'test_matched.tsv', | |||
'test_mismatched': 'test_mismatched.tsv', | |||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
CSVLoader.__init__(self, sep='\t') | |||
self.fields = { | |||
'sentence1_binary_parse': Const.INPUTS(0), | |||
'sentence2_binary_parse': Const.INPUTS(1), | |||
'gold_label': Const.TARGET, | |||
} | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
if k in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
if Const.TARGET in ds.get_field_names(): | |||
if ds[0][Const.TARGET] == 'hidden': | |||
ds.delete_field(Const.TARGET) | |||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(0)) | |||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(1)) | |||
if Const.TARGET in ds.get_field_names(): | |||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||
return ds |
@@ -1,68 +0,0 @@ | |||
from typing import Union, Dict | |||
from ..data_bundle import DataBundle | |||
from ..dataset_loader import CSVLoader | |||
from ...core.vocabulary import Vocabulary, VocabularyOption | |||
from ...core.const import Const | |||
from ..utils import check_loader_paths | |||
class MTL16Loader(CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.MTL16Loader` :class:`fastNLP.io.data_loader.MTL16Loader` | |||
读取MTL16数据集,DataSet包含以下fields: | |||
words: list(str), 需要分类的文本 | |||
target: str, 文本的标签 | |||
数据来源:https://pan.baidu.com/s/1c2L6vdA | |||
""" | |||
def __init__(self): | |||
super(MTL16Loader, self).__init__(headers=(Const.TARGET, Const.INPUT), sep='\t') | |||
def _load(self, path): | |||
dataset = super(MTL16Loader, self)._load(path) | |||
dataset.apply(lambda x: x[Const.INPUT].lower().split(), new_field_name=Const.INPUT) | |||
if len(dataset) == 0: | |||
raise RuntimeError(f"{path} has no valid data.") | |||
return dataset | |||
def process(self, | |||
paths: Union[str, Dict[str, str]], | |||
src_vocab_opt: VocabularyOption = None, | |||
tgt_vocab_opt: VocabularyOption = None,): | |||
paths = check_loader_paths(paths) | |||
datasets = {} | |||
info = DataBundle() | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
datasets[name] = dataset | |||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT) | |||
src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||
tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET) | |||
tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET) | |||
info.vocabs = { | |||
Const.INPUT: src_vocab, | |||
Const.TARGET: tgt_vocab | |||
} | |||
info.datasets = datasets | |||
for name, dataset in info.datasets.items(): | |||
dataset.set_input(Const.INPUT) | |||
dataset.set_target(Const.TARGET) | |||
return info |
@@ -1,85 +0,0 @@ | |||
from ..data_bundle import DataSetLoader | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ...core.const import Const | |||
class PeopleDailyCorpusLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.PeopleDailyCorpusLoader` :class:`fastNLP.io.data_loader.PeopleDailyCorpusLoader` | |||
读取人民日报数据集 | |||
""" | |||
def __init__(self, pos=True, ner=True): | |||
super(PeopleDailyCorpusLoader, self).__init__() | |||
self.pos = pos | |||
self.ner = ner | |||
def _load(self, data_path): | |||
with open(data_path, "r", encoding="utf-8") as f: | |||
sents = f.readlines() | |||
examples = [] | |||
for sent in sents: | |||
if len(sent) <= 2: | |||
continue | |||
inside_ne = False | |||
sent_pos_tag = [] | |||
sent_words = [] | |||
sent_ner = [] | |||
words = sent.strip().split()[1:] | |||
for word in words: | |||
if "[" in word and "]" in word: | |||
ner_tag = "U" | |||
print(word) | |||
elif "[" in word: | |||
inside_ne = True | |||
ner_tag = "B" | |||
word = word[1:] | |||
elif "]" in word: | |||
ner_tag = "L" | |||
word = word[:word.index("]")] | |||
if inside_ne is True: | |||
inside_ne = False | |||
else: | |||
raise RuntimeError("only ] appears!") | |||
else: | |||
if inside_ne is True: | |||
ner_tag = "I" | |||
else: | |||
ner_tag = "O" | |||
tmp = word.split("/") | |||
token, pos = tmp[0], tmp[1] | |||
sent_ner.append(ner_tag) | |||
sent_pos_tag.append(pos) | |||
sent_words.append(token) | |||
example = [sent_words] | |||
if self.pos is True: | |||
example.append(sent_pos_tag) | |||
if self.ner is True: | |||
example.append(sent_ner) | |||
examples.append(example) | |||
return self.convert(examples) | |||
def convert(self, data): | |||
""" | |||
:param data: python 内置对象 | |||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||
""" | |||
data_set = DataSet() | |||
for item in data: | |||
sent_words = item[0] | |||
if self.pos is True and self.ner is True: | |||
instance = Instance( | |||
words=sent_words, pos_tags=item[1], ner=item[2]) | |||
elif self.pos is True: | |||
instance = Instance(words=sent_words, pos_tags=item[1]) | |||
elif self.ner is True: | |||
instance = Instance(words=sent_words, ner=item[1]) | |||
else: | |||
instance = Instance(words=sent_words) | |||
data_set.append(instance) | |||
data_set.apply(lambda ins: len(ins[Const.INPUT]), new_field_name=Const.INPUT_LEN) | |||
return data_set |
@@ -1,47 +0,0 @@ | |||
from ...core.const import Const | |||
from .matching import MatchingLoader | |||
from ..dataset_loader import CSVLoader | |||
class QNLILoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.data_loader.QNLILoader` | |||
读取QNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
'test': 'test.tsv' # test set has not label | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
self.fields = { | |||
'question': Const.INPUTS(0), | |||
'sentence': Const.INPUTS(1), | |||
'label': Const.TARGET, | |||
} | |||
CSVLoader.__init__(self, sep='\t') | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
if k in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
for fields in ds.get_all_fields(): | |||
if Const.INPUT in fields: | |||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||
return ds |
@@ -1,34 +0,0 @@ | |||
from ...core.const import Const | |||
from .matching import MatchingLoader | |||
from ..dataset_loader import CSVLoader | |||
class QuoraLoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.data_loader.QuoraLoader` | |||
读取MNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
'test': 'test.tsv', | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
return ds |
@@ -1,47 +0,0 @@ | |||
from ...core.const import Const | |||
from .matching import MatchingLoader | |||
from ..dataset_loader import CSVLoader | |||
class RTELoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.data_loader.RTELoader` | |||
读取RTE数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
'test': 'test.tsv' # test set has not label | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
self.fields = { | |||
'sentence1': Const.INPUTS(0), | |||
'sentence2': Const.INPUTS(1), | |||
'label': Const.TARGET, | |||
} | |||
CSVLoader.__init__(self, sep='\t') | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
if k in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
for fields in ds.get_all_fields(): | |||
if Const.INPUT in fields: | |||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||
return ds |
@@ -1,46 +0,0 @@ | |||
from ...core.const import Const | |||
from .matching import MatchingLoader | |||
from ..dataset_loader import JsonLoader | |||
class SNLILoader(MatchingLoader, JsonLoader): | |||
""" | |||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.data_loader.SNLILoader` | |||
读取SNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||
""" | |||
def __init__(self, paths: dict=None): | |||
fields = { | |||
'sentence1_binary_parse': Const.INPUTS(0), | |||
'sentence2_binary_parse': Const.INPUTS(1), | |||
'gold_label': Const.TARGET, | |||
} | |||
paths = paths if paths is not None else { | |||
'train': 'snli_1.0_train.jsonl', | |||
'dev': 'snli_1.0_dev.jsonl', | |||
'test': 'snli_1.0_test.jsonl'} | |||
MatchingLoader.__init__(self, paths=paths) | |||
JsonLoader.__init__(self, fields=fields) | |||
def _load(self, path): | |||
ds = JsonLoader._load(self, path) | |||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(0)) | |||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(1)) | |||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||
return ds |
@@ -1,180 +0,0 @@ | |||
from typing import Union, Dict | |||
from nltk import Tree | |||
from ..data_bundle import DataBundle, DataSetLoader | |||
from ..dataset_loader import CSVLoader | |||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||
from ...core.dataset import DataSet | |||
from ...core.const import Const | |||
from ...core.instance import Instance | |||
from ..utils import check_loader_paths, get_tokenizer | |||
class SSTLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.data_loader.SSTLoader` | |||
读取SST数据集, DataSet包含fields:: | |||
words: list(str) 需要分类的文本 | |||
target: str 文本的标签 | |||
数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip | |||
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` | |||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||
""" | |||
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | |||
DATA_DIR = 'sst/' | |||
def __init__(self, subtree=False, fine_grained=False): | |||
self.subtree = subtree | |||
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', | |||
'3': 'positive', '4': 'very positive'} | |||
if not fine_grained: | |||
tag_v['0'] = tag_v['1'] | |||
tag_v['4'] = tag_v['3'] | |||
self.tag_v = tag_v | |||
self.tokenizer = get_tokenizer() | |||
def _load(self, path): | |||
""" | |||
:param str path: 存储数据的路径 | |||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||
""" | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
datas = [] | |||
for l in f: | |||
datas.extend([(s, self.tag_v[t]) | |||
for s, t in self._get_one(l, self.subtree)]) | |||
ds = DataSet() | |||
for words, tag in datas: | |||
ds.append(Instance(words=words, target=tag)) | |||
return ds | |||
def _get_one(self, data, subtree): | |||
tree = Tree.fromstring(data) | |||
if subtree: | |||
return [(self.tokenizer(' '.join(t.leaves())), t.label()) for t in tree.subtrees() ] | |||
return [(self.tokenizer(' '.join(tree.leaves())), tree.label())] | |||
def process(self, | |||
paths, train_subtree=True, | |||
src_vocab_op: VocabularyOption = None, | |||
tgt_vocab_op: VocabularyOption = None,): | |||
paths = check_loader_paths(paths) | |||
input_name, target_name = 'words', 'target' | |||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | |||
info = DataBundle() | |||
origin_subtree = self.subtree | |||
self.subtree = train_subtree | |||
info.datasets['train'] = self._load(paths['train']) | |||
self.subtree = origin_subtree | |||
for n, p in paths.items(): | |||
if n != 'train': | |||
info.datasets[n] = self._load(p) | |||
src_vocab.from_dataset( | |||
info.datasets['train'], | |||
field_name=input_name, | |||
no_create_entry_dataset=[ds for n, ds in info.datasets.items() if n != 'train']) | |||
tgt_vocab.from_dataset(info.datasets['train'], field_name=target_name) | |||
src_vocab.index_dataset( | |||
*info.datasets.values(), | |||
field_name=input_name, new_field_name=input_name) | |||
tgt_vocab.index_dataset( | |||
*info.datasets.values(), | |||
field_name=target_name, new_field_name=target_name) | |||
info.vocabs = { | |||
input_name: src_vocab, | |||
target_name: tgt_vocab | |||
} | |||
return info | |||
class SST2Loader(CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.SST2Loader` :class:`fastNLP.io.data_loader.SST2Loader` | |||
数据来源 SST: https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8 | |||
""" | |||
def __init__(self): | |||
super(SST2Loader, self).__init__(sep='\t') | |||
self.tokenizer = get_tokenizer() | |||
self.field = {'sentence': Const.INPUT, 'label': Const.TARGET} | |||
def _load(self, path: str) -> DataSet: | |||
ds = super(SST2Loader, self)._load(path) | |||
for k, v in self.field.items(): | |||
if k in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
ds.apply(lambda x: self.tokenizer(x[Const.INPUT]), new_field_name=Const.INPUT) | |||
print("all count:", len(ds)) | |||
return ds | |||
def process(self, | |||
paths: Union[str, Dict[str, str]], | |||
src_vocab_opt: VocabularyOption = None, | |||
tgt_vocab_opt: VocabularyOption = None, | |||
char_level_op=False): | |||
paths = check_loader_paths(paths) | |||
datasets = {} | |||
info = DataBundle() | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
dataset.apply_field(lambda words:words.copy(), field_name='words', new_field_name='raw_words') | |||
datasets[name] = dataset | |||
def wordtochar(words): | |||
chars = [] | |||
for word in words: | |||
word = word.lower() | |||
for char in word: | |||
chars.append(char) | |||
chars.append('') | |||
chars.pop() | |||
return chars | |||
input_name, target_name = Const.INPUT, Const.TARGET | |||
info.vocabs={} | |||
# 就分隔为char形式 | |||
if char_level_op: | |||
for dataset in datasets.values(): | |||
dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) | |||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT, no_create_entry_dataset=[ | |||
dataset for name, dataset in datasets.items() if name!='train' | |||
]) | |||
src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||
tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET) | |||
tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET) | |||
info.vocabs = { | |||
Const.INPUT: src_vocab, | |||
Const.TARGET: tgt_vocab | |||
} | |||
info.datasets = datasets | |||
for name, dataset in info.datasets.items(): | |||
dataset.set_input(Const.INPUT) | |||
dataset.set_target(Const.TARGET) | |||
return info | |||
@@ -1,132 +0,0 @@ | |||
import csv | |||
from typing import Iterable | |||
from ...core.const import Const | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||
from ..data_bundle import DataBundle, DataSetLoader | |||
from typing import Union, Dict | |||
from ..utils import check_loader_paths, get_tokenizer | |||
class YelpLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.YelpLoader` :class:`fastNLP.io.data_loader.YelpLoader` | |||
读取Yelp_full/Yelp_polarity数据集, DataSet包含fields: | |||
words: list(str), 需要分类的文本 | |||
target: str, 文本的标签 | |||
chars:list(str),未index的字符列表 | |||
数据集:yelp_full/yelp_polarity | |||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||
:param lower: 是否需要自动转小写,默认为False。 | |||
""" | |||
def __init__(self, fine_grained=False, lower=False): | |||
super(YelpLoader, self).__init__() | |||
tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', | |||
'4.0': 'positive', '5.0': 'very positive'} | |||
if not fine_grained: | |||
tag_v['1.0'] = tag_v['2.0'] | |||
tag_v['5.0'] = tag_v['4.0'] | |||
self.fine_grained = fine_grained | |||
self.tag_v = tag_v | |||
self.lower = lower | |||
self.tokenizer = get_tokenizer() | |||
def _load(self, path): | |||
ds = DataSet() | |||
csv_reader = csv.reader(open(path, encoding='utf-8')) | |||
all_count = 0 | |||
real_count = 0 | |||
for row in csv_reader: | |||
all_count += 1 | |||
if len(row) == 2: | |||
target = self.tag_v[row[0] + ".0"] | |||
words = clean_str(row[1], self.tokenizer, self.lower) | |||
if len(words) != 0: | |||
ds.append(Instance(words=words, target=target)) | |||
real_count += 1 | |||
print("all count:", all_count) | |||
print("real count:", real_count) | |||
return ds | |||
def process(self, paths: Union[str, Dict[str, str]], | |||
train_ds: Iterable[str] = None, | |||
src_vocab_op: VocabularyOption = None, | |||
tgt_vocab_op: VocabularyOption = None, | |||
char_level_op=False): | |||
paths = check_loader_paths(paths) | |||
info = DataBundle(datasets=self.load(paths)) | |||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | |||
_train_ds = [info.datasets[name] | |||
for name in train_ds] if train_ds else info.datasets.values() | |||
def wordtochar(words): | |||
chars = [] | |||
for word in words: | |||
word = word.lower() | |||
for char in word: | |||
chars.append(char) | |||
chars.append('') | |||
chars.pop() | |||
return chars | |||
input_name, target_name = Const.INPUT, Const.TARGET | |||
info.vocabs = {} | |||
# 就分隔为char形式 | |||
if char_level_op: | |||
for dataset in info.datasets.values(): | |||
dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) | |||
else: | |||
src_vocab.from_dataset(*_train_ds, field_name=input_name) | |||
src_vocab.index_dataset(*info.datasets.values(), field_name=input_name, new_field_name=input_name) | |||
info.vocabs[input_name] = src_vocab | |||
tgt_vocab.from_dataset(*_train_ds, field_name=target_name) | |||
tgt_vocab.index_dataset( | |||
*info.datasets.values(), | |||
field_name=target_name, new_field_name=target_name) | |||
info.vocabs[target_name] = tgt_vocab | |||
info.datasets['train'], info.datasets['dev'] = info.datasets['train'].split(0.1, shuffle=False) | |||
for name, dataset in info.datasets.items(): | |||
dataset.set_input(Const.INPUT) | |||
dataset.set_target(Const.TARGET) | |||
return info | |||
def clean_str(sentence, tokenizer, char_lower=False): | |||
""" | |||
heavily borrowed from github | |||
https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb | |||
:param sentence: is a str | |||
:return: | |||
""" | |||
if char_lower: | |||
sentence = sentence.lower() | |||
import re | |||
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | |||
words = tokenizer(sentence) | |||
words_collection = [] | |||
for word in words: | |||
if word in ['-lrb-', '-rrb-', '<sssss>', '-r', '-l', 'b-']: | |||
continue | |||
tt = nonalpnum.split(word) | |||
t = ''.join(tt) | |||
if t != '': | |||
words_collection.append(t) | |||
return words_collection | |||
@@ -1,121 +0,0 @@ | |||
"""undocumented | |||
.. warning:: | |||
本模块将在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。 | |||
dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的数据, 并返回 `DataSet` , | |||
得到的 :class:`~fastNLP.DataSet` 对象可以直接传入 :class:`~fastNLP.Trainer` 和 :class:`~fastNLP.Tester`, 用于模型的训练和测试。 | |||
以SNLI数据集为例:: | |||
loader = SNLILoader() | |||
train_ds = loader.load('path/to/train') | |||
dev_ds = loader.load('path/to/dev') | |||
test_ds = loader.load('path/to/test') | |||
# ... do stuff | |||
为 fastNLP 提供 DataSetLoader 的开发者请参考 :class:`~fastNLP.io.DataSetLoader` 的介绍。 | |||
""" | |||
__all__ = [ | |||
'CSVLoader', | |||
'JsonLoader', | |||
] | |||
from .data_bundle import DataSetLoader | |||
from .file_reader import _read_csv, _read_json | |||
from ..core.dataset import DataSet | |||
from ..core.instance import Instance | |||
class JsonLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.dataset_loader.JsonLoader` | |||
读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 | |||
:param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name | |||
``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` , | |||
`value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名 | |||
``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` | |||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||
Default: ``False`` | |||
""" | |||
def __init__(self, fields=None, dropna=False): | |||
super(JsonLoader, self).__init__() | |||
self.dropna = dropna | |||
self.fields = None | |||
self.fields_list = None | |||
if fields: | |||
self.fields = {} | |||
for k, v in fields.items(): | |||
self.fields[k] = k if v is None else v | |||
self.fields_list = list(self.fields.keys()) | |||
def _load(self, path): | |||
ds = DataSet() | |||
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | |||
if self.fields: | |||
ins = {self.fields[k]: v for k, v in d.items()} | |||
else: | |||
ins = d | |||
ds.append(Instance(**ins)) | |||
return ds | |||
class CSVLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader` | |||
读取CSV格式的数据集。返回 ``DataSet`` | |||
:param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 | |||
若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` | |||
:param str sep: CSV文件中列与列之间的分隔符. Default: "," | |||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||
Default: ``False`` | |||
""" | |||
def __init__(self, headers=None, sep=",", dropna=False): | |||
self.headers = headers | |||
self.sep = sep | |||
self.dropna = dropna | |||
def _load(self, path): | |||
ds = DataSet() | |||
for idx, data in _read_csv(path, headers=self.headers, | |||
sep=self.sep, dropna=self.dropna): | |||
ds.append(Instance(**data)) | |||
return ds | |||
def _cut_long_sentence(sent, max_sample_length=200): | |||
""" | |||
将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。 | |||
所以截取的句子可能长于或者短于max_sample_length | |||
:param sent: str. | |||
:param max_sample_length: int. | |||
:return: list of str. | |||
""" | |||
sent_no_space = sent.replace(' ', '') | |||
cutted_sentence = [] | |||
if len(sent_no_space) > max_sample_length: | |||
parts = sent.strip().split() | |||
new_line = '' | |||
length = 0 | |||
for part in parts: | |||
length += len(part) | |||
new_line += part + ' ' | |||
if length > max_sample_length: | |||
new_line = new_line[:-1] | |||
cutted_sentence.append(new_line) | |||
length = 0 | |||
new_line = '' | |||
if new_line != '': | |||
cutted_sentence.append(new_line[:-1]) | |||
else: | |||
cutted_sentence.append(sent) | |||
return cutted_sentence |
@@ -13,7 +13,6 @@ import warnings | |||
import numpy as np | |||
from .data_bundle import BaseLoader | |||
from ..core.utils import Option | |||
from ..core.vocabulary import Vocabulary | |||
@@ -32,10 +31,8 @@ class EmbeddingOption(Option): | |||
) | |||
class EmbedLoader(BaseLoader): | |||
class EmbedLoader: | |||
""" | |||
别名::class:`fastNLP.io.EmbedLoader` :class:`fastNLP.io.embed_loader.EmbedLoader` | |||
用于读取预训练的embedding, 读取结果可直接载入为模型参数。 | |||
""" | |||
@@ -84,9 +81,9 @@ class EmbedLoader(BaseLoader): | |||
word = ''.join(parts[:-dim]) | |||
nums = parts[-dim:] | |||
# 对齐unk与pad | |||
if word==padding and vocab.padding is not None: | |||
if word == padding and vocab.padding is not None: | |||
word = vocab.padding | |||
elif word==unknown and vocab.unknown is not None: | |||
elif word == unknown and vocab.unknown is not None: | |||
word = vocab.unknown | |||
if word in vocab: | |||
index = vocab.to_index(word) | |||
@@ -171,7 +168,7 @@ class EmbedLoader(BaseLoader): | |||
index = vocab.to_index(key) | |||
matrix[index] = vec | |||
if (unknown is not None and not found_unknown) or (padding is not None and not found_pad): | |||
if ((unknown is not None) and (not found_unknown)) or ((padding is not None) and (not found_pad)): | |||
start_idx = 0 | |||
if padding is not None: | |||
start_idx += 1 | |||
@@ -180,9 +177,9 @@ class EmbedLoader(BaseLoader): | |||
mean = np.mean(matrix[start_idx:], axis=0, keepdims=True) | |||
std = np.std(matrix[start_idx:], axis=0, keepdims=True) | |||
if (unknown is not None and not found_unknown): | |||
if (unknown is not None) and (not found_unknown): | |||
matrix[start_idx - 1] = np.random.randn(1, dim).astype(dtype) * std + mean | |||
if (padding is not None and not found_pad): | |||
if (padding is not None) and (not found_pad): | |||
matrix[0] = np.random.randn(1, dim).astype(dtype) * std + mean | |||
if normalize: | |||
@@ -5,6 +5,7 @@ | |||
__all__ = [] | |||
import json | |||
import csv | |||
from ..core import logger | |||
@@ -21,17 +22,17 @@ def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | |||
:if False, raise ValueError when reading invalid data. default: True | |||
:return: generator, every time yield (line number, csv item) | |||
""" | |||
with open(path, 'r', encoding=encoding) as f: | |||
with open(path, 'r', encoding=encoding) as csv_file: | |||
f = csv.reader(csv_file, delimiter=sep) | |||
start_idx = 0 | |||
if headers is None: | |||
headers = f.readline().rstrip('\r\n') | |||
headers = headers.split(sep) | |||
headers = next(f) | |||
start_idx += 1 | |||
elif not isinstance(headers, (list, tuple)): | |||
raise TypeError("headers should be list or tuple, not {}." \ | |||
.format(type(headers))) | |||
for line_idx, line in enumerate(f, start_idx): | |||
contents = line.rstrip('\r\n').split(sep) | |||
contents = line | |||
if len(contents) != len(headers): | |||
if dropna: | |||
continue | |||
@@ -77,6 +77,9 @@ PRETRAIN_STATIC_FILES = { | |||
'cn-tencent': "tencent_cn.zip", | |||
'cn-fasttext': "cc.zh.300.vec.gz", | |||
'cn-sgns-literature-word': 'sgns.literature.word.txt.zip', | |||
'cn-char-fastnlp-100d': "cn_char_fastnlp_100d.zip", | |||
'cn-bi-fastnlp-100d': "cn_bi_fastnlp_100d.zip", | |||
"cn-tri-fastnlp-100d": "cn_tri_fastnlp_100d.zip" | |||
} | |||
DATASET_DIR = { | |||
@@ -96,7 +99,9 @@ DATASET_DIR = { | |||
"cws-pku": 'cws_pku.zip', | |||
"cws-cityu": "cws_cityu.zip", | |||
"cws-as": 'cws_as.zip', | |||
"cws-msra": 'cws_msra.zip' | |||
"cws-msra": 'cws_msra.zip', | |||
"chn-senti-corp":"chn_senti_corp.zip" | |||
} | |||
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | |||
@@ -52,6 +52,7 @@ __all__ = [ | |||
'IMDBLoader', | |||
'SSTLoader', | |||
'SST2Loader', | |||
"ChnSentiCorpLoader", | |||
'ConllLoader', | |||
'Conll2003Loader', | |||
@@ -75,7 +76,7 @@ __all__ = [ | |||
"CRLoader" | |||
] | |||
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader | |||
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader | |||
from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader | |||
from .csv import CSVLoader | |||
from .cws import CWSLoader | |||
@@ -7,6 +7,7 @@ __all__ = [ | |||
"IMDBLoader", | |||
"SSTLoader", | |||
"SST2Loader", | |||
"ChnSentiCorpLoader" | |||
] | |||
import glob | |||
@@ -23,8 +24,6 @@ from ...core.instance import Instance | |||
class YelpLoader(Loader): | |||
""" | |||
别名::class:`fastNLP.io.YelpLoader` :class:`fastNLP.io.loader.YelpLoader` | |||
原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。 | |||
Example:: | |||
@@ -32,7 +31,6 @@ class YelpLoader(Loader): | |||
"1","I got 'new' tires from the..." | |||
"1","Don't waste your time..." | |||
读取YelpFull, YelpPolarity的数据。可以通过xxx下载并预处理数据。 | |||
读取的DataSet将具备以下的数据结构 | |||
.. csv-table:: | |||
@@ -163,8 +161,6 @@ class YelpPolarityLoader(YelpLoader): | |||
class IMDBLoader(Loader): | |||
""" | |||
别名::class:`fastNLP.io.IMDBLoader` :class:`fastNLP.io.loader.IMDBLoader` | |||
IMDBLoader读取后的数据将具有以下两列内容: raw_words: str, 需要分类的文本; target: str, 文本的标签 | |||
DataSet具备以下的结构: | |||
@@ -243,8 +239,6 @@ class IMDBLoader(Loader): | |||
class SSTLoader(Loader): | |||
""" | |||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.loader.SSTLoader` | |||
读取之后的DataSet具有以下的结构 | |||
.. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field | |||
@@ -346,3 +340,59 @@ class SST2Loader(Loader): | |||
""" | |||
output_dir = self._get_dataset_path(dataset_name='sst-2') | |||
return output_dir | |||
class ChnSentiCorpLoader(Loader): | |||
""" | |||
支持读取的数据的格式为,第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是label,第 | |||
一个制表符及之后认为是句子 | |||
Example:: | |||
label raw_chars | |||
1 這間酒店環境和服務態度亦算不錯,但房間空間太小~~ | |||
1 <荐书> 推荐所有喜欢<红楼>的红迷们一定要收藏这本书,要知道... | |||
0 商品的不足暂时还没发现,京东的订单处理速度实在.......周二就打包完成,周五才发货... | |||
读取后的DataSet具有以下的field | |||
.. csv-table:: | |||
:header: "raw_chars", "target" | |||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "1" | |||
"<荐书> 推荐所有喜欢<红楼>...", "1" | |||
"..." | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path:str): | |||
""" | |||
从path中读取数据 | |||
:param path: | |||
:return: | |||
""" | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() | |||
for line in f: | |||
line = line.strip() | |||
tab_index = line.index('\t') | |||
if tab_index!=-1: | |||
target = line[:tab_index] | |||
raw_chars = line[tab_index+1:] | |||
if raw_chars: | |||
ds.append(Instance(raw_chars=raw_chars, target=target)) | |||
return ds | |||
def download(self)->str: | |||
""" | |||
自动下载数据,该数据取自https://github.com/pengming617/bert_classification/tree/master/data,在 | |||
https://arxiv.org/pdf/1904.09223.pdf与https://arxiv.org/pdf/1906.08101.pdf有使用 | |||
:return: | |||
""" | |||
output_dir = self._get_dataset_path('chn-senti-corp') | |||
return output_dir |
@@ -27,8 +27,6 @@ from ...core.instance import Instance | |||
class ConllLoader(Loader): | |||
""" | |||
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.loader.ConllLoader` | |||
ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示: | |||
Example:: | |||
@@ -12,8 +12,6 @@ from ...core.instance import Instance | |||
class CSVLoader(Loader): | |||
""" | |||
别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.loader.CSVLoader` | |||
读取CSV格式的数据集, 返回 ``DataSet`` 。 | |||
:param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 | |||
@@ -12,8 +12,6 @@ from ...core.instance import Instance | |||
class JsonLoader(Loader): | |||
""" | |||
别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.loader.JsonLoader` | |||
读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 | |||
:param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name | |||
@@ -34,29 +34,27 @@ class Loader: | |||
""" | |||
从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 | |||
读取的field根据ConllLoader初始化时传入的headers决定。 | |||
:param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式 | |||
(0) 如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 | |||
(1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件 | |||
名包含'train'、 'dev'、 'test'则会报错:: | |||
data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train、 | |||
# dev、 test等有所变化,可以通过以下的方式取出DataSet | |||
tr_data = data_bundle.datasets['train'] | |||
te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段 | |||
data_bundle = xxxLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train、 | |||
# dev、 test等有所变化,可以通过以下的方式取出DataSet | |||
tr_data = data_bundle.get_dataset('train') | |||
te_data = data_bundle.get_dataset('test') # 如果目录下有文件包含test这个字段 | |||
(2) 传入文件路径:: | |||
(2) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test:: | |||
data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' | |||
tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet | |||
paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} | |||
data_bundle = xxxLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" | |||
dev_data = data_bundle.get_dataset('dev') | |||
(3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test:: | |||
(3) 传入文件路径:: | |||
paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} | |||
data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" | |||
dev_data = data_bundle.datasets['dev'] | |||
data_bundle = xxxLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' | |||
tr_data = data_bundle.get_dataset('train') # 取出DataSet | |||
:return: 返回的 :class:`~fastNLP.io.DataBundle` | |||
""" | |||
@@ -78,7 +76,7 @@ class Loader: | |||
@staticmethod | |||
def _get_dataset_path(dataset_name): | |||
""" | |||
传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存 | |||
传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存(如果支持的话) | |||
:param str dataset_name: 数据集的名称 | |||
:return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。 | |||
@@ -41,7 +41,7 @@ class MNLILoader(Loader): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() # 跳过header | |||
if path.endswith("test.tsv"): | |||
if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'): | |||
warnings.warn("RTE's test file has no target.") | |||
for line in f: | |||
line = line.strip() | |||
@@ -8,13 +8,9 @@ __all__ = [ | |||
import torch | |||
from .data_bundle import BaseLoader | |||
class ModelLoader(BaseLoader): | |||
class ModelLoader: | |||
""" | |||
别名::class:`fastNLP.io.ModelLoader` :class:`fastNLP.io.model_io.ModelLoader` | |||
用于读取模型 | |||
""" | |||
@@ -43,8 +39,6 @@ class ModelLoader(BaseLoader): | |||
class ModelSaver(object): | |||
""" | |||
别名::class:`fastNLP.io.ModelSaver` :class:`fastNLP.io.model_io.ModelSaver` | |||
用于保存模型 | |||
Example:: | |||
@@ -17,6 +17,7 @@ __all__ = [ | |||
"SSTPipe", | |||
"SST2Pipe", | |||
"IMDBPipe", | |||
"ChnSentiCorpPipe", | |||
"Conll2003NERPipe", | |||
"OntoNotesNERPipe", | |||
@@ -41,7 +42,7 @@ __all__ = [ | |||
"CoreferencePipe" | |||
] | |||
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe | |||
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe | |||
from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe | |||
from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ | |||
MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe | |||
@@ -5,7 +5,8 @@ __all__ = [ | |||
"YelpPolarityPipe", | |||
"SSTPipe", | |||
"SST2Pipe", | |||
'IMDBPipe' | |||
'IMDBPipe', | |||
"ChnSentiCorpPipe" | |||
] | |||
import re | |||
@@ -13,18 +14,18 @@ import re | |||
from nltk import Tree | |||
from .pipe import Pipe | |||
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance | |||
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance, _add_chars_field | |||
from ..data_bundle import DataBundle | |||
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | |||
from ...core.const import Const | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ...core.vocabulary import Vocabulary | |||
from ..loader.classification import ChnSentiCorpLoader | |||
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | |||
class _CLSPipe(Pipe): | |||
""" | |||
分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列 | |||
@@ -227,8 +228,6 @@ class YelpPolarityPipe(_CLSPipe): | |||
class SSTPipe(_CLSPipe): | |||
""" | |||
别名::class:`fastNLP.io.SSTPipe` :class:`fastNLP.io.pipe.SSTPipe` | |||
经过该Pipe之后,DataSet中具备的field如下所示 | |||
.. csv-table:: 下面是使用SSTPipe处理后的DataSet所具备的field | |||
@@ -457,3 +456,97 @@ class IMDBPipe(_CLSPipe): | |||
data_bundle = self.process(data_bundle) | |||
return data_bundle | |||
class ChnSentiCorpPipe(Pipe): | |||
""" | |||
处理之后的DataSet有以下的结构 | |||
.. csv-table:: | |||
:header: "raw_chars", "chars", "target", "seq_len" | |||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "[2, 3, 4, 5, ...]", 1, 31 | |||
"<荐书> 推荐所有喜欢<红楼>...", "[10, 21, ....]", 1, 25 | |||
"..." | |||
其中chars, seq_len是input,target是target | |||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | |||
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||
data_bundle.get_vocab('bigrams')获取. | |||
:param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] | |||
。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||
data_bundle.get_vocab('trigrams')获取. | |||
""" | |||
def __init__(self, bigrams=False, trigrams=False): | |||
super().__init__() | |||
self.bigrams = bigrams | |||
self.trigrams = trigrams | |||
def _tokenize(self, data_bundle): | |||
""" | |||
将DataSet中的"复旦大学"拆分为["复", "旦", "大", "学"]. 未来可以通过扩展这个函数实现分词。 | |||
:param data_bundle: | |||
:return: | |||
""" | |||
data_bundle.apply_field(list, field_name=Const.CHAR_INPUT, new_field_name=Const.CHAR_INPUT) | |||
return data_bundle | |||
def process(self, data_bundle:DataBundle): | |||
""" | |||
可以处理的DataSet应该具备以下的field | |||
.. csv-table:: | |||
:header: "raw_chars", "target" | |||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "1" | |||
"<荐书> 推荐所有喜欢<红楼>...", "1" | |||
"..." | |||
:param data_bundle: | |||
:return: | |||
""" | |||
_add_chars_field(data_bundle, lower=False) | |||
data_bundle = self._tokenize(data_bundle) | |||
input_field_names = [Const.CHAR_INPUT] | |||
if self.bigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||
field_name=Const.CHAR_INPUT, new_field_name='bigrams') | |||
input_field_names.append('bigrams') | |||
if self.trigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||
field_name=Const.CHAR_INPUT, new_field_name='trigrams') | |||
input_field_names.append('trigrams') | |||
# index | |||
_indexize(data_bundle, input_field_names, Const.TARGET) | |||
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names | |||
target_fields = [Const.TARGET] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.CHAR_INPUT) | |||
data_bundle.set_input(*input_fields) | |||
data_bundle.set_target(*target_fields) | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | |||
:return: DataBundle | |||
""" | |||
# 读取数据 | |||
data_bundle = ChnSentiCorpLoader().load(paths) | |||
data_bundle = self.process(data_bundle) | |||
return data_bundle |
@@ -51,7 +51,7 @@ class _NERPipe(Pipe): | |||
"[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" | |||
"[...]", "[...]" | |||
:param DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。 | |||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。 | |||
在传入DataBundle基础上原位修改。 | |||
:return: DataBundle | |||
""" | |||
@@ -193,7 +193,7 @@ class OntoNotesNERPipe(_NERPipe): | |||
""" | |||
处理OntoNotes的NER数据,处理之后DataSet中的field情况为 | |||
.. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader | |||
.. csv-table:: | |||
:header: "raw_words", "words", "target", "seq_len" | |||
"[Nadim, Ladki]", "[2, 3]", "[1, 2]", 2 | |||
@@ -222,14 +222,23 @@ class _CNNERPipe(Pipe): | |||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target, seq_len。 | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | |||
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||
data_bundle.get_vocab('bigrams')获取. | |||
:param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] | |||
。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||
data_bundle.get_vocab('trigrams')获取. | |||
""" | |||
def __init__(self, encoding_type: str = 'bio'): | |||
def __init__(self, encoding_type: str = 'bio', bigrams=False, trigrams=False): | |||
if encoding_type == 'bio': | |||
self.convert_tag = iob2 | |||
else: | |||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||
self.bigrams = bigrams | |||
self.trigrams = trigrams | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
""" | |||
支持的DataSet的field为 | |||
@@ -241,11 +250,11 @@ class _CNNERPipe(Pipe): | |||
"[青, 岛, 海, 牛, 队, 和, ...]", "[B-ORG, I-ORG, I-ORG, ...]" | |||
"[...]", "[...]" | |||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | |||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int], | |||
是转换为index的target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | |||
:param DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。 | |||
在传入DataBundle基础上原位修改。 | |||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field | |||
的内容均为List[str]。在传入DataBundle基础上原位修改。 | |||
:return: DataBundle | |||
""" | |||
# 转换tag | |||
@@ -253,11 +262,24 @@ class _CNNERPipe(Pipe): | |||
dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) | |||
_add_chars_field(data_bundle, lower=False) | |||
input_field_names = [Const.CHAR_INPUT] | |||
if self.bigrams: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||
field_name=Const.CHAR_INPUT, new_field_name='bigrams') | |||
input_field_names.append('bigrams') | |||
if self.trigrams: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||
field_name=Const.CHAR_INPUT, new_field_name='trigrams') | |||
input_field_names.append('trigrams') | |||
# index | |||
_indexize(data_bundle, input_field_names=Const.CHAR_INPUT, target_field_names=Const.TARGET) | |||
_indexize(data_bundle, input_field_names, Const.TARGET) | |||
input_fields = [Const.TARGET, Const.CHAR_INPUT, Const.INPUT_LEN] | |||
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names | |||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||
for name, dataset in data_bundle.datasets.items(): | |||
@@ -177,7 +177,7 @@ class MatchingPipe(Pipe): | |||
def _tokenize(self, data_bundle, field_names, new_field_names): | |||
""" | |||
:param DataBundle data_bundle: DataBundle. | |||
:param ~fastNLP.DataBundle data_bundle: DataBundle. | |||
:param list field_names: List[str], 需要tokenize的field名称 | |||
:param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。 | |||
:return: 输入的DataBundle对象 | |||
@@ -199,7 +199,7 @@ class MatchingPipe(Pipe): | |||
"This site includes a...", "The Government Executive...", "not_entailment" | |||
"...", "..." | |||
:param data_bundle: 通过loader读取得到的data_bundle,里面包含了数据集的原始数据内容 | |||
:param ~fastNLP.DataBundle data_bundle: 通过loader读取得到的data_bundle,里面包含了数据集的原始数据内容 | |||
:return: data_bundle | |||
""" | |||
data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)], | |||
@@ -9,13 +9,15 @@ from .. import DataBundle | |||
class Pipe: | |||
""" | |||
别名::class:`fastNLP.io.Pipe` :class:`fastNLP.io.pipe.Pipe` | |||
.. todo:: | |||
doc | |||
""" | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
""" | |||
对输入的DataBundle进行处理,然后返回该DataBundle。 | |||
:param data_bundle: 需要处理的DataBundle对象 | |||
:param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象 | |||
:return: | |||
""" | |||
raise NotImplementedError | |||
@@ -92,7 +92,7 @@ def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Con | |||
""" | |||
在dataset中的field_name列建立词表,Const.TARGET列建立词表,并把词表加入到data_bundle中。 | |||
:param data_bundle: | |||
:param ~fastNLP.DataBundle data_bundle: | |||
:param: str,list input_field_names: | |||
:param: str,list target_field_names: 这一列的vocabulary没有unknown和padding | |||
:return: | |||
@@ -154,7 +154,7 @@ def _drop_empty_instance(data_bundle, field_name): | |||
""" | |||
删除data_bundle的DataSet中存在的某个field为空的情况 | |||
:param data_bundle: DataBundle | |||
:param ~fastNLP.DataBundle data_bundle: | |||
:param str field_name: 对哪个field进行检查,如果为None,则任意field为空都会删掉 | |||
:return: 传入的DataBundle | |||
""" | |||
@@ -21,14 +21,24 @@ __all__ = [ | |||
"STSeqCls", | |||
"BiaffineParser", | |||
"GraphParser" | |||
"GraphParser", | |||
"BertForSequenceClassification", | |||
"BertForSentenceMatching", | |||
"BertForMultipleChoice", | |||
"BertForTokenClassification", | |||
"BertForQuestionAnswering" | |||
] | |||
from .base_model import BaseModel | |||
from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequenceClassification, \ | |||
BertForTokenClassification | |||
BertForTokenClassification, BertForSentenceMatching | |||
from .biaffine_parser import BiaffineParser, GraphParser | |||
from .cnn_text_classification import CNNText | |||
from .sequence_labeling import SeqLabeling, AdvSeqLabel | |||
from .snli import ESIM | |||
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel | |||
import sys | |||
from ..doc_utils import doc_process | |||
doc_process(sys.modules[__name__]) |