Browse Source

update construct_graph (#393)

tags/v1.0.0alpha
hw GitHub 2 years ago
parent
commit
c18b205bc0
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 287 additions and 6 deletions
  1. +5
    -0
      fastNLP/io/file_utils.py
  2. +14
    -6
      fastNLP/io/pipe/__init__.py
  3. +268
    -0
      fastNLP/io/pipe/construct_graph.py

+ 5
- 0
fastNLP/io/file_utils.py View File

@@ -103,6 +103,11 @@ DATASET_DIR = {
"yelp-review-polarity": "yelp_review_polarity.tar.gz", "yelp-review-polarity": "yelp_review_polarity.tar.gz",
"sst-2": "SST-2.zip", "sst-2": "SST-2.zip",
"sst": "SST.zip", "sst": "SST.zip",
'mr': 'mr.zip',
"R8": "R8.zip",
"R52": "R52.zip",
"20ng": "20ng.zip",
"ohsumed": "ohsumed.zip",


# Classification, Chinese # Classification, Chinese
"chn-senti-corp": "chn_senti_corp.zip", "chn-senti-corp": "chn_senti_corp.zip",


+ 14
- 6
fastNLP/io/pipe/__init__.py View File

@@ -23,15 +23,15 @@ __all__ = [
"ChnSentiCorpPipe", "ChnSentiCorpPipe",
"THUCNewsPipe", "THUCNewsPipe",
"WeiboSenti100kPipe", "WeiboSenti100kPipe",
"MRPipe", "R52Pipe", "R8Pipe", "OhsumedPipe", "NG20Loader",
"MRPipe", "R52Pipe", "R8Pipe", "OhsumedPipe", "NG20Pipe",
"Conll2003NERPipe", "Conll2003NERPipe",
"OntoNotesNERPipe", "OntoNotesNERPipe",
"MsraNERPipe", "MsraNERPipe",
"WeiboNERPipe", "WeiboNERPipe",
"PeopleDailyPipe", "PeopleDailyPipe",
"Conll2003Pipe", "Conll2003Pipe",
"MatchingBertPipe", "MatchingBertPipe",
"RTEBertPipe", "RTEBertPipe",
"SNLIBertPipe", "SNLIBertPipe",
@@ -53,14 +53,20 @@ __all__ = [
"RenamePipe", "RenamePipe",
"GranularizePipe", "GranularizePipe",
"MachingTruncatePipe", "MachingTruncatePipe",
"CoReferencePipe", "CoReferencePipe",


"CMRC2018BertPipe"
"CMRC2018BertPipe",

"R52PmiGraphPipe",
"R8PmiGraphPipe",
"OhsumedPmiGraphPipe",
"NG20PmiGraphPipe",
"MRPmiGraphPipe"
] ]


from .classification import CLSBasePipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \ from .classification import CLSBasePipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \
WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe, MRPipe, R8Pipe, R52Pipe, OhsumedPipe, NG20Loader
WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe, MRPipe, R8Pipe, R52Pipe, OhsumedPipe, NG20Pipe
from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe
from .conll import Conll2003Pipe from .conll import Conll2003Pipe
from .coreference import CoReferencePipe from .coreference import CoReferencePipe
@@ -70,3 +76,5 @@ from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe
LCQMCPipe, BQCorpusPipe, LCQMCBertPipe, RenamePipe, GranularizePipe, MachingTruncatePipe LCQMCPipe, BQCorpusPipe, LCQMCBertPipe, RenamePipe, GranularizePipe, MachingTruncatePipe
from .pipe import Pipe from .pipe import Pipe
from .qa import CMRC2018BertPipe from .qa import CMRC2018BertPipe

from .construct_graph import MRPmiGraphPipe, R8PmiGraphPipe, R52PmiGraphPipe, NG20PmiGraphPipe, OhsumedPmiGraphPipe

+ 268
- 0
fastNLP/io/pipe/construct_graph.py View File

@@ -0,0 +1,268 @@

__all__ =[
'MRPmiGraphPipe',
'R8PmiGraphPipe',
'R52PmiGraphPipe',
'OhsumedPmiGraphPipe',
'NG20PmiGraphPipe'
]
try:
import networkx as nx
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.pipeline import Pipeline
except:
pass
from collections import defaultdict
import itertools
import math
from tqdm import tqdm
import numpy as np

from ..data_bundle import DataBundle
from ...core.const import Const
from ..loader.classification import MRLoader, OhsumedLoader, R52Loader, R8Loader, NG20Loader


def _get_windows(content_lst: list, window_size:int):
r"""
滑动窗口处理文本,获取词频和共现词语的词频
:param content_lst:
:param window_size:
:return: 词频,共现词频,窗口化后文本段的数量
"""
word_window_freq = defaultdict(int) # w(i) 单词在窗口单位内出现的次数
word_pair_count = defaultdict(int) # w(i, j)
windows_len = 0
for words in tqdm(content_lst, desc="Split by window"):
windows = list()

if isinstance(words, str):
words = words.split()
length = len(words)

if length <= window_size:
windows.append(words)
else:
for j in range(length - window_size + 1):
window = words[j: j + window_size]
windows.append(list(set(window)))

for window in windows:
for word in window:
word_window_freq[word] += 1

for word_pair in itertools.combinations(window, 2):
word_pair_count[word_pair] += 1

windows_len += len(windows)
return word_window_freq, word_pair_count, windows_len

def _cal_pmi(W_ij, W, word_freq_i, word_freq_j):
r"""
params: w_ij:为词语i,j的共现词频
w:文本数量
word_freq_i: 词语i的词频
word_freq_j: 词语j的词频
return: 词语i,j的tfidf值
"""
p_i = word_freq_i / W
p_j = word_freq_j / W
p_i_j = W_ij / W
pmi = math.log(p_i_j / (p_i * p_j))

return pmi

def _count_pmi(windows_len, word_pair_count, word_window_freq, threshold):
r"""
params: windows_len: 文本段数量
word_pair_count: 词共现频率字典
word_window_freq: 词频率字典
threshold: 阈值
return 词语pmi的list列表,其中元素为[word1, word2, pmi]
"""
word_pmi_lst = list()
for word_pair, W_i_j in tqdm(word_pair_count.items(), desc="Calculate pmi between words"):
word_freq_1 = word_window_freq[word_pair[0]]
word_freq_2 = word_window_freq[word_pair[1]]

pmi = _cal_pmi(W_i_j, windows_len, word_freq_1, word_freq_2)
if pmi <= threshold:
continue
word_pmi_lst.append([word_pair[0], word_pair[1], pmi])
return word_pmi_lst

class GraphBuilderBase:
def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
self.graph = nx.Graph()
self.word2id = dict()
self.graph_type = graph_type
self.window_size = widow_size
self.doc_node_num = 0
self.tr_doc_index = None
self.te_doc_index = None
self.dev_doc_index = None
self.doc = None
self.threshold = threshold

def _get_doc_edge(self, data_bundle: DataBundle):
r'''
对输入的DataBundle进行处理,然后生成文档-单词的tfidf值
:param: data_bundle中的文本若为英文,形式为[ 'This is the first document.'],若为中文则为['他 喜欢 吃 苹果']
: return 返回带有具有tfidf边文档-单词稀疏矩阵
'''
tr_doc = list(data_bundle.get_dataset("train").get_field(Const.RAW_WORD))
val_doc = list(data_bundle.get_dataset("dev").get_field(Const.RAW_WORD))
te_doc = list(data_bundle.get_dataset("test").get_field(Const.RAW_WORD))
doc = tr_doc + val_doc + te_doc
self.doc = doc
self.tr_doc_index = [ind for ind in range(len(tr_doc))]
self.dev_doc_index = [ind+len(tr_doc) for ind in range(len(val_doc))]
self.te_doc_index = [ind+len(tr_doc)+len(val_doc) for ind in range(len(te_doc))]
text_tfidf = Pipeline([('count', CountVectorizer(token_pattern=r'\S+', min_df=1, max_df=1.0)),
('tfidf', TfidfTransformer(norm=None, use_idf=True, smooth_idf=False, sublinear_tf=False))])

tfidf_vec = text_tfidf.fit_transform(doc)
self.doc_node_num = tfidf_vec.shape[0]
vocab_lst = text_tfidf['count'].get_feature_names()
for ind, word in enumerate(vocab_lst):
self.word2id[word] = ind
for ind, row in enumerate(tfidf_vec):
for col_index, value in zip(row.indices, row.data):
self.graph.add_edge(ind, self.doc_node_num+col_index, weight=value)
return nx.to_scipy_sparse_matrix(self.graph)

def _get_word_edge(self):
word_window_freq, word_pair_count, windows_len = _get_windows(self.doc, self.window_size)
pmi_edge_lst = _count_pmi(windows_len, word_pair_count, word_window_freq, self.threshold)
for edge_item in pmi_edge_lst:
word_indx1 = self.doc_node_num + self.word2id[edge_item[0]]
word_indx2 = self.doc_node_num + self.word2id[edge_item[1]]
if word_indx1 == word_indx2:
continue
self.graph.add_edge(word_indx1, word_indx2, weight=edge_item[2])

def build_graph(self, data_bundle: DataBundle):
r"""
对输入的DataBundle进行处理,然后返回该scipy_sparse_matrix类型的邻接矩阵。

:param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象
:return:
"""
raise NotImplementedError

def build_graph_from_file(self, path: str):
r"""
传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()`

:param paths:
:return: scipy_sparse_matrix
"""
raise NotImplementedError


class MRPmiGraphPipe(GraphBuilderBase):

def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)

def build_graph(self, data_bundle: DataBundle):
r'''
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
'''
self._get_doc_edge(data_bundle)
self._get_word_edge()
return nx.to_scipy_sparse_matrix(self.graph,
nodelist=list(range(self.graph.number_of_nodes())),
weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)

def build_graph_from_file(self, path: str):
data_bundle = MRLoader().load(path)
return self.build_graph(data_bundle)

class R8PmiGraphPipe(GraphBuilderBase):

def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)

def build_graph(self, data_bundle: DataBundle):
r'''
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
'''
self._get_doc_edge(data_bundle)
self._get_word_edge()
return nx.to_scipy_sparse_matrix(self.graph,
nodelist=list(range(self.graph.number_of_nodes())),
weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)

def build_graph_from_file(self, path: str):
data_bundle = R8Loader().load(path)
return self.build_graph(data_bundle)

class R52PmiGraphPipe(GraphBuilderBase):

def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)

def build_graph(self, data_bundle: DataBundle):
r'''
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
return 返回csr类型的稀疏矩阵;训练集,验证集,测试集,在图中的index.
'''
self._get_doc_edge(data_bundle)
self._get_word_edge()
return nx.to_scipy_sparse_matrix(self.graph,
nodelist=list(range(self.graph.number_of_nodes())),
weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)

def build_graph_from_file(self, path: str):
data_bundle = R52Loader().load(path)
return self.build_graph(data_bundle)

class OhsumedPmiGraphPipe(GraphBuilderBase):

def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)

def build_graph(self, data_bundle: DataBundle):
r'''
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
'''
self._get_doc_edge(data_bundle)
self._get_word_edge()
return nx.to_scipy_sparse_matrix(self.graph,
nodelist=list(range(self.graph.number_of_nodes())),
weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)

def build_graph_from_file(self, path: str):
data_bundle = OhsumedLoader().load(path)
return self.build_graph(data_bundle)


class NG20PmiGraphPipe(GraphBuilderBase):

def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)

def build_graph(self, data_bundle: DataBundle):
r'''
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
'''
self._get_doc_edge(data_bundle)
self._get_word_edge()
return nx.to_scipy_sparse_matrix(self.graph,
nodelist=list(range(self.graph.number_of_nodes())),
weight='weight', dtype=np.float32, format='csr'), (
self.tr_doc_index, self.dev_doc_index, self.te_doc_index)

def build_graph_from_file(self, path: str):
r'''
param: path->数据集的路径.
return: 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
'''
data_bundle = NG20Loader().load(path)
return self.build_graph(data_bundle)

Loading…
Cancel
Save