|
|
@@ -21,6 +21,8 @@ from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_insta |
|
|
|
from ..data_bundle import DataBundle |
|
|
|
from ..loader.classification import ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader |
|
|
|
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader |
|
|
|
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader, \ |
|
|
|
AGsNewsLoader |
|
|
|
from ...core._logger import logger |
|
|
|
from ...core.const import Const |
|
|
|
from ...core.dataset import DataSet |
|
|
@@ -272,6 +274,49 @@ class YelpPolarityPipe(_CLSPipe): |
|
|
|
return self.process(data_bundle=data_bundle) |
|
|
|
|
|
|
|
|
|
|
|
class AGsNewsPipe(YelpFullPipe): |
|
|
|
""" |
|
|
|
处理AG's News的数据, 处理之后DataSet中的内容如下 |
|
|
|
|
|
|
|
.. csv-table:: 下面是使用AGsNewsPipe处理后的DataSet所具备的field |
|
|
|
:header: "raw_words", "target", "words", "seq_len" |
|
|
|
|
|
|
|
"I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160 |
|
|
|
" Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 |
|
|
|
"...", ., "[...]", . |
|
|
|
|
|
|
|
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: |
|
|
|
|
|
|
|
+-------------+-----------+--------+-------+---------+ |
|
|
|
| field_names | raw_words | target | words | seq_len | |
|
|
|
+-------------+-----------+--------+-------+---------+ |
|
|
|
| is_input | False | False | True | True | |
|
|
|
| is_target | False | True | False | False | |
|
|
|
| ignore_type | | False | False | False | |
|
|
|
| pad_value | | 0 | 0 | 0 | |
|
|
|
+-------------+-----------+--------+-------+---------+ |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): |
|
|
|
""" |
|
|
|
|
|
|
|
:param bool lower: 是否对输入进行小写化。 |
|
|
|
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 |
|
|
|
""" |
|
|
|
super().__init__(lower=lower, tokenizer=tokenizer) |
|
|
|
self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3} |
|
|
|
|
|
|
|
def process_from_file(self, paths=None): |
|
|
|
""" |
|
|
|
|
|
|
|
:param str paths: |
|
|
|
:return: DataBundle |
|
|
|
""" |
|
|
|
data_bundle = AGsNewsLoader().load(paths) |
|
|
|
return self.process(data_bundle=data_bundle) |
|
|
|
|
|
|
|
|
|
|
|
class SSTPipe(_CLSPipe): |
|
|
|
""" |
|
|
|
经过该Pipe之后,DataSet中具备的field如下所示 |
|
|
|