@@ -6,7 +6,9 @@ __all__ = [
"SSTPipe",
"SST2Pipe",
'IMDBPipe',
"ChnSentiCorpPipe"
"ChnSentiCorpPipe",
"THUCNewsPipe",
"WeiboSenti100kPipe"
]
import re
@@ -17,7 +19,7 @@ from nltk import Tree
from .pipe import Pipe
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance, _add_chars_field
from ..data_bundle import DataBundle
from ..loader.classification import ChnSentiCorpLoader
from ..loader.classification import ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader
from ...core.const import Const
from ...core.dataset import DataSet
@@ -580,4 +582,200 @@ class ChnSentiCorpPipe(Pipe):
data_bundle = ChnSentiCorpLoader().load(paths)
data_bundle = self.process(data_bundle)
return data_bundle
return data_bundle
class THUCNewsPipe(_CLSPipe):
"""
处理之后的DataSet有以下的结构
.. csv-table::
:header: "raw_chars", "chars", "target", "seq_len"
"马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道...", "[409, 1197, 2146, 213, ...]", 0, 746
"..."
其中chars, seq_len是input,target是target
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过
data_bundle.get_vocab('bigrams')获取.
:param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...]
。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过
data_bundle.get_vocab('trigrams')获取.
"""
def __init__(self, bigrams=False, trigrams=False):
super().__init__()
self.bigrams = bigrams
self.trigrams = trigrams
def _chracter_split(self, sent):
return list(sent)
# return [w for w in sent]
def _raw_split(self, sent):
return sent.split()
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None):
new_field_name = new_field_name or field_name
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name)
return data_bundle
def process(self, data_bundle: DataBundle):
"""
可处理的DataSet应具备如下的field
.. csv-table::
:header: "raw_words", "target"
"马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道 ... ", "体育"
"...", "..."
:param data_bundle:
:return:
"""
# 根据granularity设置tag
tag_map = {'体育': 0, '财经': 1, '房产': 2, '家居': 3, '教育': 4, '科技': 5, '时尚': 6, '时政': 7, '游戏': 8, '娱乐': 9}
data_bundle = self._granularize(data_bundle=data_bundle, tag_map=tag_map)
# clean,lower
# CWS(tokenize)
data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars')
input_field_names = [Const.CHAR_INPUT]
# n-grams
if self.bigrams:
for name, dataset in data_bundle.iter_datasets():
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])],
field_name=Const.CHAR_INPUT, new_field_name='bigrams')
input_field_names.append('bigrams')
if self.trigrams:
for name, dataset in data_bundle.iter_datasets():
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)],
field_name=Const.CHAR_INPUT, new_field_name='trigrams')
input_field_names.append('trigrams')
# index
data_bundle = _indexize(data_bundle=data_bundle, input_field_names=Const.CHAR_INPUT)
# add length
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(field_name=Const.CHAR_INPUT, new_field_name=Const.INPUT_LEN)
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names
target_fields = [Const.TARGET]
data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields)
return data_bundle
def process_from_file(self, paths=None):
"""
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。
:return: DataBundle
"""
data_loader = THUCNewsLoader() # 此处需要实例化一个data_loader,否则传入load()的参数为None
data_bundle = data_loader.load(paths)
data_bundle = self.process(data_bundle)
return data_bundle
class WeiboSenti100kPipe(_CLSPipe):
"""
处理之后的DataSet有以下的结构
.. csv-table::
:header: "raw_chars", "chars", "target", "seq_len"
"六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒]", "[0, 690, 18, ...]", 0, 56
"..."
其中chars, seq_len是input,target是target
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过
data_bundle.get_vocab('bigrams')获取.
:param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...]
。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过
data_bundle.get_vocab('trigrams')获取.
"""
def __init__(self, bigrams=False, trigrams=False):
super().__init__()
self.bigrams = bigrams
self.trigrams = trigrams
def _chracter_split(self, sent):
return list(sent)
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None):
new_field_name = new_field_name or field_name
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name)
return data_bundle
def process(self, data_bundle: DataBundle):
"""
可处理的DataSet应具备以下的field
.. csv-table::
:header: "raw_chars", "target"
"六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒]", "0"
"...", "..."
:param data_bundle:
:return:
"""
# clean,lower
# CWS(tokenize)
data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars')
input_field_names = [Const.CHAR_INPUT]
# n-grams
if self.bigrams:
for name, dataset in data_bundle.iter_datasets():
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])],
field_name=Const.CHAR_INPUT, new_field_name='bigrams')
input_field_names.append('bigrams')
if self.trigrams:
for name, dataset in data_bundle.iter_datasets():
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)],
field_name=Const.CHAR_INPUT, new_field_name='trigrams')
input_field_names.append('trigrams')
# index
data_bundle = _indexize(data_bundle=data_bundle, input_field_names='chars')
# add length
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(field_name=Const.CHAR_INPUT, new_field_name=Const.INPUT_LEN)
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names
target_fields = [Const.TARGET]
data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields)
return data_bundle
def process_from_file(self, paths=None):
"""
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。
:return: DataBundle
"""
data_loader = WeiboSenti100kLoader() # 此处需要实例化一个data_loader,否则传入load()的参数为None
data_bundle = data_loader.load(paths)
data_bundle = self.process(data_bundle)
return data_bundle