@@ -3,6 +3,7 @@ | |||
# You can set these variables from the command line. | |||
SPHINXOPTS = | |||
SPHINXAPIDOC = sphinx-apidoc | |||
SPHINXBUILD = sphinx-build | |||
SPHINXPROJ = fastNLP | |||
SOURCEDIR = source | |||
@@ -12,6 +13,12 @@ BUILDDIR = build | |||
help: | |||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) | |||
apidoc: | |||
@$(SPHINXAPIDOC) -f -o source ../fastNLP | |||
server: | |||
cd build/html && python -m http.server | |||
.PHONY: help Makefile | |||
# Catch-all target: route all unknown targets to Sphinx using the new | |||
@@ -23,9 +23,9 @@ copyright = '2018, xpqiu' | |||
author = 'xpqiu' | |||
# The short X.Y version | |||
version = '0.2' | |||
version = '0.4' | |||
# The full version, including alpha/beta/rc tags | |||
release = '0.2' | |||
release = '0.4' | |||
# -- General configuration --------------------------------------------------- | |||
@@ -67,7 +67,7 @@ language = None | |||
# List of patterns, relative to source directory, that match files and | |||
# directories to ignore when looking for source files. | |||
# This pattern also affects html_static_path and html_extra_path . | |||
exclude_patterns = [] | |||
exclude_patterns = ['modules.rst'] | |||
# The name of the Pygments (syntax highlighting) style to use. | |||
pygments_style = 'sphinx' | |||
@@ -1,36 +1,62 @@ | |||
fastNLP.api | |||
============ | |||
fastNLP.api package | |||
=================== | |||
fastNLP.api.api | |||
---------------- | |||
Submodules | |||
---------- | |||
fastNLP.api.api module | |||
---------------------- | |||
.. automodule:: fastNLP.api.api | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.api.converter | |||
---------------------- | |||
fastNLP.api.converter module | |||
---------------------------- | |||
.. automodule:: fastNLP.api.converter | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.api.model\_zoo | |||
----------------------- | |||
fastNLP.api.examples module | |||
--------------------------- | |||
.. automodule:: fastNLP.api.model_zoo | |||
.. automodule:: fastNLP.api.examples | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.api.pipeline | |||
--------------------- | |||
fastNLP.api.pipeline module | |||
--------------------------- | |||
.. automodule:: fastNLP.api.pipeline | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.api.processor | |||
---------------------- | |||
fastNLP.api.processor module | |||
---------------------------- | |||
.. automodule:: fastNLP.api.processor | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.api.utils module | |||
------------------------ | |||
.. automodule:: fastNLP.api.utils | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
Module contents | |||
--------------- | |||
.. automodule:: fastNLP.api | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,84 +1,126 @@ | |||
fastNLP.core | |||
============= | |||
fastNLP.core package | |||
==================== | |||
fastNLP.core.batch | |||
------------------- | |||
Submodules | |||
---------- | |||
fastNLP.core.batch module | |||
------------------------- | |||
.. automodule:: fastNLP.core.batch | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.callback module | |||
---------------------------- | |||
fastNLP.core.dataset | |||
--------------------- | |||
.. automodule:: fastNLP.core.callback | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.dataset module | |||
--------------------------- | |||
.. automodule:: fastNLP.core.dataset | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.fieldarray | |||
------------------------ | |||
fastNLP.core.fieldarray module | |||
------------------------------ | |||
.. automodule:: fastNLP.core.fieldarray | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.instance | |||
---------------------- | |||
fastNLP.core.instance module | |||
---------------------------- | |||
.. automodule:: fastNLP.core.instance | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.losses | |||
-------------------- | |||
fastNLP.core.losses module | |||
-------------------------- | |||
.. automodule:: fastNLP.core.losses | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.metrics | |||
--------------------- | |||
fastNLP.core.metrics module | |||
--------------------------- | |||
.. automodule:: fastNLP.core.metrics | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.optimizer | |||
----------------------- | |||
fastNLP.core.optimizer module | |||
----------------------------- | |||
.. automodule:: fastNLP.core.optimizer | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.predictor | |||
----------------------- | |||
fastNLP.core.predictor module | |||
----------------------------- | |||
.. automodule:: fastNLP.core.predictor | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.sampler | |||
--------------------- | |||
fastNLP.core.sampler module | |||
--------------------------- | |||
.. automodule:: fastNLP.core.sampler | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.tester | |||
-------------------- | |||
fastNLP.core.tester module | |||
-------------------------- | |||
.. automodule:: fastNLP.core.tester | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.trainer | |||
--------------------- | |||
fastNLP.core.trainer module | |||
--------------------------- | |||
.. automodule:: fastNLP.core.trainer | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.utils | |||
------------------- | |||
fastNLP.core.utils module | |||
------------------------- | |||
.. automodule:: fastNLP.core.utils | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.core.vocabulary | |||
------------------------ | |||
fastNLP.core.vocabulary module | |||
------------------------------ | |||
.. automodule:: fastNLP.core.vocabulary | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
Module contents | |||
--------------- | |||
.. automodule:: fastNLP.core | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,42 +1,62 @@ | |||
fastNLP.io | |||
=========== | |||
fastNLP.io package | |||
================== | |||
fastNLP.io.base\_loader | |||
------------------------ | |||
Submodules | |||
---------- | |||
fastNLP.io.base\_loader module | |||
------------------------------ | |||
.. automodule:: fastNLP.io.base_loader | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.io.config\_io | |||
---------------------- | |||
fastNLP.io.config\_io module | |||
---------------------------- | |||
.. automodule:: fastNLP.io.config_io | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.io.dataset\_loader | |||
--------------------------- | |||
fastNLP.io.dataset\_loader module | |||
--------------------------------- | |||
.. automodule:: fastNLP.io.dataset_loader | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.io.embed\_loader | |||
------------------------- | |||
fastNLP.io.embed\_loader module | |||
------------------------------- | |||
.. automodule:: fastNLP.io.embed_loader | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.io.logger | |||
------------------ | |||
fastNLP.io.file\_reader module | |||
------------------------------ | |||
.. automodule:: fastNLP.io.logger | |||
.. automodule:: fastNLP.io.file_reader | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.io.model\_io | |||
--------------------- | |||
fastNLP.io.model\_io module | |||
--------------------------- | |||
.. automodule:: fastNLP.io.model_io | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
Module contents | |||
--------------- | |||
.. automodule:: fastNLP.io | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,42 +1,110 @@ | |||
fastNLP.models | |||
=============== | |||
fastNLP.models package | |||
====================== | |||
fastNLP.models.base\_model | |||
--------------------------- | |||
Submodules | |||
---------- | |||
fastNLP.models.base\_model module | |||
--------------------------------- | |||
.. automodule:: fastNLP.models.base_model | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.bert module | |||
-------------------------- | |||
fastNLP.models.biaffine\_parser | |||
-------------------------------- | |||
.. automodule:: fastNLP.models.bert | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.biaffine\_parser module | |||
-------------------------------------- | |||
.. automodule:: fastNLP.models.biaffine_parser | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.char\_language\_model | |||
------------------------------------- | |||
fastNLP.models.char\_language\_model module | |||
------------------------------------------- | |||
.. automodule:: fastNLP.models.char_language_model | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.cnn\_text\_classification | |||
----------------------------------------- | |||
fastNLP.models.cnn\_text\_classification module | |||
----------------------------------------------- | |||
.. automodule:: fastNLP.models.cnn_text_classification | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.enas\_controller module | |||
-------------------------------------- | |||
.. automodule:: fastNLP.models.enas_controller | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.enas\_model module | |||
--------------------------------- | |||
.. automodule:: fastNLP.models.enas_model | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.sequence\_modeling | |||
---------------------------------- | |||
fastNLP.models.enas\_trainer module | |||
----------------------------------- | |||
.. automodule:: fastNLP.models.enas_trainer | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.enas\_utils module | |||
--------------------------------- | |||
.. automodule:: fastNLP.models.enas_utils | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.sequence\_modeling module | |||
---------------------------------------- | |||
.. automodule:: fastNLP.models.sequence_modeling | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.snli | |||
-------------------- | |||
fastNLP.models.snli module | |||
-------------------------- | |||
.. automodule:: fastNLP.models.snli | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.star\_transformer module | |||
--------------------------------------- | |||
.. automodule:: fastNLP.models.star_transformer | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
Module contents | |||
--------------- | |||
.. automodule:: fastNLP.models | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,36 +1,54 @@ | |||
fastNLP.modules.aggregator | |||
=========================== | |||
fastNLP.modules.aggregator package | |||
================================== | |||
fastNLP.modules.aggregator.attention | |||
------------------------------------- | |||
Submodules | |||
---------- | |||
fastNLP.modules.aggregator.attention module | |||
------------------------------------------- | |||
.. automodule:: fastNLP.modules.aggregator.attention | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.aggregator.avg\_pool | |||
------------------------------------- | |||
fastNLP.modules.aggregator.avg\_pool module | |||
------------------------------------------- | |||
.. automodule:: fastNLP.modules.aggregator.avg_pool | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.aggregator.kmax\_pool | |||
-------------------------------------- | |||
fastNLP.modules.aggregator.kmax\_pool module | |||
-------------------------------------------- | |||
.. automodule:: fastNLP.modules.aggregator.kmax_pool | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.aggregator.max\_pool | |||
------------------------------------- | |||
fastNLP.modules.aggregator.max\_pool module | |||
------------------------------------------- | |||
.. automodule:: fastNLP.modules.aggregator.max_pool | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.aggregator.self\_attention | |||
------------------------------------------- | |||
fastNLP.modules.aggregator.self\_attention module | |||
------------------------------------------------- | |||
.. automodule:: fastNLP.modules.aggregator.self_attention | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
Module contents | |||
--------------- | |||
.. automodule:: fastNLP.modules.aggregator | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,18 +1,38 @@ | |||
fastNLP.modules.decoder | |||
======================== | |||
fastNLP.modules.decoder package | |||
=============================== | |||
fastNLP.modules.decoder.CRF | |||
---------------------------- | |||
Submodules | |||
---------- | |||
fastNLP.modules.decoder.CRF module | |||
---------------------------------- | |||
.. automodule:: fastNLP.modules.decoder.CRF | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.decoder.MLP | |||
---------------------------- | |||
fastNLP.modules.decoder.MLP module | |||
---------------------------------- | |||
.. automodule:: fastNLP.modules.decoder.MLP | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.decoder.utils module | |||
------------------------------------ | |||
.. automodule:: fastNLP.modules.decoder.utils | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
Module contents | |||
--------------- | |||
.. automodule:: fastNLP.modules.decoder | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,60 +1,94 @@ | |||
fastNLP.modules.encoder | |||
======================== | |||
fastNLP.modules.encoder package | |||
=============================== | |||
fastNLP.modules.encoder.char\_embedding | |||
---------------------------------------- | |||
Submodules | |||
---------- | |||
fastNLP.modules.encoder.char\_embedding module | |||
---------------------------------------------- | |||
.. automodule:: fastNLP.modules.encoder.char_embedding | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.conv | |||
----------------------------- | |||
fastNLP.modules.encoder.conv module | |||
----------------------------------- | |||
.. automodule:: fastNLP.modules.encoder.conv | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.conv\_maxpool | |||
-------------------------------------- | |||
fastNLP.modules.encoder.conv\_maxpool module | |||
-------------------------------------------- | |||
.. automodule:: fastNLP.modules.encoder.conv_maxpool | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.embedding | |||
---------------------------------- | |||
fastNLP.modules.encoder.embedding module | |||
---------------------------------------- | |||
.. automodule:: fastNLP.modules.encoder.embedding | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.linear | |||
------------------------------- | |||
fastNLP.modules.encoder.linear module | |||
------------------------------------- | |||
.. automodule:: fastNLP.modules.encoder.linear | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.lstm | |||
----------------------------- | |||
fastNLP.modules.encoder.lstm module | |||
----------------------------------- | |||
.. automodule:: fastNLP.modules.encoder.lstm | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.masked\_rnn | |||
------------------------------------ | |||
fastNLP.modules.encoder.masked\_rnn module | |||
------------------------------------------ | |||
.. automodule:: fastNLP.modules.encoder.masked_rnn | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.transformer | |||
------------------------------------ | |||
fastNLP.modules.encoder.star\_transformer module | |||
------------------------------------------------ | |||
.. automodule:: fastNLP.modules.encoder.star_transformer | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.transformer module | |||
------------------------------------------ | |||
.. automodule:: fastNLP.modules.encoder.transformer | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.variational\_rnn | |||
----------------------------------------- | |||
fastNLP.modules.encoder.variational\_rnn module | |||
----------------------------------------------- | |||
.. automodule:: fastNLP.modules.encoder.variational_rnn | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
Module contents | |||
--------------- | |||
.. automodule:: fastNLP.modules.encoder | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,5 +1,8 @@ | |||
fastNLP.modules | |||
================ | |||
fastNLP.modules package | |||
======================= | |||
Subpackages | |||
----------- | |||
.. toctree:: | |||
@@ -7,24 +10,38 @@ fastNLP.modules | |||
fastNLP.modules.decoder | |||
fastNLP.modules.encoder | |||
fastNLP.modules.dropout | |||
------------------------ | |||
Submodules | |||
---------- | |||
fastNLP.modules.dropout module | |||
------------------------------ | |||
.. automodule:: fastNLP.modules.dropout | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.other\_modules | |||
------------------------------- | |||
fastNLP.modules.other\_modules module | |||
------------------------------------- | |||
.. automodule:: fastNLP.modules.other_modules | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.utils | |||
---------------------- | |||
fastNLP.modules.utils module | |||
---------------------------- | |||
.. automodule:: fastNLP.modules.utils | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
Module contents | |||
--------------- | |||
.. automodule:: fastNLP.modules | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,13 +1,22 @@ | |||
fastNLP | |||
======== | |||
fastNLP package | |||
=============== | |||
Subpackages | |||
----------- | |||
.. toctree:: | |||
fastNLP.api | |||
fastNLP.automl | |||
fastNLP.core | |||
fastNLP.io | |||
fastNLP.models | |||
fastNLP.modules | |||
Module contents | |||
--------------- | |||
.. automodule:: fastNLP | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,3 +1,41 @@ | |||
""" | |||
api.api的介绍文档 | |||
直接缩进会把上面的文字变成标题 | |||
空行缩进的写法比较合理 | |||
比较合理 | |||
*这里是斜体内容* | |||
**这里是粗体内容** | |||
数学公式块 | |||
.. math:: | |||
E = mc^2 | |||
.. note:: | |||
注解型提示。 | |||
.. warning:: | |||
警告型提示。 | |||
.. seealso:: | |||
`参考与超链接 <https://willqvq.github.io/doc_guide/%E6%B3%A8%E9%87%8A%E6%8C%87%E5%AF%BC>`_ | |||
普通代码块需要空一行, Example:: | |||
from fitlog import fitlog | |||
fitlog.commit() | |||
普通下标和上标: | |||
H\ :sub:`2`\ O | |||
E = mc\ :sup:`2` | |||
""" | |||
import warnings | |||
import torch | |||
@@ -9,7 +47,7 @@ from fastNLP.core.dataset import DataSet | |||
from fastNLP.api.utils import load_url | |||
from fastNLP.api.processor import ModelProcessor | |||
from fastNLP.io.dataset_loader import ConllCWSReader, ConllxDataLoader | |||
from fastNLP.io.dataset_loader import cut_long_sentence, ConllLoader | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.api.pipeline import Pipeline | |||
from fastNLP.core.metrics import SpanFPreRecMetric | |||
@@ -23,7 +61,89 @@ model_urls = { | |||
} | |||
class ConllCWSReader(object): | |||
"""Deprecated. Use ConllLoader for all types of conll-format files.""" | |||
def __init__(self): | |||
pass | |||
def load(self, path, cut_long_sent=False): | |||
""" | |||
返回的DataSet只包含raw_sentence这个field,内容为str。 | |||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||
:: | |||
1 编者按 编者按 NN O 11 nmod:topic | |||
2 : : PU O 11 punct | |||
3 7月 7月 NT DATE 4 compound:nn | |||
4 12日 12日 NT DATE 11 nmod:tmod | |||
5 , , PU O 11 punct | |||
1 这 这 DT O 3 det | |||
2 款 款 M O 1 mark:clf | |||
3 飞行 飞行 NN O 8 nsubj | |||
4 从 从 P O 5 case | |||
5 外型 外型 NN O 8 nmod:prep | |||
""" | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
sample = [] | |||
for line in f: | |||
if line.startswith('\n'): | |||
datalist.append(sample) | |||
sample = [] | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
sample.append(line.strip().split()) | |||
if len(sample) > 0: | |||
datalist.append(sample) | |||
ds = DataSet() | |||
for sample in datalist: | |||
# print(sample) | |||
res = self.get_char_lst(sample) | |||
if res is None: | |||
continue | |||
line = ' '.join(res) | |||
if cut_long_sent: | |||
sents = cut_long_sentence(line) | |||
else: | |||
sents = [line] | |||
for raw_sentence in sents: | |||
ds.append(Instance(raw_sentence=raw_sentence)) | |||
return ds | |||
def get_char_lst(self, sample): | |||
if len(sample) == 0: | |||
return None | |||
text = [] | |||
for w in sample: | |||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||
if t3 == '_': | |||
return None | |||
text.append(t1) | |||
return text | |||
class ConllxDataLoader(ConllLoader): | |||
"""返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。 | |||
Deprecated. Use ConllLoader for all types of conll-format files. | |||
""" | |||
def __init__(self): | |||
headers = [ | |||
'words', 'pos_tags', 'heads', 'labels', | |||
] | |||
indexs = [ | |||
1, 3, 6, 7, | |||
] | |||
super(ConllxDataLoader, self).__init__(headers=headers, indexs=indexs) | |||
class API: | |||
""" | |||
这是 API 类的文档 | |||
""" | |||
def __init__(self): | |||
self.pipeline = None | |||
self._dict = None | |||
@@ -69,8 +189,9 @@ class POS(API): | |||
self.load(model_path, device) | |||
def predict(self, content): | |||
""" | |||
"""predict函数的介绍, | |||
函数介绍的第二句,这句话不会换行 | |||
:param content: list of list of str. Each string is a token(word). | |||
:return answer: list of list of str. Each string is a tag. | |||
""" | |||
@@ -136,13 +257,14 @@ class POS(API): | |||
class CWS(API): | |||
def __init__(self, model_path=None, device='cpu'): | |||
""" | |||
中文分词高级接口。 | |||
""" | |||
中文分词高级接口。 | |||
:param model_path: 当model_path为None,使用默认位置的model。如果默认位置不存在,则自动下载模型 | |||
:param device: str,可以为'cpu', 'cuda'或'cuda:0'等。会将模型load到相应device进行推断。 | |||
""" | |||
:param model_path: 当model_path为None,使用默认位置的model。如果默认位置不存在,则自动下载模型 | |||
:param device: str,可以为'cpu', 'cuda'或'cuda:0'等。会将模型load到相应device进行推断。 | |||
""" | |||
def __init__(self, model_path=None, device='cpu'): | |||
super(CWS, self).__init__() | |||
if model_path is None: | |||
model_path = model_urls['cws'] | |||
@@ -183,18 +305,20 @@ class CWS(API): | |||
def test(self, filepath): | |||
""" | |||
传入一个分词文件路径,返回该数据集上分词f1, precision, recall。 | |||
分词文件应该为: | |||
分词文件应该为:: | |||
1 编者按 编者按 NN O 11 nmod:topic | |||
2 : : PU O 11 punct | |||
3 7月 7月 NT DATE 4 compound:nn | |||
4 12日 12日 NT DATE 11 nmod:tmod | |||
5 , , PU O 11 punct | |||
1 这 这 DT O 3 det | |||
2 款 款 M O 1 mark:clf | |||
3 飞行 飞行 NN O 8 nsubj | |||
4 从 从 P O 5 case | |||
5 外型 外型 NN O 8 nmod:prep | |||
以空行分割两个句子,有内容的每行有7列。 | |||
:param filepath: str, 文件路径路径。 | |||
@@ -62,13 +62,14 @@ class ENASTrainer(fastNLP.Trainer): | |||
""" | |||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | |||
最好的模型参数。 | |||
:return results: 返回一个字典类型的数据, 内含以下内容:: | |||
seconds: float, 表示训练时长 | |||
以下三个内容只有在提供了dev_data的情况下会有。 | |||
best_eval: Dict of Dict, 表示evaluation的结果 | |||
best_epoch: int,在第几个epoch取得的最佳值 | |||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||
:return results: 返回一个字典类型的数据, | |||
内含以下内容:: | |||
seconds: float, 表示训练时长 | |||
以下三个内容只有在提供了dev_data的情况下会有。 | |||
best_eval: Dict of Dict, 表示evaluation的结果 | |||
best_epoch: int,在第几个epoch取得的最佳值 | |||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||
""" | |||
results = {} | |||
@@ -1,3 +1,18 @@ | |||
""" | |||
fastNLP.core.DataSet的介绍文档 | |||
DataSet是fastNLP中用于承载数据的容器。可以将DataSet看做是一个表格,每一行是一个instance(或sample),每一列是一个feature。 | |||
csv-table:: | |||
:header: "Field1", "Field2", "Field3" | |||
:widths:20, 10, 10 | |||
"This is the first instance", ['This', 'is', 'the', 'first', 'instance'], 5 | |||
"Second instance", ['Second', 'instance'], 2 | |||
""" | |||
import _pickle as pickle | |||
import numpy as np | |||
@@ -31,7 +46,7 @@ class DataSet(object): | |||
length_set.add(len(value)) | |||
assert len(length_set) == 1, "Arrays must all be same length." | |||
for key, value in data.items(): | |||
self.add_field(name=key, fields=value) | |||
self.add_field(field_name=key, fields=value) | |||
elif isinstance(data, list): | |||
for ins in data: | |||
assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins)) | |||
@@ -88,7 +103,7 @@ class DataSet(object): | |||
raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self)-1}") | |||
data_set = DataSet() | |||
for field in self.field_arrays.values(): | |||
data_set.add_field(name=field.name, fields=field.content[idx], padder=field.padder, | |||
data_set.add_field(field_name=field.name, fields=field.content[idx], padder=field.padder, | |||
is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type) | |||
return data_set | |||
elif isinstance(idx, str): | |||
@@ -131,7 +146,7 @@ class DataSet(object): | |||
return "DataSet(" + self.__inner_repr__() + ")" | |||
def append(self, ins): | |||
"""Add an instance to the DataSet. | |||
"""将一个instance对象append到DataSet后面。 | |||
If the DataSet is not empty, the instance must have the same field names as the rest instances in the DataSet. | |||
:param ins: an Instance object | |||
@@ -151,54 +166,60 @@ class DataSet(object): | |||
assert name in self.field_arrays | |||
self.field_arrays[name].append(field) | |||
def add_field(self, name, fields, padder=AutoPadder(pad_val=0), is_input=False, is_target=False, ignore_type=False): | |||
"""Add a new field to the DataSet. | |||
def add_field(self, field_name, fields, padder=AutoPadder(), is_input=False, is_target=False, ignore_type=False): | |||
"""新增一个field | |||
:param str name: the name of the field. | |||
:param fields: a list of int, float, or other objects. | |||
:param int padder: PadBase对象,如何对该Field进行padding。大部分情况使用默认值即可 | |||
:param bool is_input: whether this field is model input. | |||
:param bool is_target: whether this field is label or target. | |||
:param bool ignore_type: If True, do not perform type check. (Default: False) | |||
:param str field_name: 新增的field的名称 | |||
:param list fields: 需要新增的field的内容 | |||
:param None, Padder padder: 如果为None,则不进行pad。 | |||
:param bool is_input: 新加入的field是否是input | |||
:param bool is_target: 新加入的field是否是target | |||
:param bool ignore_type: 是否忽略对新加入的field的类型检查 | |||
""" | |||
if len(self.field_arrays) != 0: | |||
if len(self) != len(fields): | |||
raise RuntimeError(f"The field to append must have the same size as dataset. " | |||
f"Dataset size {len(self)} != field size {len(fields)}") | |||
self.field_arrays[name] = FieldArray(name, fields, is_target=is_target, is_input=is_input, | |||
padder=padder, ignore_type=ignore_type) | |||
self.field_arrays[field_name] = FieldArray(field_name, fields, is_target=is_target, is_input=is_input, | |||
padder=padder, ignore_type=ignore_type) | |||
def delete_field(self, name): | |||
"""Delete a field based on the field name. | |||
def delete_field(self, field_name): | |||
"""删除field | |||
:param name: the name of the field to be deleted. | |||
:param str field_name: 需要删除的field的名称. | |||
""" | |||
self.field_arrays.pop(name) | |||
self.field_arrays.pop(field_name) | |||
def get_field(self, field_name): | |||
"""获取field_name这个field | |||
:param str field_name: field的名称 | |||
:return: FieldArray | |||
""" | |||
if field_name not in self.field_arrays: | |||
raise KeyError("Field name {} not found in DataSet".format(field_name)) | |||
return self.field_arrays[field_name] | |||
def get_all_fields(self): | |||
"""Return all the fields with their names. | |||
"""返回一个dict,key为field_name, value为对应的FieldArray | |||
:return field_arrays: the internal data structure of DataSet. | |||
:return: dict: | |||
""" | |||
return self.field_arrays | |||
def get_length(self): | |||
"""Fetch the length of the dataset. | |||
"""获取DataSet的元素数量 | |||
:return length: | |||
:return: int length: | |||
""" | |||
return len(self) | |||
def rename_field(self, old_name, new_name): | |||
"""Rename a field. | |||
"""将某个field重新命名. | |||
:param str old_name: | |||
:param str new_name: | |||
:param str old_name: 原来的field名称 | |||
:param str new_name: 修改为new_name | |||
""" | |||
if old_name in self.field_arrays: | |||
self.field_arrays[new_name] = self.field_arrays.pop(old_name) | |||
@@ -207,34 +228,62 @@ class DataSet(object): | |||
raise KeyError("DataSet has no field named {}.".format(old_name)) | |||
def set_target(self, *field_names, flag=True): | |||
"""Change the target flag of these fields. | |||
"""将field_names的target设置为flag状态 | |||
Example:: | |||
:param field_names: a sequence of str, indicating field names | |||
:param bool flag: Set these fields as target if True. Unset them if False. | |||
dataset.set_target('labels', 'seq_len') # 将labels和seq_len这两个field的target属性设置为True | |||
dataset.set_target('labels', 'seq_lens', flag=False) # 将labels和seq_len的target属性设置为False | |||
:param str field_names: field的名称 | |||
:param bool flag: 将field_name的target状态设置为flag | |||
""" | |||
assert isinstance(flag, bool), "Only bool type supported." | |||
for name in field_names: | |||
if name in self.field_arrays: | |||
self.field_arrays[name].is_target = flag | |||
else: | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
def set_input(self, *field_name, flag=True): | |||
"""Set the input flag of these fields. | |||
def set_input(self, *field_names, flag=True): | |||
"""将field_name的input设置为flag状态 | |||
Example:: | |||
dataset.set_input('words', 'seq_len') # 将words和seq_len这两个field的input属性设置为True | |||
dataset.set_input('words', flag=False) # 将words这个field的input属性设置为False | |||
:param field_name: a sequence of str, indicating field names. | |||
:param bool flag: Set these fields as input if True. Unset them if False. | |||
:param str field_names: field的名称 | |||
:param bool flag: 将field_name的input状态设置为flag | |||
""" | |||
for name in field_name: | |||
for name in field_names: | |||
if name in self.field_arrays: | |||
self.field_arrays[name].is_input = flag | |||
else: | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
def set_padder(self, field_name, padder): | |||
def set_ignore_type(self, *field_names, flag=True): | |||
"""将field_names的ignore_type设置为flag状态 | |||
:param str field_names: field的名称 | |||
:param bool flag: 将field_name的ignore_type状态设置为flag | |||
:return: | |||
""" | |||
为field_name设置padder | |||
:param field_name: str, 设置field的padding方式为padder | |||
:param padder: PadderBase类型或None. 设置为None即删除padder。即对该field不进行padding操作. | |||
assert isinstance(flag, bool), "Only bool type supported." | |||
for name in field_names: | |||
if name in self.field_arrays: | |||
self.field_arrays[name].ignore_type = flag | |||
else: | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
def set_padder(self, field_name, padder): | |||
"""为field_name设置padder | |||
Example:: | |||
from fastNLP import EngChar2DPadder | |||
padder = EngChar2DPadder() | |||
dataset.set_padder('chars', padder) # 则chars这个field会使用EngChar2DPadder进行pad操作 | |||
:param str field_name: 设置field的padding方式为padder | |||
:param None, Padder padder: 设置为None即删除padder, 即对该field不进行pad操作. | |||
:return: | |||
""" | |||
if field_name not in self.field_arrays: | |||
@@ -242,11 +291,10 @@ class DataSet(object): | |||
self.field_arrays[field_name].set_padder(padder) | |||
def set_pad_val(self, field_name, pad_val): | |||
""" | |||
为某个 | |||
"""为某个field设置对应的pad_val. | |||
:param field_name: str,修改该field的pad_val | |||
:param pad_val: int,该field的padder会以pad_val作为padding index | |||
:param str field_name: 修改该field的pad_val | |||
:param int pad_val: 该field的padder会以pad_val作为padding index | |||
:return: | |||
""" | |||
if field_name not in self.field_arrays: | |||
@@ -254,43 +302,68 @@ class DataSet(object): | |||
self.field_arrays[field_name].set_pad_val(pad_val) | |||
def get_input_name(self): | |||
"""Get all field names with `is_input` as True. | |||
"""返回所有is_input被设置为True的field名称 | |||
:return field_names: a list of str | |||
:return: list, 里面的元素为被设置为input的field名称 | |||
""" | |||
return [name for name, field in self.field_arrays.items() if field.is_input] | |||
def get_target_name(self): | |||
"""Get all field names with `is_target` as True. | |||
"""返回所有is_target被设置为True的field名称 | |||
:return field_names: a list of str | |||
:return list, 里面的元素为被设置为target的field名称 | |||
""" | |||
return [name for name, field in self.field_arrays.items() if field.is_target] | |||
def apply(self, func, new_field_name=None, **kwargs): | |||
"""Apply a function to every instance of the DataSet. | |||
:param func: a function that takes an instance as input. | |||
:param str new_field_name: If not None, results of the function will be stored as a new field. | |||
:param **kwargs: Accept parameters will be | |||
(1) is_input: boolean, will be ignored if new_field is None. If True, the new field will be as input. | |||
(2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target. | |||
:return results: if new_field_name is not passed, returned values of the function over all instances. | |||
def apply_field(self, func, field_name, new_field_name=None, **kwargs): | |||
"""将DataSet中的每个instance中的`field_name`这个field传给func,并获取它的返回值. | |||
:param callable func: input是instance的`field_name`这个field. | |||
:param str field_name: 传入func的是哪个field. | |||
:param str, None new_field_name: 将func返回的内容放入到什么field中 | |||
1. str, 将func的返回值放入这个名为`new_field_name`的新field中,如果名称与已有的field相 | |||
同,则覆盖之前的field | |||
2. None, 不创建新的field | |||
:param kwargs: 合法的参数有以下三个 | |||
1. is_input: bool, 如果为True则将`new_field_name`的field设置为input | |||
2. is_target: bool, 如果为True则将`new_field_name`的field设置为target | |||
3. ignore_type: bool, 如果为True则将`new_field_name`的field的ignore_type设置为true, 忽略其类型 | |||
:return: list(Any), 里面的元素为func的返回值,所以list长度为DataSet的长度 | |||
""" | |||
assert len(self)!=0, "Null dataset cannot use .apply()." | |||
assert len(self)!=0, "Null DataSet cannot use apply()." | |||
if field_name not in self: | |||
raise KeyError("DataSet has no field named `{}`.".format(field_name)) | |||
results = [] | |||
idx = -1 | |||
try: | |||
for idx, ins in enumerate(self._inner_iter()): | |||
results.append(func(ins)) | |||
results.append(func(ins[field_name])) | |||
except Exception as e: | |||
if idx!=-1: | |||
print("Exception happens at the `{}`th instance.".format(idx)) | |||
raise e | |||
# results = [func(ins) for ins in self._inner_iter()] | |||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | |||
if new_field_name is not None: | |||
self._add_apply_field(results, new_field_name, kwargs) | |||
return results | |||
def _add_apply_field(self, results, new_field_name, kwargs): | |||
"""将results作为加入到新的field中,field名称为new_field_name | |||
:param list(str) results: 一般是apply*()之后的结果 | |||
:param str new_field_name: 新加入的field的名称 | |||
:param dict kwargs: 用户apply*()时传入的自定义参数 | |||
:return: | |||
""" | |||
extra_param = {} | |||
if 'is_input' in kwargs: | |||
extra_param['is_input'] = kwargs['is_input'] | |||
@@ -298,56 +371,91 @@ class DataSet(object): | |||
extra_param['is_target'] = kwargs['is_target'] | |||
if 'ignore_type' in kwargs: | |||
extra_param['ignore_type'] = kwargs['ignore_type'] | |||
if new_field_name is not None: | |||
if new_field_name in self.field_arrays: | |||
# overwrite the field, keep same attributes | |||
old_field = self.field_arrays[new_field_name] | |||
if 'is_input' not in extra_param: | |||
extra_param['is_input'] = old_field.is_input | |||
if 'is_target' not in extra_param: | |||
extra_param['is_target'] = old_field.is_target | |||
if 'ignore_type' not in extra_param: | |||
extra_param['ignore_type'] = old_field.ignore_type | |||
self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"], | |||
is_target=extra_param["is_target"], ignore_type=extra_param['ignore_type']) | |||
else: | |||
self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | |||
is_target=extra_param.get("is_target", None), | |||
ignore_type=extra_param.get("ignore_type", False)) | |||
if new_field_name in self.field_arrays: | |||
# overwrite the field, keep same attributes | |||
old_field = self.field_arrays[new_field_name] | |||
if 'is_input' not in extra_param: | |||
extra_param['is_input'] = old_field.is_input | |||
if 'is_target' not in extra_param: | |||
extra_param['is_target'] = old_field.is_target | |||
if 'ignore_type' not in extra_param: | |||
extra_param['ignore_type'] = old_field.ignore_type | |||
self.add_field(field_name=new_field_name, fields=results, is_input=extra_param["is_input"], | |||
is_target=extra_param["is_target"], ignore_type=extra_param['ignore_type']) | |||
else: | |||
return results | |||
self.add_field(field_name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | |||
is_target=extra_param.get("is_target", None), | |||
ignore_type=extra_param.get("ignore_type", False)) | |||
def apply(self, func, new_field_name=None, **kwargs): | |||
"""将DataSet中每个instance传入到func中,并获取它的返回值. | |||
:param callable func: 参数是DataSet中的instance | |||
:param str, None new_field_name: 将func返回的内容放入到什么field中 | |||
1. str, 将func的返回值放入这个名为`new_field_name`的新field中,如果名称与已有的field相 | |||
同,则覆盖之前的field | |||
2. None, 不创建新的field | |||
:param kwargs: 合法的参数有以下三个 | |||
1. is_input: bool, 如果为True则将`new_field_name`的field设置为input | |||
2. is_target: bool, 如果为True则将`new_field_name`的field设置为target | |||
3. ignore_type: bool, 如果为True则将`new_field_name`的field的ignore_type设置为true, 忽略其类型 | |||
:return: List[], 里面的元素为func的返回值,所以list长度为DataSet的长度 | |||
""" | |||
assert len(self)!=0, "Null DataSet cannot use apply()." | |||
idx = -1 | |||
try: | |||
results = [] | |||
for idx, ins in enumerate(self._inner_iter()): | |||
results.append(func(ins)) | |||
except Exception as e: | |||
if idx!=-1: | |||
print("Exception happens at the `{}`th instance.".format(idx)) | |||
raise e | |||
# results = [func(ins) for ins in self._inner_iter()] | |||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | |||
if new_field_name is not None: | |||
self._add_apply_field(results, new_field_name, kwargs) | |||
return results | |||
def drop(self, func, inplace=True): | |||
"""Drop instances if a condition holds. | |||
"""func接受一个instance,返回bool值,返回值为True时,该instance会被删除。 | |||
:param func: a function that takes an Instance object as input, and returns bool. | |||
The instance will be dropped if the function returns True. | |||
:param inplace: bool, whether to drop inpalce. Otherwise a new dataset will be returned. | |||
:param callable func: 接受一个instance作为参数,返回bool值。为True时删除该instance | |||
:param bool inplace: 是否在当前DataSet中直接删除instance。如果为False,返回值为一个删除了相应instance的新的DataSet | |||
:return: DataSet | |||
""" | |||
if inplace: | |||
results = [ins for ins in self._inner_iter() if not func(ins)] | |||
for name, old_field in self.field_arrays.items(): | |||
self.field_arrays[name].content = [ins[name] for ins in results] | |||
return self | |||
else: | |||
results = [ins for ins in self if not func(ins)] | |||
data = DataSet(results) | |||
dataset = DataSet(results) | |||
for field_name, field in self.field_arrays.items(): | |||
data.field_arrays[field_name].to(field) | |||
dataset.field_arrays[field_name].to(field) | |||
return dataset | |||
def split(self, dev_ratio): | |||
"""Split the dataset into training and development(validation) set. | |||
def split(self, ratio): | |||
"""将DataSet按照ratio的比例拆分,返回两个DataSet | |||
:param float dev_ratio: the ratio of test set in all data. | |||
:return (train_set, dev_set): | |||
train_set: the training set | |||
dev_set: the development set | |||
:param float ratio: 0<ratio<1, 返回的第一个DataSet拥有ratio这么多数据,第二个DataSet拥有(1-ratio)这么多数据 | |||
:return: [DataSet, DataSet] | |||
""" | |||
assert isinstance(dev_ratio, float) | |||
assert 0 < dev_ratio < 1 | |||
assert isinstance(ratio, float) | |||
assert 0 < ratio < 1 | |||
all_indices = [_ for _ in range(len(self))] | |||
np.random.shuffle(all_indices) | |||
split = int(dev_ratio * len(self)) | |||
split = int(ratio * len(self)) | |||
dev_indices = all_indices[:split] | |||
train_indices = all_indices[split:] | |||
dev_set = DataSet() | |||
@@ -373,6 +481,9 @@ class DataSet(object): | |||
:return dataset: the read data set | |||
""" | |||
import warnings | |||
warnings.warn('read_csv is deprecated, use CSVLoader instead', | |||
category=DeprecationWarning) | |||
with open(csv_path, "r", encoding='utf-8') as f: | |||
start_idx = 0 | |||
if headers is None: | |||
@@ -398,26 +509,25 @@ class DataSet(object): | |||
_dict[header].append(content) | |||
return cls(_dict) | |||
# def read_pos(self): | |||
# return DataLoaderRegister.get_reader('read_pos') | |||
def save(self, path): | |||
"""Save the DataSet object as pickle. | |||
"""保存DataSet. | |||
:param str path: the path to the pickle | |||
:param str path: 将DataSet存在哪个路径 | |||
""" | |||
with open(path, 'wb') as f: | |||
pickle.dump(self, f) | |||
@staticmethod | |||
def load(path): | |||
"""Load a DataSet object from pickle. | |||
"""从保存的DataSet pickle路径中读取DataSet | |||
:param str path: the path to the pickle | |||
:return data_set: | |||
:param str path: 从哪里读取DataSet | |||
:return: DataSet | |||
""" | |||
with open(path, 'rb') as f: | |||
return pickle.load(f) | |||
d = pickle.load(f) | |||
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | |||
return d | |||
def construct_dataset(sentences): | |||
@@ -1,93 +1,8 @@ | |||
import numpy as np | |||
from copy import deepcopy | |||
class PadderBase: | |||
""" | |||
所有padder都需要继承这个类,并覆盖__call__()方法。 | |||
用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。 | |||
""" | |||
def __init__(self, pad_val=0, **kwargs): | |||
self.pad_val = pad_val | |||
def set_pad_val(self, pad_val): | |||
self.pad_val = pad_val | |||
def __call__(self, contents, field_name, field_ele_dtype): | |||
""" | |||
传入的是List内容。假设有以下的DataSet。 | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
dataset = DataSet() | |||
dataset.append(Instance(word='this is a demo', length=4, | |||
chars=[['t', 'h', 'i', 's'], ['i', 's'], ['a'], ['d', 'e', 'm', 'o']])) | |||
dataset.append(Instance(word='another one', length=2, | |||
chars=[['a', 'n', 'o', 't', 'h', 'e', 'r'], ['o', 'n', 'e']])) | |||
# 如果batch_size=2, 下面只是用str的方式看起来更直观一点,但实际上可能word和chars在pad时都已经为index了。 | |||
word这个field的pad_func会接收到的内容会是 | |||
[ | |||
'this is a demo', | |||
'another one' | |||
] | |||
length这个field的pad_func会接收到的内容会是 | |||
[4, 2] | |||
chars这个field的pad_func会接收到的内容会是 | |||
[ | |||
[['t', 'h', 'i', 's'], ['i', 's'], ['a'], ['d', 'e', 'm', 'o']], | |||
[['a', 'n', 'o', 't', 'h', 'e', 'r'], ['o', 'n', 'e']] | |||
] | |||
即把每个instance中某个field的内容合成一个List传入 | |||
:param contents: List[element]。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | |||
deepcopy一份。 | |||
:param field_name: str, field的名称,帮助定位错误 | |||
:param field_ele_dtype: np.int64, np.float64, np.str. 该field的内层list元素的类型。辅助判断是否pad,大多数情况用不上 | |||
:return: List[padded_element]或np.array([padded_element]) | |||
""" | |||
raise NotImplementedError | |||
class AutoPadder(PadderBase): | |||
""" | |||
根据contents的数据自动判定是否需要做padding。 | |||
(1) 如果元素类型(元素类型是指field中最里层List的元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 | |||
型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行padding | |||
(2) 如果元素类型为(np.int64, np.float64), | |||
(2.1) 如果该field的内容只有一个,比如为sequence_length, 则不进行padding | |||
(2.2) 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 | |||
如果某个instance中field为[1, 2, 3],则可以pad; 若为[[1,2], [3,4, ...]]则不能进行pad | |||
""" | |||
def __init__(self, pad_val=0): | |||
""" | |||
:param pad_val: int, padding的位置使用该index | |||
""" | |||
super().__init__(pad_val=pad_val) | |||
def _is_two_dimension(self, contents): | |||
""" | |||
判断contents是不是只有两个维度。[[1,2], [3]]是两个维度. [[[1,2], [3, 4, 5]], [[4,5]]]有三个维度 | |||
:param contents: | |||
:return: | |||
""" | |||
value = contents[0] | |||
if isinstance(value , (np.ndarray, list)): | |||
value = value[0] | |||
if isinstance(value, (np.ndarray, list)): | |||
return False | |||
return True | |||
return False | |||
def __call__(self, contents, field_name, field_ele_dtype): | |||
if not is_iterable(contents[0]): | |||
array = np.array([content for content in contents], dtype=field_ele_dtype) | |||
elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents): | |||
max_len = max([len(content) for content in contents]) | |||
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) | |||
for i, content in enumerate(contents): | |||
array[i][:len(content)] = content | |||
elif field_ele_dtype is None: | |||
array = contents # 当ignore_type=True时,直接返回contents | |||
else: # should only be str | |||
array = np.array([content for content in contents]) | |||
return array | |||
import numpy as np | |||
from copy import deepcopy | |||
class FieldArray(object): | |||
@@ -98,13 +13,14 @@ class FieldArray(object): | |||
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray. | |||
:param bool is_target: If True, this FieldArray is used to compute loss. | |||
:param bool is_input: If True, this FieldArray is used to the model input. | |||
:param PadderBase padder: PadderBase类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过 | |||
:param Padder padder: PadderBase类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过 | |||
fieldarray.set_pad_val()。 | |||
默认为None,(1)如果某个field是scalar,则不进行任何padding;(2)如果为一维list, 且fieldarray的dtype为float或int类型 | |||
则会进行padding;(3)其它情况不进行padder。 | |||
假设需要对English word中character进行padding,则需要使用其他的padder。 | |||
或ignore_type为True但是需要进行padding。 | |||
:param bool ignore_type: whether to ignore type. If True, no type detection will rise for this FieldArray. (default: False) | |||
:param bool ignore_type: whether to ignore type. If True, no type detection will rise for this FieldArray. | |||
(default: False) | |||
""" | |||
def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False): | |||
@@ -147,7 +63,7 @@ class FieldArray(object): | |||
if padder is None: | |||
padder = AutoPadder(pad_val=0) | |||
else: | |||
assert isinstance(padder, PadderBase), "padder must be of type PadderBase." | |||
assert isinstance(padder, Padder), "padder must be of type Padder." | |||
padder = deepcopy(padder) | |||
self.set_padder(padder) | |||
self.ignore_type = ignore_type | |||
@@ -290,9 +206,10 @@ class FieldArray(object): | |||
return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | |||
def append(self, val): | |||
"""Add a new item to the tail of FieldArray. | |||
"""将val增加到FieldArray中,若该field的ignore_type为True则直接append到这个field中;若ignore_type为False,且当前field为 | |||
input或者target,则会检查传入的content是否与之前的内容在dimension, 元素的类型上是匹配的。 | |||
:param val: int, float, str, or a list of one. | |||
:param val: Any. | |||
""" | |||
if self.ignore_type is False: | |||
if isinstance(val, list): | |||
@@ -331,18 +248,18 @@ class FieldArray(object): | |||
self.content.append(val) | |||
def __getitem__(self, indices): | |||
return self.get(indices) | |||
return self.get(indices, pad=False) | |||
def __setitem__(self, idx, val): | |||
assert isinstance(idx, int) | |||
self.content[idx] = val | |||
def get(self, indices, pad=True): | |||
"""Fetch instances based on indices. | |||
"""根据给定的indices返回内容 | |||
:param indices: an int, or a list of int. | |||
:param pad: bool, 是否对返回的结果进行padding。 | |||
:return: | |||
:param indices: (int, List[int]), 获取indices对应的内容。 | |||
:param pad: bool, 是否对返回的结果进行padding。仅对indices为List[int]时有效 | |||
:return: (single, List) | |||
""" | |||
if isinstance(indices, int): | |||
return self.content[indices] | |||
@@ -357,23 +274,26 @@ class FieldArray(object): | |||
def set_padder(self, padder): | |||
""" | |||
设置padding方式 | |||
设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 | |||
:param padder: PadderBase类型或None. 设置为None即删除padder. | |||
:param padder: (None, Padder). 设置为None即删除padder. | |||
:return: | |||
""" | |||
if padder is not None: | |||
assert isinstance(padder, PadderBase), "padder must be of type PadderBase." | |||
self.padder = deepcopy(padder) | |||
assert isinstance(padder, Padder), "padder must be of type Padder." | |||
self.padder = deepcopy(padder) | |||
else: | |||
self.padder = None | |||
def set_pad_val(self, pad_val): | |||
""" | |||
修改padder的pad_val. | |||
:param pad_val: int。 | |||
"""修改padder的pad_val. | |||
:param pad_val: int。将该field的pad值设置为该值 | |||
:return: | |||
""" | |||
if self.padder is not None: | |||
self.padder.set_pad_val(pad_val) | |||
return self | |||
def __len__(self): | |||
@@ -385,8 +305,7 @@ class FieldArray(object): | |||
def to(self, other): | |||
""" | |||
将other的属性复制给本fieldarray(必须通过fieldarray类型). 包含 is_input, is_target, padder, dtype, pytype, content_dim | |||
ignore_type | |||
将other的属性复制给本FieldArray(other必须为FieldArray类型).属性包括 is_input, is_target, padder, ignore_type | |||
:param other: FieldArray | |||
:return: | |||
@@ -396,11 +315,10 @@ class FieldArray(object): | |||
self.is_input = other.is_input | |||
self.is_target = other.is_target | |||
self.padder = other.padder | |||
self.dtype = other.dtype | |||
self.pytype = other.pytype | |||
self.content_dim = other.content_dim | |||
self.ignore_type = other.ignore_type | |||
return self | |||
def is_iterable(content): | |||
try: | |||
_ = (e for e in content) | |||
@@ -409,17 +327,136 @@ def is_iterable(content): | |||
return True | |||
class EngChar2DPadder(PadderBase): | |||
class Padder: | |||
""" | |||
所有padder都需要继承这个类,并覆盖__call__()方法。 | |||
用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。 | |||
""" | |||
用于为英语执行character级别的2D padding操作。对应的field内容应该为[['T', 'h', 'i', 's'], ['a'], ['d', 'e', 'm', 'o']](这里为 | |||
了更直观,把它们写为str,但实际使用时它们应该是character的index)。 | |||
padded过后的batch内容,形状为(batch_size, max_sentence_length, max_word_length). max_sentence_length最大句子长度。 | |||
max_word_length最长的word的长度 | |||
def __init__(self, pad_val=0, **kwargs): | |||
self.pad_val = pad_val | |||
def set_pad_val(self, pad_val): | |||
self.pad_val = pad_val | |||
def __call__(self, contents, field_name, field_ele_dtype): | |||
""" | |||
传入的是List内容。假设有以下的DataSet。 | |||
:param contents: List[element]。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | |||
deepcopy一份。 | |||
:param field_name: str, field的名称。 | |||
:param field_ele_dtype: (np.int64, np.float64, np.str, None), 该field的内层元素的类型。如果该field的ignore_type | |||
为True,该这个值为None。 | |||
:return: np.array([padded_element]) | |||
Example:: | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
dataset = DataSet() | |||
dataset.append(Instance(sent='this is a demo', length=4, | |||
chars=[['t', 'h', 'i', 's'], ['i', 's'], ['a'], ['d', 'e', 'm', 'o']])) | |||
dataset.append(Instance(sent='another one', length=2, | |||
chars=[['a', 'n', 'o', 't', 'h', 'e', 'r'], ['o', 'n', 'e']])) | |||
如果调用 | |||
batch = dataset.get([0,1], pad=True) | |||
sent这个field的padder的__call__会接收到的内容会是 | |||
[ | |||
'this is a demo', | |||
'another one' | |||
] | |||
length这个field的padder的__call__会接收到的内容会是 | |||
[4, 2] | |||
chars这个field的padder的__call__会接收到的内容会是 | |||
[ | |||
[['t', 'h', 'i', 's'], ['i', 's'], ['a'], ['d', 'e', 'm', 'o']], | |||
[['a', 'n', 'o', 't', 'h', 'e', 'r'], ['o', 'n', 'e']] | |||
] | |||
即把每个instance中某个field的内容合成一个List传入 | |||
""" | |||
raise NotImplementedError | |||
class AutoPadder(Padder): | |||
""" | |||
def __init__(self, pad_val=0, pad_length=0): | |||
根据contents的数据自动判定是否需要做padding。 | |||
1 如果元素类型(元素类型是指field中最里层元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 | |||
型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行pad | |||
2 如果元素类型为(np.int64, np.float64), | |||
2.1 如果该field的内容为(np.int64, np.float64),比如为seq_len, 则不进行padding | |||
2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 | |||
如果某个instance中field为[1, 2, 3],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad | |||
""" | |||
def __init__(self, pad_val=0): | |||
""" | |||
:param pad_val: int, padding的位置使用该index | |||
""" | |||
super().__init__(pad_val=pad_val) | |||
def _is_two_dimension(self, contents): | |||
""" | |||
判断contents是不是只有两个维度。[[1,2], [3]]是两个维度. [[[1,2], [3, 4, 5]], [[4,5]]]有三个维度 | |||
:param contents: | |||
:return: | |||
""" | |||
value = contents[0] | |||
if isinstance(value, (np.ndarray, list)): | |||
value = value[0] | |||
if isinstance(value, (np.ndarray, list)): | |||
return False | |||
return True | |||
return False | |||
def __call__(self, contents, field_name, field_ele_dtype): | |||
if not is_iterable(contents[0]): | |||
array = np.array([content for content in contents], dtype=field_ele_dtype) | |||
elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents): | |||
max_len = max([len(content) for content in contents]) | |||
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) | |||
for i, content in enumerate(contents): | |||
array[i][:len(content)] = content | |||
elif field_ele_dtype is None: | |||
array = np.array(contents) # 当ignore_type=True时,直接返回contents | |||
else: # should only be str | |||
array = np.array([content for content in contents]) | |||
return array | |||
class EngChar2DPadder(Padder): | |||
""" | |||
用于为英语执行character级别的2D padding操作。对应的field内容应该类似[['T', 'h', 'i', 's'], ['a'], ['d', 'e', 'm', 'o']], | |||
但这个Padder只能处理index为int的情况。 | |||
padded过后的batch内容,形状为(batch_size, max_sentence_length, max_word_length). max_sentence_length为这个batch中最大句 | |||
子长度;max_word_length为这个batch中最长的word的长度 | |||
Example:: | |||
from fastNLP import DataSet | |||
from fastNLP import EnChar2DPadder | |||
from fastNLP import Vocabulary | |||
dataset = DataSet({'sent': ['This is the first demo', 'This is the second demo']}) | |||
dataset.apply(lambda ins:[list(word) for word in ins['sent'].split()], new_field_name='chars') | |||
vocab = Vocabulary() | |||
vocab.from_dataset(dataset, field_name='chars') | |||
vocab.index_dataset(dataset, field_name='chars') | |||
dataset.set_input('chars') | |||
padder = EnChar2DPadder() | |||
dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder | |||
""" | |||
def __init__(self, pad_val=0, pad_length=0): | |||
""" | |||
:param pad_val: int, pad的位置使用该index | |||
:param pad_length: int, 如果为0则取一个batch中最大的单词长度作为padding长度。如果为大于0的数,则将所有单词的长度都pad或截 | |||
取到该长度. | |||
""" | |||
@@ -1,13 +1,12 @@ | |||
class Instance(object): | |||
"""An Instance is an example of data. | |||
Example:: | |||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) | |||
ins["field_1"] | |||
>>[1, 1, 1] | |||
ins.add_field("field_3", [3, 3, 3]) | |||
:param fields: a dict of (str: list). | |||
Example:: | |||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) | |||
ins["field_1"] | |||
>>[1, 1, 1] | |||
ins.add_field("field_3", [3, 3, 3]) | |||
""" | |||
def __init__(self, **fields): | |||
@@ -272,7 +272,7 @@ def squash(predict, truth, **kwargs): | |||
:param predict: Tensor, model output | |||
:param truth: Tensor, truth from dataset | |||
:param **kwargs: extra arguments | |||
:param kwargs: extra arguments | |||
:return predict , truth: predict & truth after processing | |||
""" | |||
return predict.view(-1, predict.size()[-1]), truth.view(-1, ) | |||
@@ -316,7 +316,7 @@ def mask(predict, truth, **kwargs): | |||
:param predict: Tensor, [batch_size , max_len , tag_size] | |||
:param truth: Tensor, [batch_size , max_len] | |||
:param **kwargs: extra arguments, kwargs["mask"]: ByteTensor, [batch_size , max_len], the mask Tensor. The position that is 1 will be selected. | |||
:param kwargs: extra arguments, kwargs["mask"]: ByteTensor, [batch_size , max_len], the mask Tensor. The position that is 1 will be selected. | |||
:return predict , truth: predict & truth after processing | |||
""" | |||
@@ -17,66 +17,72 @@ class MetricBase(object): | |||
"""Base class for all metrics. | |||
所有的传入到Trainer, Tester的Metric需要继承自该对象。需要覆盖写入evaluate(), get_metric()方法。 | |||
evaluate(xxx)中传入的是一个batch的数据。 | |||
get_metric(xxx)当所有数据处理完毕,调用该方法得到最终的metric值 | |||
以分类问题中,Accuracy计算为例 | |||
假设model的forward返回dict中包含'pred'这个key, 并且该key需要用于Accuracy | |||
class Model(nn.Module): | |||
def __init__(xxx): | |||
# do something | |||
def forward(self, xxx): | |||
# do something | |||
return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes | |||
假设model的forward返回dict中包含'pred'这个key, 并且该key需要用于Accuracy:: | |||
class Model(nn.Module): | |||
def __init__(xxx): | |||
# do something | |||
def forward(self, xxx): | |||
# do something | |||
return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes | |||
假设dataset中'label'这个field是需要预测的值,并且该field被设置为了target | |||
对应的AccMetric可以按如下的定义 | |||
# version1, 只使用这一次 | |||
class AccMetric(MetricBase): | |||
def __init__(self): | |||
super().__init__() | |||
# 根据你的情况自定义指标 | |||
self.corr_num = 0 | |||
self.total = 0 | |||
def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value | |||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||
self.total += label.size(0) | |||
self.corr_num += label.eq(pred).sum().item() | |||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||
acc = self.corr_num/self.total | |||
if reset: # 是否清零以便重新计算 | |||
对应的AccMetric可以按如下的定义, version1, 只使用这一次:: | |||
class AccMetric(MetricBase): | |||
def __init__(self): | |||
super().__init__() | |||
# 根据你的情况自定义指标 | |||
self.corr_num = 0 | |||
self.total = 0 | |||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||
# version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred | |||
class AccMetric(MetricBase): | |||
def __init__(self, label=None, pred=None): | |||
# 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时, | |||
# acc_metric = AccMetric(label='y', pred='pred_y')即可。 | |||
# 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对 | |||
# 应的的值 | |||
super().__init__() | |||
self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可 | |||
# 如果没有注册该则效果与version1就是一样的 | |||
# 根据你的情况自定义指标 | |||
self.corr_num = 0 | |||
self.total = 0 | |||
def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。 | |||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||
self.total += label.size(0) | |||
self.corr_num += label.eq(pred).sum().item() | |||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||
acc = self.corr_num/self.total | |||
if reset: # 是否清零以便重新计算 | |||
def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value | |||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||
self.total += label.size(0) | |||
self.corr_num += label.eq(pred).sum().item() | |||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||
acc = self.corr_num/self.total | |||
if reset: # 是否清零以便重新计算 | |||
self.corr_num = 0 | |||
self.total = 0 | |||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||
version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred:: | |||
class AccMetric(MetricBase): | |||
def __init__(self, label=None, pred=None): | |||
# 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时, | |||
# acc_metric = AccMetric(label='y', pred='pred_y')即可。 | |||
# 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对 | |||
# 应的的值 | |||
super().__init__() | |||
self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可 | |||
# 如果没有注册该则效果与version1就是一样的 | |||
# 根据你的情况自定义指标 | |||
self.corr_num = 0 | |||
self.total = 0 | |||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||
def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。 | |||
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric | |||
self.total += label.size(0) | |||
self.corr_num += label.eq(pred).sum().item() | |||
def get_metric(self, reset=True): # 在这里定义如何计算metric | |||
acc = self.corr_num/self.total | |||
if reset: # 是否清零以便重新计算 | |||
self.corr_num = 0 | |||
self.total = 0 | |||
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 | |||
``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``. | |||
@@ -84,12 +90,12 @@ class MetricBase(object): | |||
``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``. | |||
``MetricBase`` will do the following type checks: | |||
1. whether self.evaluate has varargs, which is not supported. | |||
2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``. | |||
3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``. | |||
1. whether self.evaluate has varargs, which is not supported. | |||
2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``. | |||
3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``. | |||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | |||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | |||
target_dict which are not used in self.evaluate. (but if kwargs presented in self.evaluate, no filtering | |||
will be conducted.) | |||
""" | |||
@@ -388,23 +394,26 @@ class SpanFPreRecMetric(MetricBase): | |||
""" | |||
在序列标注问题中,以span的方式计算F, pre, rec. | |||
比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例) | |||
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 | |||
最后得到的metric结果为 | |||
{ | |||
'f': xxx, # 这里使用f考虑以后可以计算f_beta值 | |||
'pre': xxx, | |||
'rec':xxx | |||
} | |||
若only_gross=False, 即还会返回各个label的metric统计值 | |||
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 | |||
最后得到的metric结果为:: | |||
{ | |||
'f': xxx, | |||
'pre': xxx, | |||
'rec':xxx, | |||
'f-label': xxx, | |||
'pre-label': xxx, | |||
'rec-label':xxx, | |||
... | |||
} | |||
'f': xxx, # 这里使用f考虑以后可以计算f_beta值 | |||
'pre': xxx, | |||
'rec':xxx | |||
} | |||
若only_gross=False, 即还会返回各个label的metric统计值:: | |||
{ | |||
'f': xxx, | |||
'pre': xxx, | |||
'rec':xxx, | |||
'f-label': xxx, | |||
'pre-label': xxx, | |||
'rec-label':xxx, | |||
... | |||
} | |||
""" | |||
def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None, | |||
@@ -573,13 +582,21 @@ class BMESF1PreRecMetric(MetricBase): | |||
""" | |||
按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, | |||
next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B | |||
+-------+---------+----------+----------+---------+---------+ | |||
| | next_B | next_M | next_E | next_S | end | | |||
|:-----:|:-------:|:--------:|:--------:|:-------:|:-------:| | |||
| start | 合法 | next_M=B | next_E=S | 合法 | - | | |||
+=======+=========+==========+==========+=========+=========+ | |||
| start | 合法 | next_M=B | next_E=S | 合法 | -- | | |||
+-------+---------+----------+----------+---------+---------+ | |||
| cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | | |||
+-------+---------+----------+----------+---------+---------+ | |||
| cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | | |||
+-------+---------+----------+----------+---------+---------+ | |||
| cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||
+-------+---------+----------+----------+---------+---------+ | |||
| cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||
+-------+---------+----------+----------+---------+---------+ | |||
举例: | |||
prediction为BSEMS,会被认为是SSSSS. | |||
@@ -79,7 +79,7 @@ class Trainer(object): | |||
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | |||
# check update every | |||
assert update_every>=1, "update_every must be no less than 1." | |||
assert update_every >= 1, "update_every must be no less than 1." | |||
self.update_every = int(update_every) | |||
# check save_path | |||
@@ -120,7 +120,7 @@ class Trainer(object): | |||
self.use_cuda = bool(use_cuda) | |||
self.save_path = save_path | |||
self.print_every = int(print_every) | |||
self.validate_every = int(validate_every) if validate_every!=0 else -1 | |||
self.validate_every = int(validate_every) if validate_every != 0 else -1 | |||
self.best_metric_indicator = None | |||
self.best_dev_epoch = None | |||
self.best_dev_step = None | |||
@@ -129,7 +129,7 @@ class Trainer(object): | |||
self.prefetch = prefetch | |||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | |||
self.n_steps = (len(self.train_data) // self.batch_size + int( | |||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||
if isinstance(optimizer, torch.optim.Optimizer): | |||
self.optimizer = optimizer | |||
@@ -156,7 +156,6 @@ class Trainer(object): | |||
self.callback_manager = CallbackManager(env={"trainer": self}, | |||
callbacks=callbacks) | |||
def train(self, load_best_model=True): | |||
""" | |||
@@ -185,14 +184,15 @@ class Trainer(object): | |||
根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 | |||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | |||
最好的模型参数。 | |||
:return results: 返回一个字典类型的数据, 内含以下内容:: | |||
最好的模型参数。 | |||
:return results: 返回一个字典类型的数据, | |||
内含以下内容:: | |||
seconds: float, 表示训练时长 | |||
以下三个内容只有在提供了dev_data的情况下会有。 | |||
best_eval: Dict of Dict, 表示evaluation的结果 | |||
best_epoch: int,在第几个epoch取得的最佳值 | |||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||
seconds: float, 表示训练时长 | |||
以下三个内容只有在提供了dev_data的情况下会有。 | |||
best_eval: Dict of Dict, 表示evaluation的结果 | |||
best_epoch: int,在第几个epoch取得的最佳值 | |||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||
""" | |||
results = {} | |||
@@ -218,8 +218,9 @@ class Trainer(object): | |||
self.callback_manager.on_exception(e) | |||
if self.dev_data is not None and hasattr(self, 'best_dev_perf'): | |||
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||
self.tester._format_eval_results(self.best_dev_perf),) | |||
print( | |||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||
self.tester._format_eval_results(self.best_dev_perf), ) | |||
results['best_eval'] = self.best_dev_perf | |||
results['best_epoch'] = self.best_dev_epoch | |||
results['best_step'] = self.best_dev_step | |||
@@ -250,7 +251,7 @@ class Trainer(object): | |||
avg_loss = 0 | |||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||
prefetch=self.prefetch) | |||
for epoch in range(1, self.n_epochs+1): | |||
for epoch in range(1, self.n_epochs + 1): | |||
self.epoch = epoch | |||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||
# early stopping | |||
@@ -267,7 +268,7 @@ class Trainer(object): | |||
self.callback_manager.on_loss_begin(batch_y, prediction) | |||
loss = self._compute_loss(prediction, batch_y).mean() | |||
avg_loss += loss.item() | |||
loss = loss/self.update_every | |||
loss = loss / self.update_every | |||
# Is loss NaN or inf? requires_grad = False | |||
self.callback_manager.on_backward_begin(loss) | |||
@@ -277,8 +278,8 @@ class Trainer(object): | |||
self._update() | |||
self.callback_manager.on_step_end() | |||
if (self.step+1) % self.print_every == 0: | |||
avg_loss = avg_loss / self.print_every | |||
if self.step % self.print_every == 0: | |||
avg_loss = float(avg_loss) / self.print_every | |||
if self.use_tqdm: | |||
print_output = "loss:{0:<6.5f}".format(avg_loss) | |||
pbar.update(self.print_every) | |||
@@ -297,7 +298,7 @@ class Trainer(object): | |||
eval_res = self._do_validation(epoch=epoch, step=self.step) | |||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||
self.n_steps) + \ | |||
self.tester._format_eval_results(eval_res) | |||
self.tester._format_eval_results(eval_res) | |||
pbar.write(eval_str + '\n') | |||
# ================= mini-batch end ==================== # | |||
@@ -317,7 +318,7 @@ class Trainer(object): | |||
if self._better_eval_result(res): | |||
if self.save_path is not None: | |||
self._save_model(self.model, | |||
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) | |||
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) | |||
else: | |||
self._best_model_states = {name: param.cpu().clone() for name, param in self.model.named_parameters()} | |||
self.best_dev_perf = res | |||
@@ -344,7 +345,7 @@ class Trainer(object): | |||
"""Perform weight update on a model. | |||
""" | |||
if (self.step+1)%self.update_every==0: | |||
if (self.step + 1) % self.update_every == 0: | |||
self.optimizer.step() | |||
def _data_forward(self, network, x): | |||
@@ -361,7 +362,7 @@ class Trainer(object): | |||
For PyTorch, just do "loss.backward()" | |||
""" | |||
if self.step%self.update_every==0: | |||
if self.step % self.update_every == 0: | |||
self.model.zero_grad() | |||
loss.backward() | |||
@@ -437,6 +438,7 @@ class Trainer(object): | |||
DEFAULT_CHECK_BATCH_SIZE = 2 | |||
DEFAULT_CHECK_NUM_BATCH = 2 | |||
def _get_value_info(_dict): | |||
# given a dict value, return information about this dict's value. Return list of str | |||
strs = [] | |||
@@ -453,6 +455,7 @@ def _get_value_info(_dict): | |||
strs.append(_str) | |||
return strs | |||
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | |||
dev_data=None, metric_key=None, | |||
check_level=0): | |||
@@ -463,17 +466,17 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
for batch_count, (batch_x, batch_y) in enumerate(batch): | |||
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | |||
# forward check | |||
if batch_count==0: | |||
if batch_count == 0: | |||
info_str = "" | |||
input_fields = _get_value_info(batch_x) | |||
target_fields = _get_value_info(batch_y) | |||
if len(input_fields)>0: | |||
if len(input_fields) > 0: | |||
info_str += "input fields after batch(if batch size is {}):\n".format(batch_size) | |||
info_str += "\n".join(input_fields) | |||
info_str += '\n' | |||
else: | |||
raise RuntimeError("There is no input field.") | |||
if len(target_fields)>0: | |||
if len(target_fields) > 0: | |||
info_str += "target fields after batch(if batch size is {}):\n".format(batch_size) | |||
info_str += "\n".join(target_fields) | |||
info_str += '\n' | |||
@@ -481,7 +484,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
info_str += 'There is no target field.' | |||
print(info_str) | |||
_check_forward_error(forward_func=model.forward, dataset=dataset, | |||
batch_x=batch_x, check_level=check_level) | |||
batch_x=batch_x, check_level=check_level) | |||
refined_batch_x = _build_args(model.forward, **batch_x) | |||
pred_dict = model(**refined_batch_x) | |||
@@ -24,7 +24,7 @@ def _prepare_cache_filepath(filepath): | |||
if not os.path.exists(cache_dir): | |||
os.makedirs(cache_dir) | |||
# TODO 可以保存下缓存时的参数,如果load的时候发现参数不一致,发出警告。 | |||
def cache_results(cache_filepath, refresh=False, verbose=1): | |||
def wrapper_(func): | |||
signature = inspect.signature(func) | |||
@@ -197,17 +197,22 @@ def get_func_signature(func): | |||
Given a function or method, return its signature. | |||
For example: | |||
(1) function | |||
1 function:: | |||
def func(a, b='a', *args): | |||
xxxx | |||
get_func_signature(func) # 'func(a, b='a', *args)' | |||
(2) method | |||
2 method:: | |||
class Demo: | |||
def __init__(self): | |||
xxx | |||
def forward(self, a, b='a', **args) | |||
demo = Demo() | |||
get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)' | |||
:param func: a function or a method | |||
:return: str or None | |||
""" | |||
@@ -26,10 +26,10 @@ class ConfigLoader(BaseLoader): | |||
:param str file_path: the path of config file | |||
:param dict sections: the dict of ``{section_name(string): ConfigSection object}`` | |||
Example:: | |||
test_args = ConfigSection() | |||
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||
Example:: | |||
test_args = ConfigSection() | |||
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||
""" | |||
assert isinstance(sections, dict) | |||
@@ -1,71 +1,13 @@ | |||
import os | |||
import json | |||
from nltk.tree import Tree | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.io.base_loader import DataLoaderRegister | |||
from fastNLP.io.file_reader import read_csv, read_json, read_conll | |||
def convert_seq_dataset(data): | |||
"""Create an DataSet instance that contains no labels. | |||
:param data: list of list of strings, [num_examples, *]. | |||
Example:: | |||
[ | |||
[word_11, word_12, ...], | |||
... | |||
] | |||
:return: a DataSet. | |||
""" | |||
dataset = DataSet() | |||
for word_seq in data: | |||
dataset.append(Instance(word_seq=word_seq)) | |||
return dataset | |||
def convert_seq2tag_dataset(data): | |||
"""Convert list of data into DataSet. | |||
:param data: list of list of strings, [num_examples, *]. | |||
Example:: | |||
[ | |||
[ [word_11, word_12, ...], label_1 ], | |||
[ [word_21, word_22, ...], label_2 ], | |||
... | |||
] | |||
:return: a DataSet. | |||
""" | |||
dataset = DataSet() | |||
for sample in data: | |||
dataset.append(Instance(word_seq=sample[0], label=sample[1])) | |||
return dataset | |||
def convert_seq2seq_dataset(data): | |||
"""Convert list of data into DataSet. | |||
:param data: list of list of strings, [num_examples, *]. | |||
Example:: | |||
[ | |||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||
... | |||
] | |||
:return: a DataSet. | |||
""" | |||
dataset = DataSet() | |||
for sample in data: | |||
dataset.append(Instance(word_seq=sample[0], label_seq=sample[1])) | |||
return dataset | |||
def download_from_url(url, path): | |||
def _download_from_url(url, path): | |||
from tqdm import tqdm | |||
import requests | |||
@@ -81,7 +23,7 @@ def download_from_url(url, path): | |||
t.update(len(chunk)) | |||
return | |||
def uncompress(src, dst): | |||
def _uncompress(src, dst): | |||
import zipfile, gzip, tarfile, os | |||
def unzip(src, dst): | |||
@@ -134,241 +76,6 @@ class DataSetLoader: | |||
raise NotImplementedError | |||
class NativeDataSetLoader(DataSetLoader): | |||
"""A simple example of DataSetLoader | |||
""" | |||
def __init__(self): | |||
super(NativeDataSetLoader, self).__init__() | |||
def load(self, path): | |||
ds = DataSet.read_csv(path, headers=("raw_sentence", "label"), sep="\t") | |||
ds.set_input("raw_sentence") | |||
ds.set_target("label") | |||
return ds | |||
DataLoaderRegister.set_reader(NativeDataSetLoader, 'read_naive') | |||
class RawDataSetLoader(DataSetLoader): | |||
"""A simple example of raw data reader | |||
""" | |||
def __init__(self): | |||
super(RawDataSetLoader, self).__init__() | |||
def load(self, data_path, split=None): | |||
with open(data_path, "r", encoding="utf-8") as f: | |||
lines = f.readlines() | |||
lines = lines if split is None else [l.split(split) for l in lines] | |||
lines = list(filter(lambda x: len(x) > 0, lines)) | |||
return self.convert(lines) | |||
def convert(self, data): | |||
return convert_seq_dataset(data) | |||
DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') | |||
class DummyPOSReader(DataSetLoader): | |||
"""A simple reader for a dummy POS tagging dataset. | |||
In these datasets, each line are divided by "\t". The first Col is the vocabulary and the second | |||
Col is the label. Different sentence are divided by an empty line. | |||
E.g:: | |||
Tom label1 | |||
and label2 | |||
Jerry label1 | |||
. label3 | |||
(separated by an empty line) | |||
Hello label4 | |||
world label5 | |||
! label3 | |||
In this example, there are two sentences "Tom and Jerry ." and "Hello world !". Each word has its own label. | |||
""" | |||
def __init__(self): | |||
super(DummyPOSReader, self).__init__() | |||
def load(self, data_path): | |||
""" | |||
:return data: three-level list | |||
Example:: | |||
[ | |||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||
... | |||
] | |||
""" | |||
with open(data_path, "r", encoding="utf-8") as f: | |||
lines = f.readlines() | |||
data = self.parse(lines) | |||
return self.convert(data) | |||
@staticmethod | |||
def parse(lines): | |||
data = [] | |||
sentence = [] | |||
for line in lines: | |||
line = line.strip() | |||
if len(line) > 1: | |||
sentence.append(line.split('\t')) | |||
else: | |||
words = [] | |||
labels = [] | |||
for tokens in sentence: | |||
words.append(tokens[0]) | |||
labels.append(tokens[1]) | |||
data.append([words, labels]) | |||
sentence = [] | |||
if len(sentence) != 0: | |||
words = [] | |||
labels = [] | |||
for tokens in sentence: | |||
words.append(tokens[0]) | |||
labels.append(tokens[1]) | |||
data.append([words, labels]) | |||
return data | |||
def convert(self, data): | |||
"""Convert lists of strings into Instances with Fields. | |||
""" | |||
return convert_seq2seq_dataset(data) | |||
DataLoaderRegister.set_reader(DummyPOSReader, 'read_pos') | |||
class DummyCWSReader(DataSetLoader): | |||
"""Load pku dataset for Chinese word segmentation. | |||
""" | |||
def __init__(self): | |||
super(DummyCWSReader, self).__init__() | |||
def load(self, data_path, max_seq_len=32): | |||
"""Load pku dataset for Chinese word segmentation. | |||
CWS (Chinese Word Segmentation) pku training dataset format: | |||
1. Each line is a sentence. | |||
2. Each word in a sentence is separated by space. | |||
This function convert the pku dataset into three-level lists with labels <BMES>. | |||
B: beginning of a word | |||
M: middle of a word | |||
E: ending of a word | |||
S: single character | |||
:param str data_path: path to the data set. | |||
:param max_seq_len: int, the maximum length of a sequence. If a sequence is longer than it, split it into | |||
several sequences. | |||
:return: three-level lists | |||
""" | |||
assert isinstance(max_seq_len, int) and max_seq_len > 0 | |||
with open(data_path, "r", encoding="utf-8") as f: | |||
sentences = f.readlines() | |||
data = [] | |||
for sent in sentences: | |||
tokens = sent.strip().split() | |||
words = [] | |||
labels = [] | |||
for token in tokens: | |||
if len(token) == 1: | |||
words.append(token) | |||
labels.append("S") | |||
else: | |||
words.append(token[0]) | |||
labels.append("B") | |||
for idx in range(1, len(token) - 1): | |||
words.append(token[idx]) | |||
labels.append("M") | |||
words.append(token[-1]) | |||
labels.append("E") | |||
num_samples = len(words) // max_seq_len | |||
if len(words) % max_seq_len != 0: | |||
num_samples += 1 | |||
for sample_idx in range(num_samples): | |||
start = sample_idx * max_seq_len | |||
end = (sample_idx + 1) * max_seq_len | |||
seq_words = words[start:end] | |||
seq_labels = labels[start:end] | |||
data.append([seq_words, seq_labels]) | |||
return self.convert(data) | |||
def convert(self, data): | |||
return convert_seq2seq_dataset(data) | |||
class DummyClassificationReader(DataSetLoader): | |||
"""Loader for a dummy classification data set""" | |||
def __init__(self): | |||
super(DummyClassificationReader, self).__init__() | |||
def load(self, data_path): | |||
assert os.path.exists(data_path) | |||
with open(data_path, "r", encoding="utf-8") as f: | |||
lines = f.readlines() | |||
data = self.parse(lines) | |||
return self.convert(data) | |||
@staticmethod | |||
def parse(lines): | |||
"""每行第一个token是标签,其余是字/词;由空格分隔。 | |||
:param lines: lines from dataset | |||
:return: list(list(list())): the three level of lists are words, sentence, and dataset | |||
""" | |||
dataset = list() | |||
for line in lines: | |||
line = line.strip().split() | |||
label = line[0] | |||
words = line[1:] | |||
if len(words) <= 1: | |||
continue | |||
sentence = [words, label] | |||
dataset.append(sentence) | |||
return dataset | |||
def convert(self, data): | |||
return convert_seq2tag_dataset(data) | |||
class DummyLMReader(DataSetLoader): | |||
"""A Dummy Language Model Dataset Reader | |||
""" | |||
def __init__(self): | |||
super(DummyLMReader, self).__init__() | |||
def load(self, data_path): | |||
if not os.path.exists(data_path): | |||
raise FileNotFoundError("file {} not found.".format(data_path)) | |||
with open(data_path, "r", encoding="utf=8") as f: | |||
text = " ".join(f.readlines()) | |||
tokens = text.strip().split() | |||
data = self.sentence_cut(tokens) | |||
return self.convert(data) | |||
def sentence_cut(self, tokens, sentence_length=15): | |||
start_idx = 0 | |||
data_set = [] | |||
for idx in range(len(tokens) // sentence_length): | |||
x = tokens[start_idx * idx: start_idx * idx + sentence_length] | |||
y = tokens[start_idx * idx + 1: start_idx * idx + sentence_length + 1] | |||
if start_idx * idx + sentence_length + 1 >= len(tokens): | |||
# ad hoc | |||
y.extend(["<unk>"]) | |||
data_set.append([x, y]) | |||
return data_set | |||
def convert(self, data): | |||
pass | |||
class PeopleDailyCorpusLoader(DataSetLoader): | |||
"""人民日报数据集 | |||
""" | |||
@@ -448,8 +155,9 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||
class ConllLoader: | |||
def __init__(self, headers, indexs=None): | |||
def __init__(self, headers, indexs=None, dropna=True): | |||
self.headers = headers | |||
self.dropna = dropna | |||
if indexs is None: | |||
self.indexs = list(range(len(self.headers))) | |||
else: | |||
@@ -458,33 +166,10 @@ class ConllLoader: | |||
self.indexs = indexs | |||
def load(self, path): | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
sample = [] | |||
start = next(f) | |||
if '-DOCSTART-' not in start: | |||
sample.append(start.split()) | |||
for line in f: | |||
if line.startswith('\n'): | |||
if len(sample): | |||
datalist.append(sample) | |||
sample = [] | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
sample.append(line.split()) | |||
if len(sample) > 0: | |||
datalist.append(sample) | |||
data = [self.get_one(sample) for sample in datalist] | |||
data = filter(lambda x: x is not None, data) | |||
ds = DataSet() | |||
for sample in data: | |||
ins = Instance() | |||
for name, idx in zip(self.headers, self.indexs): | |||
ins.add_field(field_name=name, field=sample[idx]) | |||
ds.append(ins) | |||
for idx, data in read_conll(path, indexes=self.indexs, dropna=self.dropna): | |||
ins = {h:data[idx] for h, idx in zip(self.headers, self.indexs)} | |||
ds.append(Instance(**ins)) | |||
return ds | |||
def get_one(self, sample): | |||
@@ -499,9 +184,7 @@ class Conll2003Loader(ConllLoader): | |||
"""Loader for conll2003 dataset | |||
More information about the given dataset cound be found on | |||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | |||
Deprecated. Use ConllLoader for all types of conll-format files. | |||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | |||
""" | |||
def __init__(self): | |||
headers = [ | |||
@@ -510,194 +193,6 @@ class Conll2003Loader(ConllLoader): | |||
super(Conll2003Loader, self).__init__(headers=headers) | |||
class SNLIDataSetReader(DataSetLoader): | |||
"""A data set loader for SNLI data set. | |||
""" | |||
def __init__(self): | |||
super(SNLIDataSetReader, self).__init__() | |||
def load(self, path_list): | |||
""" | |||
:param list path_list: A list of file name, in the order of premise file, hypothesis file, and label file. | |||
:return: A DataSet object. | |||
""" | |||
assert len(path_list) == 3 | |||
line_set = [] | |||
for file in path_list: | |||
if not os.path.exists(file): | |||
raise FileNotFoundError("file {} NOT found".format(file)) | |||
with open(file, 'r', encoding='utf-8') as f: | |||
lines = f.readlines() | |||
line_set.append(lines) | |||
premise_lines, hypothesis_lines, label_lines = line_set | |||
assert len(premise_lines) == len(hypothesis_lines) and len(premise_lines) == len(label_lines) | |||
data_set = [] | |||
for premise, hypothesis, label in zip(premise_lines, hypothesis_lines, label_lines): | |||
p = premise.strip().split() | |||
h = hypothesis.strip().split() | |||
l = label.strip() | |||
data_set.append([p, h, l]) | |||
return self.convert(data_set) | |||
def convert(self, data): | |||
"""Convert a 3D list to a DataSet object. | |||
:param data: A 3D tensor. | |||
Example:: | |||
[ | |||
[ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ], | |||
[ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ], | |||
... | |||
] | |||
:return: A DataSet object. | |||
""" | |||
data_set = DataSet() | |||
for example in data: | |||
p, h, l = example | |||
# list, list, str | |||
instance = Instance() | |||
instance.add_field("premise", p) | |||
instance.add_field("hypothesis", h) | |||
instance.add_field("truth", l) | |||
data_set.append(instance) | |||
data_set.apply(lambda ins: len(ins["premise"]), new_field_name="premise_len") | |||
data_set.apply(lambda ins: len(ins["hypothesis"]), new_field_name="hypothesis_len") | |||
data_set.set_input("premise", "hypothesis", "premise_len", "hypothesis_len") | |||
data_set.set_target("truth") | |||
return data_set | |||
class ConllCWSReader(object): | |||
"""Deprecated. Use ConllLoader for all types of conll-format files.""" | |||
def __init__(self): | |||
pass | |||
def load(self, path, cut_long_sent=False): | |||
""" | |||
返回的DataSet只包含raw_sentence这个field,内容为str。 | |||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||
:: | |||
1 编者按 编者按 NN O 11 nmod:topic | |||
2 : : PU O 11 punct | |||
3 7月 7月 NT DATE 4 compound:nn | |||
4 12日 12日 NT DATE 11 nmod:tmod | |||
5 , , PU O 11 punct | |||
1 这 这 DT O 3 det | |||
2 款 款 M O 1 mark:clf | |||
3 飞行 飞行 NN O 8 nsubj | |||
4 从 从 P O 5 case | |||
5 外型 外型 NN O 8 nmod:prep | |||
""" | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
sample = [] | |||
for line in f: | |||
if line.startswith('\n'): | |||
datalist.append(sample) | |||
sample = [] | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
sample.append(line.strip().split()) | |||
if len(sample) > 0: | |||
datalist.append(sample) | |||
ds = DataSet() | |||
for sample in datalist: | |||
# print(sample) | |||
res = self.get_char_lst(sample) | |||
if res is None: | |||
continue | |||
line = ' '.join(res) | |||
if cut_long_sent: | |||
sents = cut_long_sentence(line) | |||
else: | |||
sents = [line] | |||
for raw_sentence in sents: | |||
ds.append(Instance(raw_sentence=raw_sentence)) | |||
return ds | |||
def get_char_lst(self, sample): | |||
if len(sample) == 0: | |||
return None | |||
text = [] | |||
for w in sample: | |||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||
if t3 == '_': | |||
return None | |||
text.append(t1) | |||
return text | |||
class NaiveCWSReader(DataSetLoader): | |||
""" | |||
这个reader假设了分词数据集为以下形式, 即已经用空格分割好内容了 | |||
例如:: | |||
这是 fastNLP , 一个 非常 good 的 包 . | |||
或者,即每个part后面还有一个pos tag | |||
例如:: | |||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | |||
""" | |||
def __init__(self, in_word_splitter=None): | |||
super(NaiveCWSReader, self).__init__() | |||
self.in_word_splitter = in_word_splitter | |||
def load(self, filepath, in_word_splitter=None, cut_long_sent=False): | |||
""" | |||
允许使用的情况有(默认以\t或空格作为seg) | |||
这是 fastNLP , 一个 非常 good 的 包 . | |||
和 | |||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | |||
如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0] | |||
:param filepath: | |||
:param in_word_splitter: | |||
:param cut_long_sent: | |||
:return: | |||
""" | |||
if in_word_splitter == None: | |||
in_word_splitter = self.in_word_splitter | |||
dataset = DataSet() | |||
with open(filepath, 'r') as f: | |||
for line in f: | |||
line = line.strip() | |||
if len(line.replace(' ', '')) == 0: # 不能接受空行 | |||
continue | |||
if not in_word_splitter is None: | |||
words = [] | |||
for part in line.split(): | |||
word = part.split(in_word_splitter)[0] | |||
words.append(word) | |||
line = ' '.join(words) | |||
if cut_long_sent: | |||
sents = cut_long_sentence(line) | |||
else: | |||
sents = [line] | |||
for sent in sents: | |||
instance = Instance(raw_sentence=sent) | |||
dataset.append(instance) | |||
return dataset | |||
def cut_long_sentence(sent, max_sample_length=200): | |||
""" | |||
将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。所以截取的句子可能长于或者短于max_sample_length | |||
@@ -727,103 +222,6 @@ def cut_long_sentence(sent, max_sample_length=200): | |||
return cutted_sentence | |||
class ZhConllPOSReader(object): | |||
"""读取中文Conll格式。返回“字级别”的标签,使用BMES记号扩展原来的词级别标签。 | |||
Deprecated. Use ConllLoader for all types of conll-format files. | |||
""" | |||
def __init__(self): | |||
pass | |||
def load(self, path): | |||
""" | |||
返回的DataSet, 包含以下的field | |||
words:list of str, | |||
tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] | |||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||
:: | |||
1 编者按 编者按 NN O 11 nmod:topic | |||
2 : : PU O 11 punct | |||
3 7月 7月 NT DATE 4 compound:nn | |||
4 12日 12日 NT DATE 11 nmod:tmod | |||
5 , , PU O 11 punct | |||
1 这 这 DT O 3 det | |||
2 款 款 M O 1 mark:clf | |||
3 飞行 飞行 NN O 8 nsubj | |||
4 从 从 P O 5 case | |||
5 外型 外型 NN O 8 nmod:prep | |||
""" | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
sample = [] | |||
for line in f: | |||
if line.startswith('\n'): | |||
datalist.append(sample) | |||
sample = [] | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
sample.append(line.split('\t')) | |||
if len(sample) > 0: | |||
datalist.append(sample) | |||
ds = DataSet() | |||
for sample in datalist: | |||
# print(sample) | |||
res = self.get_one(sample) | |||
if res is None: | |||
continue | |||
char_seq = [] | |||
pos_seq = [] | |||
for word, tag in zip(res[0], res[1]): | |||
char_seq.extend(list(word)) | |||
if len(word) == 1: | |||
pos_seq.append('S-{}'.format(tag)) | |||
elif len(word) > 1: | |||
pos_seq.append('B-{}'.format(tag)) | |||
for _ in range(len(word) - 2): | |||
pos_seq.append('M-{}'.format(tag)) | |||
pos_seq.append('E-{}'.format(tag)) | |||
else: | |||
raise ValueError("Zero length of word detected.") | |||
ds.append(Instance(words=char_seq, | |||
tag=pos_seq)) | |||
return ds | |||
def get_one(self, sample): | |||
if len(sample) == 0: | |||
return None | |||
text = [] | |||
pos_tags = [] | |||
for w in sample: | |||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||
if t3 == '_': | |||
return None | |||
text.append(t1) | |||
pos_tags.append(t2) | |||
return text, pos_tags | |||
class ConllxDataLoader(ConllLoader): | |||
"""返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。 | |||
Deprecated. Use ConllLoader for all types of conll-format files. | |||
""" | |||
def __init__(self): | |||
headers = [ | |||
'words', 'pos_tags', 'heads', 'labels', | |||
] | |||
indexs = [ | |||
1, 3, 6, 7, | |||
] | |||
super(ConllxDataLoader, self).__init__(headers=headers, indexs=indexs) | |||
class SSTLoader(DataSetLoader): | |||
"""load SST data in PTB tree format | |||
data source: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip | |||
@@ -842,10 +240,7 @@ class SSTLoader(DataSetLoader): | |||
""" | |||
:param path: str,存储数据的路径 | |||
:return: DataSet。内含field有'words', 'pos_tags', 'heads', 'labels'(parser的label) | |||
类似于拥有以下结构, 一行为一个instance(sample) | |||
words pos_tags heads labels | |||
['some', ..] ['NN', ...] [2, 3...] ['nn', 'nn'...] | |||
:return: DataSet。 | |||
""" | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
@@ -860,7 +255,6 @@ class SSTLoader(DataSetLoader): | |||
@staticmethod | |||
def get_one(data, subtree): | |||
from nltk.tree import Tree | |||
tree = Tree.fromstring(data) | |||
if subtree: | |||
return [(t.leaves(), t.label()) for t in tree.subtrees()] | |||
@@ -872,26 +266,72 @@ class JsonLoader(DataSetLoader): | |||
every line contains a json obj, like a dict | |||
fields is the dict key that need to be load | |||
""" | |||
def __init__(self, **fields): | |||
def __init__(self, dropna=False, fields=None): | |||
super(JsonLoader, self).__init__() | |||
self.fields = {} | |||
for k, v in fields.items(): | |||
self.fields[k] = k if v is None else v | |||
self.dropna = dropna | |||
self.fields = None | |||
self.fields_list = None | |||
if fields: | |||
self.fields = {} | |||
for k, v in fields.items(): | |||
self.fields[k] = k if v is None else v | |||
self.fields_list = list(self.fields.keys()) | |||
def load(self, path): | |||
ds = DataSet() | |||
for idx, d in read_json(path, fields=self.fields_list, dropna=self.dropna): | |||
ins = {self.fields[k]:v for k,v in d.items()} | |||
ds.append(Instance(**ins)) | |||
return ds | |||
class SNLILoader(JsonLoader): | |||
""" | |||
data source: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||
""" | |||
def __init__(self): | |||
fields = { | |||
'sentence1_parse': 'words1', | |||
'sentence2_parse': 'words2', | |||
'gold_label': 'target', | |||
} | |||
super(SNLILoader, self).__init__(fields=fields) | |||
def load(self, path): | |||
ds = super(SNLILoader, self).load(path) | |||
def parse_tree(x): | |||
t = Tree.fromstring(x) | |||
return t.leaves() | |||
ds.apply(lambda ins: parse_tree(ins['words1']), new_field_name='words1') | |||
ds.apply(lambda ins: parse_tree(ins['words2']), new_field_name='words2') | |||
ds.drop(lambda x: x['target'] == '-') | |||
return ds | |||
class CSVLoader(DataSetLoader): | |||
"""Load data from a CSV file and return a DataSet object. | |||
:param str csv_path: path to the CSV file | |||
:param List[str] or Tuple[str] headers: headers of the CSV file | |||
:param str sep: delimiter in CSV file. Default: "," | |||
:param bool dropna: If True, drop rows that have less entries than headers. | |||
:return dataset: the read data set | |||
""" | |||
def __init__(self, headers=None, sep=",", dropna=True): | |||
self.headers = headers | |||
self.sep = sep | |||
self.dropna = dropna | |||
def load(self, path): | |||
with open(path, 'r', encoding='utf-8') as f: | |||
datas = [json.loads(l) for l in f] | |||
ds = DataSet() | |||
for d in datas: | |||
ins = Instance() | |||
for k, v in d.items(): | |||
if k in self.fields: | |||
ins.add_field(self.fields[k], v) | |||
ds.append(ins) | |||
for idx, data in read_csv(path, headers=self.headers, | |||
sep=self.sep, dropna=self.dropna): | |||
ds.append(Instance(**data)) | |||
return ds | |||
def add_seg_tag(data): | |||
def _add_seg_tag(data): | |||
""" | |||
:param data: list of ([word], [pos], [heads], [head_tags]) | |||
@@ -132,7 +132,7 @@ class EmbedLoader(BaseLoader): | |||
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'): | |||
""" | |||
load pretraining embedding in {embed_file} based on words in vocab. Words in vocab but not in the pretraining | |||
embedding are initialized from a normal distribution which has the mean and std of the found words vectors. | |||
embedding are initialized from a normal distribution which has the mean and std of the found words vectors. | |||
The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). | |||
:param embed_filepath: str, where to read pretrain embedding | |||
@@ -0,0 +1,112 @@ | |||
import json | |||
def read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | |||
""" | |||
Construct a generator to read csv items | |||
:param path: file path | |||
:param encoding: file's encoding, default: utf-8 | |||
:param headers: file's headers, if None, make file's first line as headers. default: None | |||
:param sep: separator for each column. default: ',' | |||
:param dropna: weather to ignore and drop invalid data, | |||
if False, raise ValueError when reading invalid data. default: True | |||
:return: generator, every time yield (line number, csv item) | |||
""" | |||
with open(path, 'r', encoding=encoding) as f: | |||
start_idx = 0 | |||
if headers is None: | |||
headers = f.readline().rstrip('\r\n') | |||
headers = headers.split(sep) | |||
start_idx += 1 | |||
elif not isinstance(headers, (list, tuple)): | |||
raise TypeError("headers should be list or tuple, not {}." \ | |||
.format(type(headers))) | |||
for line_idx, line in enumerate(f, start_idx): | |||
contents = line.rstrip('\r\n').split(sep) | |||
if len(contents) != len(headers): | |||
if dropna: | |||
continue | |||
else: | |||
raise ValueError("Line {} has {} parts, while header has {} parts." \ | |||
.format(line_idx, len(contents), len(headers))) | |||
_dict = {} | |||
for header, content in zip(headers, contents): | |||
_dict[header] = content | |||
yield line_idx, _dict | |||
def read_json(path, encoding='utf-8', fields=None, dropna=True): | |||
""" | |||
Construct a generator to read json items | |||
:param path: file path | |||
:param encoding: file's encoding, default: utf-8 | |||
:param fields: json object's fields that needed, if None, all fields are needed. default: None | |||
:param dropna: weather to ignore and drop invalid data, | |||
if False, raise ValueError when reading invalid data. default: True | |||
:return: generator, every time yield (line number, json item) | |||
""" | |||
if fields: | |||
fields = set(fields) | |||
with open(path, 'r', encoding=encoding) as f: | |||
for line_idx, line in enumerate(f): | |||
data = json.loads(line) | |||
if fields is None: | |||
yield line_idx, data | |||
continue | |||
_res = {} | |||
for k, v in data.items(): | |||
if k in fields: | |||
_res[k] = v | |||
if len(_res) < len(fields): | |||
if dropna: | |||
continue | |||
else: | |||
raise ValueError('invalid instance at line: {}'.format(line_idx)) | |||
yield line_idx, _res | |||
def read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
""" | |||
Construct a generator to read conll items | |||
:param path: file path | |||
:param encoding: file's encoding, default: utf-8 | |||
:param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None | |||
:param dropna: weather to ignore and drop invalid data, | |||
if False, raise ValueError when reading invalid data. default: True | |||
:return: generator, every time yield (line number, conll item) | |||
""" | |||
def parse_conll(sample): | |||
sample = list(map(list, zip(*sample))) | |||
sample = [sample[i] for i in indexes] | |||
for f in sample: | |||
if len(f) <= 0: | |||
raise ValueError('empty field') | |||
return sample | |||
with open(path, 'r', encoding=encoding) as f: | |||
sample = [] | |||
start = next(f) | |||
if '-DOCSTART-' not in start: | |||
sample.append(start.split()) | |||
for line_idx, line in enumerate(f, 1): | |||
if line.startswith('\n'): | |||
if len(sample): | |||
try: | |||
res = parse_conll(sample) | |||
sample = [] | |||
yield line_idx, res | |||
except Exception as e: | |||
if dropna: | |||
continue | |||
raise ValueError('invalid instance at line: {}'.format(line_idx)) | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
sample.append(line.split()) | |||
if len(sample) > 0: | |||
try: | |||
res = parse_conll(sample) | |||
yield line_idx, res | |||
except Exception as e: | |||
if dropna: | |||
return | |||
raise ValueError('invalid instance at line: {}'.format(line_idx)) |
@@ -31,16 +31,18 @@ class ModelLoader(BaseLoader): | |||
class ModelSaver(object): | |||
"""Save a model | |||
Example:: | |||
:param str save_path: the path to the saving directory. | |||
Example:: | |||
saver = ModelSaver("./save/model_ckpt_100.pkl") | |||
saver.save_pytorch(model) | |||
saver = ModelSaver("./save/model_ckpt_100.pkl") | |||
saver.save_pytorch(model) | |||
""" | |||
def __init__(self, save_path): | |||
""" | |||
:param save_path: the path to the saving directory. | |||
""" | |||
self.save_path = save_path | |||
def save_pytorch(self, model, param_only=True): | |||
@@ -20,16 +20,23 @@ class Highway(nn.Module): | |||
class CharLM(nn.Module): | |||
"""CNN + highway network + LSTM | |||
# Input: | |||
# Input:: | |||
4D tensor with shape [batch_size, in_channel, height, width] | |||
# Output: | |||
# Output:: | |||
2D Tensor with shape [batch_size, vocab_size] | |||
# Arguments: | |||
# Arguments:: | |||
char_emb_dim: the size of each character's attention | |||
word_emb_dim: the size of each word's attention | |||
vocab_size: num of unique words | |||
num_char: num of characters | |||
use_gpu: True or False | |||
""" | |||
def __init__(self, char_emb_dim, word_emb_dim, | |||
@@ -65,13 +65,14 @@ class ENASTrainer(fastNLP.Trainer): | |||
""" | |||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | |||
最好的模型参数。 | |||
:return results: 返回一个字典类型的数据, 内含以下内容:: | |||
seconds: float, 表示训练时长 | |||
以下三个内容只有在提供了dev_data的情况下会有。 | |||
best_eval: Dict of Dict, 表示evaluation的结果 | |||
best_epoch: int,在第几个epoch取得的最佳值 | |||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||
:return results: 返回一个字典类型的数据, | |||
内含以下内容:: | |||
seconds: float, 表示训练时长 | |||
以下三个内容只有在提供了dev_data的情况下会有。 | |||
best_eval: Dict of Dict, 表示evaluation的结果 | |||
best_epoch: int,在第几个epoch取得的最佳值 | |||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||
""" | |||
results = {} | |||
@@ -79,7 +79,7 @@ class SeqLabeling(BaseModel): | |||
:return prediction: list of [decode path(list)] | |||
""" | |||
max_len = x.shape[1] | |||
tag_seq = self.Crf.viterbi_decode(x, self.mask) | |||
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask) | |||
# pad prediction to equal length | |||
if pad is True: | |||
for pred in tag_seq: | |||
@@ -1,6 +1,5 @@ | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from fastNLP.models.base_model import BaseModel | |||
from fastNLP.modules import decoder as Decoder | |||
@@ -40,7 +39,7 @@ class ESIM(BaseModel): | |||
batch_first=self.batch_first, bidirectional=True | |||
) | |||
self.bi_attention = Aggregator.Bi_Attention() | |||
self.bi_attention = Aggregator.BiAttention() | |||
self.mean_pooling = Aggregator.MeanPoolWithMask() | |||
self.max_pooling = Aggregator.MaxPoolWithMask() | |||
@@ -53,23 +52,23 @@ class ESIM(BaseModel): | |||
self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) | |||
def forward(self, premise, hypothesis, premise_len, hypothesis_len): | |||
def forward(self, words1, words2, seq_len1, seq_len2): | |||
""" Forward function | |||
:param premise: A Tensor represents premise: [batch size(B), premise seq len(PL)]. | |||
:param hypothesis: A Tensor represents hypothesis: [B, hypothesis seq len(HL)]. | |||
:param premise_len: A Tensor record which is a real word and which is a padding word in premise: [B, PL]. | |||
:param hypothesis_len: A Tensor record which is a real word and which is a padding word in hypothesis: [B, HL]. | |||
:param words1: A Tensor represents premise: [batch size(B), premise seq len(PL)]. | |||
:param words2: A Tensor represents hypothesis: [B, hypothesis seq len(HL)]. | |||
:param seq_len1: A Tensor record which is a real word and which is a padding word in premise: [B]. | |||
:param seq_len2: A Tensor record which is a real word and which is a padding word in hypothesis: [B]. | |||
:return: prediction: A Dict with Tensor of classification result: [B, n_labels(N)]. | |||
""" | |||
premise0 = self.embedding_layer(self.embedding(premise)) | |||
hypothesis0 = self.embedding_layer(self.embedding(hypothesis)) | |||
premise0 = self.embedding_layer(self.embedding(words1)) | |||
hypothesis0 = self.embedding_layer(self.embedding(words2)) | |||
_BP, _PSL, _HP = premise0.size() | |||
_BH, _HSL, _HH = hypothesis0.size() | |||
_BPL, _PLL = premise_len.size() | |||
_HPL, _HLL = hypothesis_len.size() | |||
_BPL, _PLL = seq_len1.size() | |||
_HPL, _HLL = seq_len2.size() | |||
assert _BP == _BH and _BPL == _HPL and _BP == _BPL | |||
assert _HP == _HH | |||
@@ -84,7 +83,7 @@ class ESIM(BaseModel): | |||
a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H] | |||
b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H] | |||
ai, bi = self.bi_attention(a, b, premise_len, hypothesis_len) | |||
ai, bi = self.bi_attention(a, b, seq_len1, seq_len2) | |||
ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H] | |||
mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H] | |||
@@ -98,17 +97,18 @@ class ESIM(BaseModel): | |||
va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H] | |||
vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H] | |||
va_ave = self.mean_pooling(va, premise_len, dim=1) # va_ave: [B, H] | |||
va_max, va_arg_max = self.max_pooling(va, premise_len, dim=1) # va_max: [B, H] | |||
vb_ave = self.mean_pooling(vb, hypothesis_len, dim=1) # vb_ave: [B, H] | |||
vb_max, vb_arg_max = self.max_pooling(vb, hypothesis_len, dim=1) # vb_max: [B, H] | |||
va_ave = self.mean_pooling(va, seq_len1, dim=1) # va_ave: [B, H] | |||
va_max, va_arg_max = self.max_pooling(va, seq_len1, dim=1) # va_max: [B, H] | |||
vb_ave = self.mean_pooling(vb, seq_len2, dim=1) # vb_ave: [B, H] | |||
vb_max, vb_arg_max = self.max_pooling(vb, seq_len2, dim=1) # vb_max: [B, H] | |||
v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H] | |||
prediction = F.tanh(self.output(v)) # prediction: [B, N] | |||
prediction = torch.tanh(self.output(v)) # prediction: [B, N] | |||
return {'pred': prediction} | |||
def predict(self, premise, hypothesis, premise_len, hypothesis_len): | |||
return self.forward(premise, hypothesis, premise_len, hypothesis_len) | |||
def predict(self, words1, words2, seq_len1, seq_len2): | |||
prediction = self.forward(words1, words2, seq_len1, seq_len2)['pred'] | |||
return {'pred': torch.argmax(prediction, dim=-1)} | |||
@@ -5,6 +5,6 @@ from .avg_pool import MeanPoolWithMask | |||
from .kmax_pool import KMaxPool | |||
from .attention import Attention | |||
from .attention import Bi_Attention | |||
from .attention import BiAttention | |||
from .self_attention import SelfAttention | |||
@@ -23,9 +23,9 @@ class Attention(torch.nn.Module): | |||
raise NotImplementedError | |||
class DotAtte(nn.Module): | |||
class DotAttention(nn.Module): | |||
def __init__(self, key_size, value_size, dropout=0.1): | |||
super(DotAtte, self).__init__() | |||
super(DotAttention, self).__init__() | |||
self.key_size = key_size | |||
self.value_size = value_size | |||
self.scale = math.sqrt(key_size) | |||
@@ -48,7 +48,7 @@ class DotAtte(nn.Module): | |||
return torch.matmul(output, V) | |||
class MultiHeadAtte(nn.Module): | |||
class MultiHeadAttention(nn.Module): | |||
def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): | |||
""" | |||
@@ -58,7 +58,7 @@ class MultiHeadAtte(nn.Module): | |||
:param num_head: int,head的数量。 | |||
:param dropout: float。 | |||
""" | |||
super(MultiHeadAtte, self).__init__() | |||
super(MultiHeadAttention, self).__init__() | |||
self.input_size = input_size | |||
self.key_size = key_size | |||
self.value_size = value_size | |||
@@ -68,7 +68,7 @@ class MultiHeadAtte(nn.Module): | |||
self.q_in = nn.Linear(input_size, in_size) | |||
self.k_in = nn.Linear(input_size, in_size) | |||
self.v_in = nn.Linear(input_size, in_size) | |||
self.attention = DotAtte(key_size=key_size, value_size=value_size) | |||
self.attention = DotAttention(key_size=key_size, value_size=value_size) | |||
self.out = nn.Linear(value_size * num_head, input_size) | |||
self.drop = TimestepDropout(dropout) | |||
self.reset_parameters() | |||
@@ -109,16 +109,34 @@ class MultiHeadAtte(nn.Module): | |||
return output | |||
class Bi_Attention(nn.Module): | |||
class BiAttention(nn.Module): | |||
"""Bi Attention module | |||
Calculate Bi Attention matrix `e` | |||
.. math:: | |||
\begin{array}{ll} \\ | |||
e_ij = {a}^{\mathbf{T}}_{i}{b}_{j} \\ | |||
a_i = | |||
b_j = | |||
\end{array} | |||
""" | |||
def __init__(self): | |||
super(Bi_Attention, self).__init__() | |||
super(BiAttention, self).__init__() | |||
self.inf = 10e12 | |||
def forward(self, in_x1, in_x2, x1_len, x2_len): | |||
# in_x1: [batch_size, x1_seq_len, hidden_size] | |||
# in_x2: [batch_size, x2_seq_len, hidden_size] | |||
# x1_len: [batch_size, x1_seq_len] | |||
# x2_len: [batch_size, x2_seq_len] | |||
""" | |||
:param torch.Tensor in_x1: [batch_size, x1_seq_len, hidden_size] 第一句的特征表示 | |||
:param torch.Tensor in_x2: [batch_size, x2_seq_len, hidden_size] 第二句的特征表示 | |||
:param torch.Tensor x1_len: [batch_size, x1_seq_len] 第一句的0/1mask矩阵 | |||
:param torch.Tensor x2_len: [batch_size, x2_seq_len] 第二句的0/1mask矩阵 | |||
:return: torch.Tensor out_x1: [batch_size, x1_seq_len, hidden_size] 第一句attend到的特征表示 | |||
torch.Tensor out_x2: [batch_size, x2_seq_len, hidden_size] 第一句attend到的特征表示 | |||
""" | |||
assert in_x1.size()[0] == in_x2.size()[0] | |||
assert in_x1.size()[2] == in_x2.size()[2] | |||
@@ -2,12 +2,7 @@ import torch | |||
from torch import nn | |||
from fastNLP.modules.utils import initial_parameter | |||
def log_sum_exp(x, dim=-1): | |||
max_value, _ = x.max(dim=dim, keepdim=True) | |||
res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value | |||
return res.squeeze(dim) | |||
from fastNLP.modules.decoder.utils import log_sum_exp | |||
def seq_len_to_byte_mask(seq_lens): | |||
@@ -20,22 +15,27 @@ def seq_len_to_byte_mask(seq_lens): | |||
return mask | |||
def allowed_transitions(id2label, encoding_type='bio'): | |||
def allowed_transitions(id2label, encoding_type='bio', include_start_end=True): | |||
""" | |||
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | |||
:param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 | |||
:param id2label: Dict, key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()得到id2label。 | |||
:param encoding_type: str, 支持"bio", "bmes", "bmeso"。 | |||
:return: List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 | |||
位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). | |||
start_idx=len(id2label), end_idx=len(id2label)+1。 | |||
:param include_start_end: bool, 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | |||
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | |||
start_idx=len(id2label), end_idx=len(id2label)+1。 | |||
为False, 返回的结果中不含与开始结尾相关的内容 | |||
:return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 | |||
""" | |||
num_tags = len(id2label) | |||
start_idx = num_tags | |||
end_idx = num_tags + 1 | |||
encoding_type = encoding_type.lower() | |||
allowed_trans = [] | |||
id_label_lst = list(id2label.items()) + [(start_idx, 'start'), (end_idx, 'end')] | |||
id_label_lst = list(id2label.items()) | |||
if include_start_end: | |||
id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] | |||
def split_tag_label(from_label): | |||
from_label = from_label.lower() | |||
if from_label in ['start', 'end']: | |||
@@ -54,12 +54,12 @@ def allowed_transitions(id2label, encoding_type='bio'): | |||
if to_label in ['<pad>', '<unk>']: | |||
continue | |||
to_tag, to_label = split_tag_label(to_label) | |||
if is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
if _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
allowed_trans.append((from_id, to_id)) | |||
return allowed_trans | |||
def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
""" | |||
:param encoding_type: str, 支持"BIO", "BMES", "BEMSO"。 | |||
@@ -140,20 +140,22 @@ def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label) | |||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | |||
else: | |||
raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type)) | |||
raise ValueError("Only support BIO, BMES, BMESO encoding type, got {}.".format(encoding_type)) | |||
class ConditionalRandomField(nn.Module): | |||
""" | |||
:param int num_tags: 标签的数量。 | |||
:param bool include_start_end_trans: 是否包含起始tag | |||
:param list allowed_transitions: ``List[Tuple[from_tag_id(int), to_tag_id(int)]]``. 允许的跃迁,可以通过allowed_transitions()得到。 | |||
如果为None,则所有跃迁均为合法 | |||
:param str initial_method: | |||
""" | |||
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, initial_method=None): | |||
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, | |||
initial_method=None): | |||
"""条件随机场。 | |||
提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。 | |||
:param num_tags: int, 标签的数量 | |||
:param include_start_end_trans: bool, 是否考虑各个tag作为开始以及结尾的分数。 | |||
:param allowed_transitions: List[Tuple[from_tag_id(int), to_tag_id(int)]], 内部的Tuple[from_tag_id(int), | |||
to_tag_id(int)]视为允许发生的跃迁,其他没有包含的跃迁认为是禁止跃迁,可以通过 | |||
allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 | |||
:param initial_method: str, 初始化方法。见initial_parameter | |||
""" | |||
super(ConditionalRandomField, self).__init__() | |||
self.include_start_end_trans = include_start_end_trans | |||
@@ -168,18 +170,12 @@ class ConditionalRandomField(nn.Module): | |||
if allowed_transitions is None: | |||
constrain = torch.zeros(num_tags + 2, num_tags + 2) | |||
else: | |||
constrain = torch.ones(num_tags + 2, num_tags + 2) * -1000 | |||
constrain = torch.new_full((num_tags+2, num_tags+2), fill_value=-10000.0, dtype=torch.float) | |||
for from_tag_id, to_tag_id in allowed_transitions: | |||
constrain[from_tag_id, to_tag_id] = 0 | |||
self._constrain = nn.Parameter(constrain, requires_grad=False) | |||
# self.reset_parameter() | |||
initial_parameter(self, initial_method) | |||
def reset_parameter(self): | |||
nn.init.xavier_normal_(self.trans_m) | |||
if self.include_start_end_trans: | |||
nn.init.normal_(self.start_scores) | |||
nn.init.normal_(self.end_scores) | |||
def _normalizer_likelihood(self, logits, mask): | |||
"""Computes the (batch_size,) denominator term for the log-likelihood, which is the | |||
@@ -239,10 +235,11 @@ class ConditionalRandomField(nn.Module): | |||
def forward(self, feats, tags, mask): | |||
""" | |||
Calculate the neg log likelihood | |||
:param feats:FloatTensor, batch_size x max_len x num_tags | |||
:param tags:LongTensor, batch_size x max_len | |||
:param mask:ByteTensor batch_size x max_len | |||
用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 | |||
:param feats:FloatTensor, batch_size x max_len x num_tags,特征矩阵。 | |||
:param tags:LongTensor, batch_size x max_len,标签矩阵。 | |||
:param mask:ByteTensor batch_size x max_len,为0的位置认为是padding。 | |||
:return:FloatTensor, batch_size | |||
""" | |||
feats = feats.transpose(0, 1) | |||
@@ -253,28 +250,27 @@ class ConditionalRandomField(nn.Module): | |||
return all_path_score - gold_path_score | |||
def viterbi_decode(self, data, mask, get_score=False, unpad=False): | |||
"""Given a feats matrix, return best decode path and best score. | |||
def viterbi_decode(self, feats, mask, unpad=False): | |||
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | |||
:param data:FloatTensor, batch_size x max_len x num_tags | |||
:param mask:ByteTensor batch_size x max_len | |||
:param get_score: bool, whether to output the decode score. | |||
:param unpad: bool, 是否将结果unpad, | |||
如果False, 返回的是batch_size x max_len的tensor, | |||
如果True,返回的是List[List[int]], List[int]为每个sequence的label,已经unpadding了,即每个 | |||
List[int]的长度是这个sample的有效长度 | |||
:return: 如果get_score为False,返回结果根据unpadding变动 | |||
如果get_score为True, 返回 (paths, List[float], )。第一个仍然是解码后的路径(根据unpad变化),第二个List[Float] | |||
为每个seqence的解码分数。 | |||
:param feats: FloatTensor, batch_size x max_len x num_tags,特征矩阵。 | |||
:param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||
:param unpad: bool, 是否将结果删去padding, | |||
False, 返回的是batch_size x max_len的tensor, | |||
True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int] | |||
的长度是这个sample的有效长度。 | |||
:return: 返回 (paths, scores)。 | |||
paths: 是解码后的路径, 其值参照unpad参数. | |||
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 | |||
""" | |||
batch_size, seq_len, n_tags = data.size() | |||
data = data.transpose(0, 1).data # L, B, H | |||
batch_size, seq_len, n_tags = feats.size() | |||
feats = feats.transpose(0, 1).data # L, B, H | |||
mask = mask.transpose(0, 1).data.byte() # L, B | |||
# dp | |||
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||
vscore = data[0] | |||
vpath = feats.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||
vscore = feats[0] | |||
transitions = self._constrain.data.clone() | |||
transitions[:n_tags, :n_tags] += self.trans_m.data | |||
if self.include_start_end_trans: | |||
@@ -285,23 +281,24 @@ class ConditionalRandomField(nn.Module): | |||
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | |||
for i in range(1, seq_len): | |||
prev_score = vscore.view(batch_size, n_tags, 1) | |||
cur_score = data[i].view(batch_size, 1, n_tags) | |||
cur_score = feats[i].view(batch_size, 1, n_tags) | |||
score = prev_score + trans_score + cur_score | |||
best_score, best_dst = score.max(1) | |||
vpath[i] = best_dst | |||
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | |||
vscore.masked_fill(mask[i].view(batch_size, 1), 0) | |||
vscore += transitions[:n_tags, n_tags+1].view(1, -1) | |||
if self.include_start_end_trans: | |||
vscore += transitions[:n_tags, n_tags+1].view(1, -1) | |||
# backtrace | |||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device) | |||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device) | |||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=feats.device) | |||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=feats.device) | |||
lens = (mask.long().sum(0) - 1) | |||
# idxes [L, B], batched idx from seq_len-1 to 0 | |||
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len | |||
ans = data.new_empty((seq_len, batch_size), dtype=torch.long) | |||
ans = feats.new_empty((seq_len, batch_size), dtype=torch.long) | |||
ans_score, last_tags = vscore.max(1) | |||
ans[idxes[0], batch_idx] = last_tags | |||
for i in range(seq_len - 1): | |||
@@ -0,0 +1,70 @@ | |||
import torch | |||
def log_sum_exp(x, dim=-1): | |||
max_value, _ = x.max(dim=dim, keepdim=True) | |||
res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value | |||
return res.squeeze(dim) | |||
def viterbi_decode(feats, transitions, mask=None, unpad=False): | |||
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | |||
:param feats: FloatTensor, batch_size x max_len x num_tags,特征矩阵。 | |||
:param transitions: FloatTensor, n_tags x n_tags。[i, j]位置的值认为是从tag i到tag j的转换。 | |||
:param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||
:param unpad: bool, 是否将结果删去padding, | |||
False, 返回的是batch_size x max_len的tensor, | |||
True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是 | |||
这个sample的有效长度。 | |||
:return: 返回 (paths, scores)。 | |||
paths: 是解码后的路径, 其值参照unpad参数. | |||
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 | |||
""" | |||
batch_size, seq_len, n_tags = feats.size() | |||
assert n_tags==transitions.size(0) and n_tags==transitions.size(1), "The shapes of transitions and feats are not " \ | |||
"compatible." | |||
feats = feats.transpose(0, 1).data # L, B, H | |||
if mask is not None: | |||
mask = mask.transpose(0, 1).data.byte() # L, B | |||
else: | |||
mask = feats.new_ones((seq_len, batch_size), dtype=torch.uint8) | |||
# dp | |||
vpath = feats.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||
vscore = feats[0] | |||
vscore += transitions[n_tags, :n_tags] | |||
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | |||
for i in range(1, seq_len): | |||
prev_score = vscore.view(batch_size, n_tags, 1) | |||
cur_score = feats[i].view(batch_size, 1, n_tags) | |||
score = prev_score + trans_score + cur_score | |||
best_score, best_dst = score.max(1) | |||
vpath[i] = best_dst | |||
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | |||
vscore.masked_fill(mask[i].view(batch_size, 1), 0) | |||
# backtrace | |||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=feats.device) | |||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=feats.device) | |||
lens = (mask.long().sum(0) - 1) | |||
# idxes [L, B], batched idx from seq_len-1 to 0 | |||
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len | |||
ans = feats.new_empty((seq_len, batch_size), dtype=torch.long) | |||
ans_score, last_tags = vscore.max(1) | |||
ans[idxes[0], batch_idx] = last_tags | |||
for i in range(seq_len - 1): | |||
last_tags = vpath[idxes[i], batch_idx, last_tags] | |||
ans[idxes[i + 1], batch_idx] = last_tags | |||
ans = ans.transpose(0, 1) | |||
if unpad: | |||
paths = [] | |||
for idx, seq_len in enumerate(lens): | |||
paths.append(ans[idx, :seq_len + 1].tolist()) | |||
else: | |||
paths = ans | |||
return paths, ans_score |
@@ -1,4 +1,6 @@ | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.utils.rnn as rnn | |||
from fastNLP.modules.utils import initial_parameter | |||
@@ -19,21 +21,44 @@ class LSTM(nn.Module): | |||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, | |||
bidirectional=False, bias=True, initial_method=None, get_hidden=False): | |||
super(LSTM, self).__init__() | |||
self.batch_first = batch_first | |||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, | |||
dropout=dropout, bidirectional=bidirectional) | |||
self.get_hidden = get_hidden | |||
initial_parameter(self, initial_method) | |||
def forward(self, x, h0=None, c0=None): | |||
def forward(self, x, seq_lens=None, h0=None, c0=None): | |||
if h0 is not None and c0 is not None: | |||
x, (ht, ct) = self.lstm(x, (h0, c0)) | |||
hx = (h0, c0) | |||
else: | |||
x, (ht, ct) = self.lstm(x) | |||
if self.get_hidden: | |||
return x, (ht, ct) | |||
hx = None | |||
if seq_lens is not None and not isinstance(x, rnn.PackedSequence): | |||
print('padding') | |||
sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) | |||
if self.batch_first: | |||
x = x[sort_idx] | |||
else: | |||
x = x[:, sort_idx] | |||
x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first) | |||
output, hx = self.lstm(x, hx) # -> [N,L,C] | |||
output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first) | |||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||
if self.batch_first: | |||
output = output[unsort_idx] | |||
else: | |||
output = output[:, unsort_idx] | |||
else: | |||
return x | |||
output, hx = self.lstm(x, hx) | |||
if self.get_hidden: | |||
return output, hx | |||
return output | |||
if __name__ == "__main__": | |||
lstm = LSTM(10) | |||
lstm = LSTM(input_size=2, hidden_size=2, get_hidden=False) | |||
x = torch.randn((3, 5, 2)) | |||
seq_lens = torch.tensor([5,1,2]) | |||
y = lstm(x, seq_lens) | |||
print(x) | |||
print(y) | |||
print(x.size(), y.size(), ) |
@@ -1,6 +1,6 @@ | |||
from torch import nn | |||
from ..aggregator.attention import MultiHeadAtte | |||
from ..aggregator.attention import MultiHeadAttention | |||
from ..dropout import TimestepDropout | |||
@@ -18,7 +18,7 @@ class TransformerEncoder(nn.Module): | |||
class SubLayer(nn.Module): | |||
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | |||
super(TransformerEncoder.SubLayer, self).__init__() | |||
self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout) | |||
self.atte = MultiHeadAttention(model_size, key_size, value_size, num_head, dropout) | |||
self.norm1 = nn.LayerNorm(model_size) | |||
self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), | |||
nn.ReLU(), | |||
@@ -183,7 +183,7 @@ class CWSBiLSTMCRF(BaseModel): | |||
masks = seq_lens_to_mask(seq_lens) | |||
feats = self.encoder_model(chars, bigrams, seq_lens) | |||
feats = self.decoder_model(feats) | |||
probs = self.crf.viterbi_decode(feats, masks, get_score=False) | |||
paths, _ = self.crf.viterbi_decode(feats, masks) | |||
return {'pred': probs, 'seq_lens':seq_lens} | |||
return {'pred': paths, 'seq_lens':seq_lens} | |||
@@ -145,9 +145,9 @@ class TransformerDilatedCWS(nn.Module): | |||
feats = self.transformer(x, masks) | |||
feats = self.fc2(feats) | |||
probs = self.crf.viterbi_decode(feats, masks, get_score=False) | |||
paths, _ = self.crf.viterbi_decode(feats, masks) | |||
return {'pred': probs, 'seq_lens':seq_lens} | |||
return {'pred': paths, 'seq_lens':seq_lens} | |||
@@ -163,6 +163,11 @@ class TestDataSetMethods(unittest.TestCase): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target]) | |||
def test_split(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
d1, d2 = ds.split(0.1) | |||
def test_apply2(self): | |||
def split_sent(ins): | |||
return ins['raw_sentence'].split() | |||
@@ -202,20 +207,6 @@ class TestDataSetMethods(unittest.TestCase): | |||
self.assertTrue(isinstance(ans, FieldArray)) | |||
self.assertEqual(ans.content, [[5, 6]] * 10) | |||
def test_reader(self): | |||
# 跑通即可 | |||
ds = DataSet().read_naive("test/data_for_tests/tutorial_sample_dataset.csv") | |||
self.assertTrue(isinstance(ds, DataSet)) | |||
self.assertTrue(len(ds) > 0) | |||
ds = DataSet().read_rawdata("test/data_for_tests/people_daily_raw.txt") | |||
self.assertTrue(isinstance(ds, DataSet)) | |||
self.assertTrue(len(ds) > 0) | |||
ds = DataSet().read_pos("test/data_for_tests/people.txt") | |||
self.assertTrue(isinstance(ds, DataSet)) | |||
self.assertTrue(len(ds) > 0) | |||
def test_add_null(self): | |||
# TODO test failed because 'fastNLP\core\fieldarray.py:143: RuntimeError' | |||
ds = DataSet() | |||
@@ -0,0 +1,3 @@ | |||
{"annotator_labels": ["neutral"], "captionID": "3416050480.jpg#4", "gold_label": "neutral", "pairID": "3416050480.jpg#4r1n", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is training his horse for a competition.", "sentence2_binary_parse": "( ( A person ) ( ( is ( ( training ( his horse ) ) ( for ( a competition ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (VP (VBG training) (NP (PRP$ his) (NN horse)) (PP (IN for) (NP (DT a) (NN competition))))) (. .)))"} | |||
{"annotator_labels": ["contradiction"], "captionID": "3416050480.jpg#4", "gold_label": "contradiction", "pairID": "3416050480.jpg#4r1c", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is at a diner, ordering an omelette.", "sentence2_binary_parse": "( ( A person ) ( ( ( ( is ( at ( a diner ) ) ) , ) ( ordering ( an omelette ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (PP (IN at) (NP (DT a) (NN diner))) (, ,) (S (VP (VBG ordering) (NP (DT an) (NN omelette))))) (. .)))"} | |||
{"annotator_labels": ["entailment"], "captionID": "3416050480.jpg#4", "gold_label": "entailment", "pairID": "3416050480.jpg#4r1e", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is outdoors, on a horse.", "sentence2_binary_parse": "( ( A person ) ( ( ( ( is outdoors ) , ) ( on ( a horse ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (ADVP (RB outdoors)) (, ,) (PP (IN on) (NP (DT a) (NN horse)))) (. .)))"} |
@@ -1,8 +1,7 @@ | |||
import unittest | |||
from fastNLP.io.dataset_loader import Conll2003Loader, PeopleDailyCorpusLoader, ConllCWSReader, \ | |||
ZhConllPOSReader, ConllxDataLoader | |||
from fastNLP.io.dataset_loader import Conll2003Loader, PeopleDailyCorpusLoader, \ | |||
CSVLoader, SNLILoader | |||
class TestDatasetLoader(unittest.TestCase): | |||
@@ -17,3 +16,11 @@ class TestDatasetLoader(unittest.TestCase): | |||
def test_PeopleDailyCorpusLoader(self): | |||
data_set = PeopleDailyCorpusLoader().load("test/data_for_tests/people_daily_raw.txt") | |||
def test_CSVLoader(self): | |||
ds = CSVLoader(sep='\t', headers=['words', 'label'])\ | |||
.load('test/data_for_tests/tutorial_sample_dataset.csv') | |||
assert len(ds) > 0 | |||
def test_SNLILoader(self): | |||
ds = SNLILoader().load('test/data_for_tests/sample_snli.jsonl') | |||
assert len(ds) == 3 |
@@ -1,9 +0,0 @@ | |||
import unittest | |||
class TestUtils(unittest.TestCase): | |||
def test_case_1(self): | |||
pass | |||
def test_case_2(self): | |||
pass |
@@ -379,6 +379,14 @@ class TestTutorial(unittest.TestCase): | |||
dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') | |||
train_data_2[-1], dev_data_2[-1] | |||
for data in [train_data, dev_data, test_data]: | |||
data.rename_field('premise', 'words1') | |||
data.rename_field('hypothesis', 'words2') | |||
data.rename_field('premise_len', 'seq_len1') | |||
data.rename_field('hypothesis_len', 'seq_len2') | |||
data.set_input('words1', 'words2', 'seq_len1', 'seq_len2') | |||
# step 1:加载模型参数(非必选) | |||
from fastNLP.io.config_io import ConfigSection, ConfigLoader | |||
args = ConfigSection() | |||