From 2edb2a1a007176025a76cfd68c8fe80a726d4f0b Mon Sep 17 00:00:00 2001 From: Violet Yao Date: Sat, 8 Jun 2019 14:27:52 +0800 Subject: [PATCH] added yelpLoader --- fastNLP/io/dataset_loader.py | 3 +- .../text_classification/data/yelpLoader.py | 68 +++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 reproduction/text_classification/data/yelpLoader.py diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index e366c6ea..3b5e897c 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -29,6 +29,7 @@ from .file_reader import _read_csv, _read_json, _read_conll from .base_loader import DataSetLoader from .data_loader.sst import SSTLoader from ..core.const import Const +import ast class PeopleDailyCorpusLoader(DataSetLoader): @@ -239,7 +240,7 @@ class JsonLoader(DataSetLoader): if self.fields: ins = {self.fields[k]: v for k, v in d.items()} else: - ins = d + ins = ast.literal_eval(d) ds.append(Instance(**ins)) return ds diff --git a/reproduction/text_classification/data/yelpLoader.py b/reproduction/text_classification/data/yelpLoader.py new file mode 100644 index 00000000..ed5db021 --- /dev/null +++ b/reproduction/text_classification/data/yelpLoader.py @@ -0,0 +1,68 @@ +import ast +from fastNLP import DataSet, Instance, Vocabulary +from fastNLP.core.vocabulary import VocabularyOption +from fastNLP.io import JsonLoader +from fastNLP.io.base_loader import DataInfo +from fastNLP.io.embed_loader import EmbeddingOption +from fastNLP.io.file_reader import _read_json +from typing import Union, Dict +from reproduction.Star_transformer.datasets import EmbedLoader +from reproduction.utils import check_dataloader_paths + + +class yelpLoader(JsonLoader): + + """ + 读取Yelp数据集, DataSet包含fields: + + review_id: str, 22 character unique review id + user_id: str, 22 character unique user id + business_id: str, 22 character business id + useful: int, number of useful votes received + funny: int, number of funny votes received + cool: int, number of cool votes received + date: str, date formatted YYYY-MM-DD + words: list(str), 需要分类的文本 + target: str, 文本的标签 + + 数据来源: https://www.yelp.com/dataset/download + + :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` + """ + + def __init__(self, fine_grained=False): + super(yelpLoader, self).__init__() + tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', + '4.0': 'positive', '5.0': 'very positive'} + if not fine_grained: + tag_v['1.0'] = tag_v['2.0'] + tag_v['5.0'] = tag_v['4.0'] + self.fine_grained = fine_grained + self.tag_v = tag_v + + def _load(self, path): + ds = DataSet() + for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): + d = ast.literal_eval(d) + d["words"] = d.pop("text").split() + d["target"] = self.tag_v[str(d.pop("stars"))] + ds.append(Instance(**d)) + return ds + + def process(self, paths: Union[str, Dict[str, str]], vocab_opt: VocabularyOption = None, + embed_opt: EmbeddingOption = None): + paths = check_dataloader_paths(paths) + datasets = {} + info = DataInfo() + vocab = Vocabulary(min_freq=2) if vocab_opt is None else Vocabulary(**vocab_opt) + for name, path in paths.items(): + dataset = self.load(path) + datasets[name] = dataset + vocab.from_dataset(dataset, field_name="words") + info.vocabs = vocab + info.datasets = datasets + if embed_opt is not None: + embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab) + info.embeddings['words'] = embed + return info +