Browse Source

update construct_graph (#393)

tags/v1.0.0alpha
hw GitHub 2 years ago
parent
commit
c18b205bc0
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 287 additions and 6 deletions
  1. +5
    -0
      fastNLP/io/file_utils.py
  2. +14
    -6
      fastNLP/io/pipe/__init__.py
  3. +268
    -0
      fastNLP/io/pipe/construct_graph.py

+ 5
- 0
fastNLP/io/file_utils.py View File

@@ -103,6 +103,11 @@ DATASET_DIR = {
"yelp-review-polarity": "yelp_review_polarity.tar.gz",
"sst-2": "SST-2.zip",
"sst": "SST.zip",
'mr': 'mr.zip',
"R8": "R8.zip",
"R52": "R52.zip",
"20ng": "20ng.zip",
"ohsumed": "ohsumed.zip",

# Classification, Chinese
"chn-senti-corp": "chn_senti_corp.zip",


+ 14
- 6
fastNLP/io/pipe/__init__.py View File

@@ -23,15 +23,15 @@ __all__ = [
"ChnSentiCorpPipe",
"THUCNewsPipe",
"WeiboSenti100kPipe",
"MRPipe", "R52Pipe", "R8Pipe", "OhsumedPipe", "NG20Loader",
"MRPipe", "R52Pipe", "R8Pipe", "OhsumedPipe", "NG20Pipe",
"Conll2003NERPipe",
"OntoNotesNERPipe",
"MsraNERPipe",
"WeiboNERPipe",
"PeopleDailyPipe",
"Conll2003Pipe",
"MatchingBertPipe",
"RTEBertPipe",
"SNLIBertPipe",
@@ -53,14 +53,20 @@ __all__ = [
"RenamePipe",
"GranularizePipe",
"MachingTruncatePipe",
"CoReferencePipe",

"CMRC2018BertPipe"
"CMRC2018BertPipe",

"R52PmiGraphPipe",
"R8PmiGraphPipe",
"OhsumedPmiGraphPipe",
"NG20PmiGraphPipe",
"MRPmiGraphPipe"
]

from .classification import CLSBasePipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \
WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe, MRPipe, R8Pipe, R52Pipe, OhsumedPipe, NG20Loader
WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe, MRPipe, R8Pipe, R52Pipe, OhsumedPipe, NG20Pipe
from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe
from .conll import Conll2003Pipe
from .coreference import CoReferencePipe
@@ -70,3 +76,5 @@ from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe
LCQMCPipe, BQCorpusPipe, LCQMCBertPipe, RenamePipe, GranularizePipe, MachingTruncatePipe
from .pipe import Pipe
from .qa import CMRC2018BertPipe

from .construct_graph import MRPmiGraphPipe, R8PmiGraphPipe, R52PmiGraphPipe, NG20PmiGraphPipe, OhsumedPmiGraphPipe

+ 268
- 0
fastNLP/io/pipe/construct_graph.py View File

@@ -0,0 +1,268 @@

__all__ =[
'MRPmiGraphPipe',
'R8PmiGraphPipe',
'R52PmiGraphPipe',
'OhsumedPmiGraphPipe',
'NG20PmiGraphPipe'
]
try:
import networkx as nx
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.pipeline import Pipeline
except:
pass
from collections import defaultdict
import itertools
import math
from tqdm import tqdm
import numpy as np

from ..data_bundle import DataBundle
from ...core.const import Const
from ..loader.classification import MRLoader, OhsumedLoader, R52Loader, R8Loader, NG20Loader


def _get_windows(content_lst: list, window_size:int):
r"""
滑动窗口处理文本,获取词频和共现词语的词频
:param content_lst:
:param window_size:
:return: 词频,共现词频,窗口化后文本段的数量
"""
word_window_freq = defaultdict(int) # w(i) 单词在窗口单位内出现的次数
word_pair_count = defaultdict(int) # w(i, j)
windows_len = 0
for words in tqdm(content_lst, desc="Split by window"):
windows = list()

if isinstance(words, str):
words = words.split()
length = len(words)

if length <= window_size:
windows.append(words)
else:
for j in range(length - window_size + 1):
window = words[j: j + window_size]
windows.append(list(set(window)))

for window in windows:
for word in window:
word_window_freq[word] += 1

for word_pair in itertools.combinations(window, 2):
word_pair_count[word_pair] += 1

windows_len += len(windows)
return word_window_freq, word_pair_count, windows_len

def _cal_pmi(W_ij, W, word_freq_i, word_freq_j):
r"""
params: w_ij:为词语i,j的共现词频
w:文本数量
word_freq_i: 词语i的词频
word_freq_j: 词语j的词频
return: 词语i,j的tfidf值
"""
p_i = word_freq_i / W
p_j = word_freq_j / W
p_i_j = W_ij / W
pmi = math.log(p_i_j / (p_i * p_j))

return pmi

def _count_pmi(windows_len, word_pair_count, word_window_freq, threshold):
r"""
params: windows_len: 文本段数量
word_pair_count: 词共现频率字典
word_window_freq: 词频率字典
threshold: 阈值
return 词语pmi的list列表,其中元素为[word1, word2, pmi]
"""
word_pmi_lst = list()
for word_pair, W_i_j in tqdm(word_pair_count.items(), desc="Calculate pmi between words"):
word_freq_1 = word_window_freq[word_pair[0]]
word_freq_2 = word_window_freq[word_pair[1]]

pmi = _cal_pmi(W_i_j, windows_len, word_freq_1, word_freq_2)
if pmi <= threshold:
continue
word_pmi_lst.append([word_pair[0], word_pair[1], pmi])
return word_pmi_lst

class GraphBuilderBase:
def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
self.graph = nx.Graph()
self.word2id = dict()
self.graph_type = graph_type
self.window_size = widow_size
self.doc_node_num = 0
self.tr_doc_index = None
self.te_doc_index = None
self.dev_doc_index = None
self.doc = None
self.threshold = threshold

def _get_doc_edge(self, data_bundle: DataBundle):
r'''
对输入的DataBundle进行处理,然后生成文档-单词的tfidf值
:param: data_bundle中的文本若为英文,形式为[ 'This is the first document.'],若为中文则为['他 喜欢 吃 苹果']
: return 返回带有具有tfidf边文档-单词稀疏矩阵
'''
tr_doc = list(data_bundle.get_dataset("train").get_field(Const.RAW_WORD))
val_doc = list(data_bundle.get_dataset("dev").get_field(Const.RAW_WORD))
te_doc = list(data_bundle.get_dataset("test").get_field(Const.RAW_WORD))
doc = tr_doc + val_doc + te_doc
self.doc = doc
self.tr_doc_index = [ind for ind in range(len(tr_doc))]
self.dev_doc_index = [ind+len(tr_doc) for ind in range(len(val_doc))]
self.te_doc_index = [ind+len(tr_doc)+len(val_doc) for ind in range(len(te_doc))]
text_tfidf = Pipeline([('count', CountVectorizer(token_pattern=r'\S+', min_df=1, max_df=1.0)),
('tfidf', TfidfTransformer(norm=None, use_idf=True, smooth_idf=False, sublinear_tf=False))])

tfidf_vec = text_tfidf.fit_transform(doc)
self.doc_node_num = tfidf_vec.shape[0]
vocab_lst = text_tfidf['count'].get_feature_names()
for ind, word in enumerate(vocab_lst):
self.word2id[word] = ind
for ind, row in enumerate(tfidf_vec):
for col_index, value in zip(row.indices, row.data):
self.graph.add_edge(ind, self.doc_node_num+col_index, weight=value)
return nx.to_scipy_sparse_matrix(self.graph)

def _get_word_edge(self):
word_window_freq, word_pair_count, windows_len = _get_windows(self.doc, self.window_size)
pmi_edge_lst = _count_pmi(windows_len, word_pair_count, word_window_freq, self.threshold)
for edge_item in pmi_edge_lst:
word_indx1 = self.doc_node_num + self.word2id[edge_item[0]]
word_indx2 = self.doc_node_num + self.word2id[edge_item[1]]
if word_indx1 == word_indx2:
continue
self.graph.add_edge(word_indx1, word_indx2, weight=edge_item[2])

def build_graph(self, data_bundle: DataBundle):
r"""
对输入的DataBundle进行处理,然后返回该scipy_sparse_matrix类型的邻接矩阵。

:param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象
:return:
"""
raise NotImplementedError

def build_graph_from_file(self, path: str):
r"""
传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()`

:param paths:
:return: scipy_sparse_matrix
"""
raise NotImplementedError


class MRPmiGraphPipe(GraphBuilderBase):

def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)

def build_graph(self, data_bundle: DataBundle):
r'''
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
'''
self._get_doc_edge(data_bundle)
self._get_word_edge()
return nx.to_scipy_sparse_matrix(self.graph,
nodelist=list(range(self.graph.number_of_nodes())),
weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)

def build_graph_from_file(self, path: str):
data_bundle = MRLoader().load(path)
return self.build_graph(data_bundle)

class R8PmiGraphPipe(GraphBuilderBase):

def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)

def build_graph(self, data_bundle: DataBundle):
r'''
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
'''
self._get_doc_edge(data_bundle)
self._get_word_edge()
return nx.to_scipy_sparse_matrix(self.graph,
nodelist=list(range(self.graph.number_of_nodes())),
weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)

def build_graph_from_file(self, path: str):
data_bundle = R8Loader().load(path)
return self.build_graph(data_bundle)

class R52PmiGraphPipe(GraphBuilderBase):

def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)

def build_graph(self, data_bundle: DataBundle):
r'''
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
return 返回csr类型的稀疏矩阵;训练集,验证集,测试集,在图中的index.
'''
self._get_doc_edge(data_bundle)
self._get_word_edge()
return nx.to_scipy_sparse_matrix(self.graph,
nodelist=list(range(self.graph.number_of_nodes())),
weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)

def build_graph_from_file(self, path: str):
data_bundle = R52Loader().load(path)
return self.build_graph(data_bundle)

class OhsumedPmiGraphPipe(GraphBuilderBase):

def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)

def build_graph(self, data_bundle: DataBundle):
r'''
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
'''
self._get_doc_edge(data_bundle)
self._get_word_edge()
return nx.to_scipy_sparse_matrix(self.graph,
nodelist=list(range(self.graph.number_of_nodes())),
weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)

def build_graph_from_file(self, path: str):
data_bundle = OhsumedLoader().load(path)
return self.build_graph(data_bundle)


class NG20PmiGraphPipe(GraphBuilderBase):

def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)

def build_graph(self, data_bundle: DataBundle):
r'''
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
'''
self._get_doc_edge(data_bundle)
self._get_word_edge()
return nx.to_scipy_sparse_matrix(self.graph,
nodelist=list(range(self.graph.number_of_nodes())),
weight='weight', dtype=np.float32, format='csr'), (
self.tr_doc_index, self.dev_doc_index, self.te_doc_index)

def build_graph_from_file(self, path: str):
r'''
param: path->数据集的路径.
return: 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
'''
data_bundle = NG20Loader().load(path)
return self.build_graph(data_bundle)

Loading…
Cancel
Save