Browse Source

Merge branch 'dev'

# Conflicts:
#	setup.py
tags/v1.0.0alpha
WillQvQ 2 years ago
parent
commit
fb645d370f
14 changed files with 325 additions and 43 deletions
  1. +15
    -15
      README.md
  2. +4
    -8
      docs/requirements.txt
  3. +5
    -1
      docs/source/index.rst
  4. +3
    -3
      docs/source/tutorials/extend_3_fitlog.rst
  5. +2
    -2
      docs/source/tutorials/文本分类.rst
  6. +1
    -1
      fastNLP/core/dataset.py
  7. +1
    -1
      fastNLP/core/tester.py
  8. +5
    -0
      fastNLP/io/file_utils.py
  9. +14
    -6
      fastNLP/io/pipe/__init__.py
  10. +268
    -0
      fastNLP/io/pipe/construct_graph.py
  11. +1
    -1
      fastNLP/models/biaffine_parser.py
  12. +2
    -2
      fastNLP/modules/encoder/_elmo.py
  13. +3
    -2
      readthedocs.yml
  14. +1
    -1
      setup.py

+ 15
- 15
README.md View File

@@ -39,30 +39,30 @@ python -m spacy download en


## fastNLP教程
中文[文档](https://fastnlp.readthedocs.io/)、[教程](https://fastnlp.readthedocs.io/zh/latest/user/tutorials.html)
中文[文档](http://www.fastnlp.top/docs/fastNLP/)、 [教程](http://www.fastnlp.top/docs/fastNLP/user/quickstart.html)

### 快速入门

- [0. 快速入门](https://fastnlp.readthedocs.io/zh/latest/user/quickstart.html)
- [Quick-1. 文本分类](http://www.fastnlp.top/docs/fastNLP/tutorials/%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB.html)
- [Quick-2. 序列标注](http://www.fastnlp.top/docs/fastNLP/tutorials/%E5%BA%8F%E5%88%97%E6%A0%87%E6%B3%A8.html)

### 详细使用教程

- [1. 使用DataSet预处理文本](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_1_data_preprocess.html)
- [2. 使用Vocabulary转换文本与index](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_2_vocabulary.html)
- [3. 使用Embedding模块将文本转成向量](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_3_embedding.html)
- [4. 使用Loader和Pipe加载并处理数据集](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_4_load_dataset.html)
- [5. 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_5_loss_optimizer.html)
- [6. 动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_6_datasetiter.html)
- [7. 使用Metric快速评测你的模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_7_metrics.html)
- [8. 使用Modules和Models快速搭建自定义模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_8_modules_models.html)
- [9. 快速实现序列标注模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_9_seq_labeling.html)
- [10. 使用Callback自定义你的训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_10_callback.html)
- [1. 使用DataSet预处理文本](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_1_data_preprocess.html)
- [2. 使用Vocabulary转换文本与index](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_2_vocabulary.html)
- [3. 使用Embedding模块将文本转成向量](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_3_embedding.html)
- [4. 使用Loader和Pipe加载并处理数据集](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_4_load_dataset.html)
- [5. 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_5_loss_optimizer.html)
- [6. 动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_6_datasetiter.html)
- [7. 使用Metric快速评测你的模型](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_7_metrics.html)
- [8. 使用Modules和Models快速搭建自定义模型](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_8_modules_models.html)
- [9. 使用Callback自定义你的训练过程](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_9_callback.html)

### 扩展教程

- [Extend-1. BertEmbedding的各种用法](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_1_bert_embedding.html)
- [Extend-2. 分布式训练简介](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_2_dist.html)
- [Extend-3. 使用fitlog 辅助 fastNLP 进行科研](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_3_fitlog.html)
- [Extend-1. BertEmbedding的各种用法](http://www.fastnlp.top/docs/fastNLP/tutorials/extend_1_bert_embedding.html)
- [Extend-2. 分布式训练简介](http://www.fastnlp.top/docs/fastNLP/tutorials/extend_2_dist.html)
- [Extend-3. 使用fitlog 辅助 fastNLP 进行科研](http://www.fastnlp.top/docs/fastNLP/tutorials/extend_3_fitlog.html)


## 内置组件


+ 4
- 8
docs/requirements.txt View File

@@ -1,8 +1,4 @@
numpy>=1.14.2
http://download.pytorch.org/whl/cpu/torch-0.4.1-cp36-cp36m-linux_x86_64.whl
torchvision>=0.1.8
sphinx-rtd-theme==0.4.1
tensorboardX>=1.4
tqdm>=4.28.1
ipython>=6.4.0
ipython-genutils>=0.2.0
sphinx==3.2.1
docutils==0.16
sphinx-rtd-theme==0.5.0
readthedocs-sphinx-search==0.1.0rc3

+ 5
- 1
docs/source/index.rst View File

@@ -4,6 +4,10 @@ fastNLP 中文文档
`fastNLP <https://github.com/fastnlp/fastNLP/>`_ 是一款轻量级的自然语言处理(NLP)工具包。你既可以用它来快速地完成一个NLP任务,
也可以用它在研究中快速构建更复杂的模型。

.. hint::

如果你是从 readthedocs 访问的该文档,请跳转到我们的 `最新网站 <http://www.fastnlp.top/docs/fastNLP/>`_

fastNLP具有如下的特性:

- 统一的Tabular式数据容器,简化数据预处理过程;
@@ -41,7 +45,7 @@ API 文档
fitlog文档
----------

您可以 `点此 <https://fitlog.readthedocs.io/zh/latest/>`_ 查看fitlog的文档。
您可以 `点此 <http://www.fastnlp.top/docs/fitlog/>`_ 查看fitlog的文档。
fitlog 是由我们团队开发的日志记录+代码管理的工具。

索引与搜索


+ 3
- 3
docs/source/tutorials/extend_3_fitlog.rst View File

@@ -4,7 +4,7 @@

本文介绍结合使用 fastNLP 和 fitlog 进行科研的方法。

首先,我们需要安装 `fitlog <https://fitlog.readthedocs.io/>`_ 。你需要确认你的电脑中没有其它名为 `fitlog` 的命令。
首先,我们需要安装 `fitlog <http://www.fastnlp.top/docs/fitlog/>`_ 。你需要确认你的电脑中没有其它名为 `fitlog` 的命令。

我们从命令行中进入到一个文件夹,现在我们要在文件夹中创建我们的 fastNLP 项目。你可以在命令行输入 `fitlog init test1` ,
然后你会看到如下提示::
@@ -15,7 +15,7 @@
Fitlog project test1 is initialized.

这表明你已经创建成功了项目文件夹,并且在项目文件夹中已经初始化了 Git。如果你不想初始化 Git,
可以参考文档 `命令行工具 <https://fitlog.readthedocs.io/zh/latest/user/command_line.html>`_
可以参考文档 `命令行工具 <http://www.fastnlp.top/docs/fitlog/user/command_line.html>`_

现在我们进入你创建的项目文件夹 test1 中,可以看到有一个名为 logs 的文件夹,后面我们将会在里面存放你的实验记录。
同时也有一个名为 main.py 的文件,这是我们推荐你使用的训练入口文件。文件的内容如下::
@@ -37,7 +37,7 @@
fitlog.finish() # finish the logging

我们推荐你保留除注释外的四行代码,它们有助于你的实验,
他们的具体用处参见文档 `用户 API <https://fitlog.readthedocs.io/zh/latest/fitlog.html>`_
他们的具体用处参见文档 `用户 API <http://www.fastnlp.top/docs/fitlog/>`_

我们假定你要进行前两个教程中的实验,并已经把数据复制到了项目根目录下的 tutorial_sample_dataset.csv 文件中。
现在我们编写如下的训练代码,使用 :class:`~fastNLP.core.callback.FitlogCallback` 进行实验记录保存::


+ 2
- 2
docs/source/tutorials/文本分类.rst View File

@@ -291,7 +291,7 @@ fastNLP提供了Trainer对象来组织训练过程,包括完成loss计算(所


PS: 使用Bert进行文本分类
~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. code-block:: python

@@ -368,7 +368,7 @@ PS: 使用Bert进行文本分类


PS: 基于词进行文本分类
~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

由于汉字中没有显示的字与字的边界,一般需要通过分词器先将句子进行分词操作。
下面的例子演示了如何不基于fastNLP已有的数据读取、预处理代码进行文本分类。


+ 1
- 1
fastNLP/core/dataset.py View File

@@ -53,7 +53,7 @@ r"""
from fastNLP import DataSet
from fastNLP import Instance
instances = []
winstances.append(Instance(sentence="This is the first instance",
instances.append(Instance(sentence="This is the first instance",
ords=['this', 'is', 'the', 'first', 'instance', '.'],
seq_len=6))
instances.append(Instance(sentence="Second instance .",


+ 1
- 1
fastNLP/core/tester.py View File

@@ -148,7 +148,7 @@ class Tester(object):
self._predict_func = self._model.predict
self._predict_func_wrapper = self._model.predict
else:
if _model_contains_inner_module(model):
if _model_contains_inner_module(self._model):
self._predict_func_wrapper = self._model.forward
self._predict_func = self._model.module.forward
else:


+ 5
- 0
fastNLP/io/file_utils.py View File

@@ -103,6 +103,11 @@ DATASET_DIR = {
"yelp-review-polarity": "yelp_review_polarity.tar.gz",
"sst-2": "SST-2.zip",
"sst": "SST.zip",
'mr': 'mr.zip',
"R8": "R8.zip",
"R52": "R52.zip",
"20ng": "20ng.zip",
"ohsumed": "ohsumed.zip",

# Classification, Chinese
"chn-senti-corp": "chn_senti_corp.zip",


+ 14
- 6
fastNLP/io/pipe/__init__.py View File

@@ -23,15 +23,15 @@ __all__ = [
"ChnSentiCorpPipe",
"THUCNewsPipe",
"WeiboSenti100kPipe",
"MRPipe", "R52Pipe", "R8Pipe", "OhsumedPipe", "NG20Loader",
"MRPipe", "R52Pipe", "R8Pipe", "OhsumedPipe", "NG20Pipe",
"Conll2003NERPipe",
"OntoNotesNERPipe",
"MsraNERPipe",
"WeiboNERPipe",
"PeopleDailyPipe",
"Conll2003Pipe",
"MatchingBertPipe",
"RTEBertPipe",
"SNLIBertPipe",
@@ -53,14 +53,20 @@ __all__ = [
"RenamePipe",
"GranularizePipe",
"MachingTruncatePipe",
"CoReferencePipe",

"CMRC2018BertPipe"
"CMRC2018BertPipe",

"R52PmiGraphPipe",
"R8PmiGraphPipe",
"OhsumedPmiGraphPipe",
"NG20PmiGraphPipe",
"MRPmiGraphPipe"
]

from .classification import CLSBasePipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \
WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe, MRPipe, R8Pipe, R52Pipe, OhsumedPipe, NG20Loader
WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe, MRPipe, R8Pipe, R52Pipe, OhsumedPipe, NG20Pipe
from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe
from .conll import Conll2003Pipe
from .coreference import CoReferencePipe
@@ -70,3 +76,5 @@ from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe
LCQMCPipe, BQCorpusPipe, LCQMCBertPipe, RenamePipe, GranularizePipe, MachingTruncatePipe
from .pipe import Pipe
from .qa import CMRC2018BertPipe

from .construct_graph import MRPmiGraphPipe, R8PmiGraphPipe, R52PmiGraphPipe, NG20PmiGraphPipe, OhsumedPmiGraphPipe

+ 268
- 0
fastNLP/io/pipe/construct_graph.py View File

@@ -0,0 +1,268 @@

__all__ =[
'MRPmiGraphPipe',
'R8PmiGraphPipe',
'R52PmiGraphPipe',
'OhsumedPmiGraphPipe',
'NG20PmiGraphPipe'
]
try:
import networkx as nx
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.pipeline import Pipeline
except:
pass
from collections import defaultdict
import itertools
import math
from tqdm import tqdm
import numpy as np

from ..data_bundle import DataBundle
from ...core.const import Const
from ..loader.classification import MRLoader, OhsumedLoader, R52Loader, R8Loader, NG20Loader


def _get_windows(content_lst: list, window_size:int):
r"""
滑动窗口处理文本,获取词频和共现词语的词频
:param content_lst:
:param window_size:
:return: 词频,共现词频,窗口化后文本段的数量
"""
word_window_freq = defaultdict(int) # w(i) 单词在窗口单位内出现的次数
word_pair_count = defaultdict(int) # w(i, j)
windows_len = 0
for words in tqdm(content_lst, desc="Split by window"):
windows = list()

if isinstance(words, str):
words = words.split()
length = len(words)

if length <= window_size:
windows.append(words)
else:
for j in range(length - window_size + 1):
window = words[j: j + window_size]
windows.append(list(set(window)))

for window in windows:
for word in window:
word_window_freq[word] += 1

for word_pair in itertools.combinations(window, 2):
word_pair_count[word_pair] += 1

windows_len += len(windows)
return word_window_freq, word_pair_count, windows_len

def _cal_pmi(W_ij, W, word_freq_i, word_freq_j):
r"""
params: w_ij:为词语i,j的共现词频
w:文本数量
word_freq_i: 词语i的词频
word_freq_j: 词语j的词频
return: 词语i,j的tfidf值
"""
p_i = word_freq_i / W
p_j = word_freq_j / W
p_i_j = W_ij / W
pmi = math.log(p_i_j / (p_i * p_j))

return pmi

def _count_pmi(windows_len, word_pair_count, word_window_freq, threshold):
r"""
params: windows_len: 文本段数量
word_pair_count: 词共现频率字典
word_window_freq: 词频率字典
threshold: 阈值
return 词语pmi的list列表,其中元素为[word1, word2, pmi]
"""
word_pmi_lst = list()
for word_pair, W_i_j in tqdm(word_pair_count.items(), desc="Calculate pmi between words"):
word_freq_1 = word_window_freq[word_pair[0]]
word_freq_2 = word_window_freq[word_pair[1]]

pmi = _cal_pmi(W_i_j, windows_len, word_freq_1, word_freq_2)
if pmi <= threshold:
continue
word_pmi_lst.append([word_pair[0], word_pair[1], pmi])
return word_pmi_lst

class GraphBuilderBase:
def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
self.graph = nx.Graph()
self.word2id = dict()
self.graph_type = graph_type
self.window_size = widow_size
self.doc_node_num = 0
self.tr_doc_index = None
self.te_doc_index = None
self.dev_doc_index = None
self.doc = None
self.threshold = threshold

def _get_doc_edge(self, data_bundle: DataBundle):
r'''
对输入的DataBundle进行处理,然后生成文档-单词的tfidf值
:param: data_bundle中的文本若为英文,形式为[ 'This is the first document.'],若为中文则为['他 喜欢 吃 苹果']
: return 返回带有具有tfidf边文档-单词稀疏矩阵
'''
tr_doc = list(data_bundle.get_dataset("train").get_field(Const.RAW_WORD))
val_doc = list(data_bundle.get_dataset("dev").get_field(Const.RAW_WORD))
te_doc = list(data_bundle.get_dataset("test").get_field(Const.RAW_WORD))
doc = tr_doc + val_doc + te_doc
self.doc = doc
self.tr_doc_index = [ind for ind in range(len(tr_doc))]
self.dev_doc_index = [ind+len(tr_doc) for ind in range(len(val_doc))]
self.te_doc_index = [ind+len(tr_doc)+len(val_doc) for ind in range(len(te_doc))]
text_tfidf = Pipeline([('count', CountVectorizer(token_pattern=r'\S+', min_df=1, max_df=1.0)),
('tfidf', TfidfTransformer(norm=None, use_idf=True, smooth_idf=False, sublinear_tf=False))])

tfidf_vec = text_tfidf.fit_transform(doc)
self.doc_node_num = tfidf_vec.shape[0]
vocab_lst = text_tfidf['count'].get_feature_names()
for ind, word in enumerate(vocab_lst):
self.word2id[word] = ind
for ind, row in enumerate(tfidf_vec):
for col_index, value in zip(row.indices, row.data):
self.graph.add_edge(ind, self.doc_node_num+col_index, weight=value)
return nx.to_scipy_sparse_matrix(self.graph)

def _get_word_edge(self):
word_window_freq, word_pair_count, windows_len = _get_windows(self.doc, self.window_size)
pmi_edge_lst = _count_pmi(windows_len, word_pair_count, word_window_freq, self.threshold)
for edge_item in pmi_edge_lst:
word_indx1 = self.doc_node_num + self.word2id[edge_item[0]]
word_indx2 = self.doc_node_num + self.word2id[edge_item[1]]
if word_indx1 == word_indx2:
continue
self.graph.add_edge(word_indx1, word_indx2, weight=edge_item[2])

def build_graph(self, data_bundle: DataBundle):
r"""
对输入的DataBundle进行处理,然后返回该scipy_sparse_matrix类型的邻接矩阵。

:param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象
:return:
"""
raise NotImplementedError

def build_graph_from_file(self, path: str):
r"""
传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()`

:param paths:
:return: scipy_sparse_matrix
"""
raise NotImplementedError


class MRPmiGraphPipe(GraphBuilderBase):

def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)

def build_graph(self, data_bundle: DataBundle):
r'''
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
'''
self._get_doc_edge(data_bundle)
self._get_word_edge()
return nx.to_scipy_sparse_matrix(self.graph,
nodelist=list(range(self.graph.number_of_nodes())),
weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)

def build_graph_from_file(self, path: str):
data_bundle = MRLoader().load(path)
return self.build_graph(data_bundle)

class R8PmiGraphPipe(GraphBuilderBase):

def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)

def build_graph(self, data_bundle: DataBundle):
r'''
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
'''
self._get_doc_edge(data_bundle)
self._get_word_edge()
return nx.to_scipy_sparse_matrix(self.graph,
nodelist=list(range(self.graph.number_of_nodes())),
weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)

def build_graph_from_file(self, path: str):
data_bundle = R8Loader().load(path)
return self.build_graph(data_bundle)

class R52PmiGraphPipe(GraphBuilderBase):

def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)

def build_graph(self, data_bundle: DataBundle):
r'''
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
return 返回csr类型的稀疏矩阵;训练集,验证集,测试集,在图中的index.
'''
self._get_doc_edge(data_bundle)
self._get_word_edge()
return nx.to_scipy_sparse_matrix(self.graph,
nodelist=list(range(self.graph.number_of_nodes())),
weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)

def build_graph_from_file(self, path: str):
data_bundle = R52Loader().load(path)
return self.build_graph(data_bundle)

class OhsumedPmiGraphPipe(GraphBuilderBase):

def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)

def build_graph(self, data_bundle: DataBundle):
r'''
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
'''
self._get_doc_edge(data_bundle)
self._get_word_edge()
return nx.to_scipy_sparse_matrix(self.graph,
nodelist=list(range(self.graph.number_of_nodes())),
weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)

def build_graph_from_file(self, path: str):
data_bundle = OhsumedLoader().load(path)
return self.build_graph(data_bundle)


class NG20PmiGraphPipe(GraphBuilderBase):

def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)

def build_graph(self, data_bundle: DataBundle):
r'''
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
'''
self._get_doc_edge(data_bundle)
self._get_word_edge()
return nx.to_scipy_sparse_matrix(self.graph,
nodelist=list(range(self.graph.number_of_nodes())),
weight='weight', dtype=np.float32, format='csr'), (
self.tr_doc_index, self.dev_doc_index, self.te_doc_index)

def build_graph_from_file(self, path: str):
r'''
param: path->数据集的路径.
return: 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
'''
data_bundle = NG20Loader().load(path)
return self.build_graph(data_bundle)

+ 1
- 1
fastNLP/models/biaffine_parser.py View File

@@ -376,7 +376,7 @@ class BiaffineParser(GraphParser):
if self.encoder_name.endswith('lstm'):
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
x = x[sort_idx]
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True)
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=True)
feat, _ = self.encoder(x) # -> [N,L,C]
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)


+ 2
- 2
fastNLP/modules/encoder/_elmo.py View File

@@ -251,7 +251,7 @@ class LstmbiLm(nn.Module):
def forward(self, inputs, seq_len):
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
inputs = inputs[sort_idx]
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=self.batch_first)
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens.cpu(), batch_first=self.batch_first)
output, hx = self.encoder(inputs, None) # -> [N,L,C]
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
@@ -316,7 +316,7 @@ class ElmobiLm(torch.nn.Module):
max_len = inputs.size(1)
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
inputs = inputs[sort_idx]
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=True)
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens.cpu(), batch_first=True)
output, _ = self._lstm_forward(inputs, None)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
output = output[:, unsort_idx]


+ 3
- 2
readthedocs.yml View File

@@ -7,10 +7,11 @@ build:
image: latest

python:
version: 3.6
version: 3.8
install:
- requirements: docs/requirements.txt
- method: setuptools
path: .

formats:
- htmlzip
- htmlzip

+ 1
- 1
setup.py View File

@@ -23,7 +23,7 @@ setup(
long_description_content_type='text/markdown',
license='Apache License',
author='Fudan FastNLP Team',
python_requires='>=3.6',
python_requires='>=3.7',
packages=pkgs,
install_requires=reqs.strip().split('\n'),
)

Loading…
Cancel
Save