Browse Source

added yelpLoader

tags/v0.4.10
Violet Yao 5 years ago
parent
commit
2edb2a1a00
2 changed files with 70 additions and 1 deletions
  1. +2
    -1
      fastNLP/io/dataset_loader.py
  2. +68
    -0
      reproduction/text_classification/data/yelpLoader.py

+ 2
- 1
fastNLP/io/dataset_loader.py View File

@@ -29,6 +29,7 @@ from .file_reader import _read_csv, _read_json, _read_conll
from .base_loader import DataSetLoader
from .data_loader.sst import SSTLoader
from ..core.const import Const
import ast


class PeopleDailyCorpusLoader(DataSetLoader):
@@ -239,7 +240,7 @@ class JsonLoader(DataSetLoader):
if self.fields:
ins = {self.fields[k]: v for k, v in d.items()}
else:
ins = d
ins = ast.literal_eval(d)
ds.append(Instance(**ins))
return ds



+ 68
- 0
reproduction/text_classification/data/yelpLoader.py View File

@@ -0,0 +1,68 @@
import ast
from fastNLP import DataSet, Instance, Vocabulary
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io import JsonLoader
from fastNLP.io.base_loader import DataInfo
from fastNLP.io.embed_loader import EmbeddingOption
from fastNLP.io.file_reader import _read_json
from typing import Union, Dict
from reproduction.Star_transformer.datasets import EmbedLoader
from reproduction.utils import check_dataloader_paths


class yelpLoader(JsonLoader):
"""
读取Yelp数据集, DataSet包含fields:
review_id: str, 22 character unique review id
user_id: str, 22 character unique user id
business_id: str, 22 character business id
useful: int, number of useful votes received
funny: int, number of funny votes received
cool: int, number of cool votes received
date: str, date formatted YYYY-MM-DD
words: list(str), 需要分类的文本
target: str, 文本的标签
数据来源: https://www.yelp.com/dataset/download
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
"""
def __init__(self, fine_grained=False):
super(yelpLoader, self).__init__()
tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral',
'4.0': 'positive', '5.0': 'very positive'}
if not fine_grained:
tag_v['1.0'] = tag_v['2.0']
tag_v['5.0'] = tag_v['4.0']
self.fine_grained = fine_grained
self.tag_v = tag_v
def _load(self, path):
ds = DataSet()
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
d = ast.literal_eval(d)
d["words"] = d.pop("text").split()
d["target"] = self.tag_v[str(d.pop("stars"))]
ds.append(Instance(**d))
return ds

def process(self, paths: Union[str, Dict[str, str]], vocab_opt: VocabularyOption = None,
embed_opt: EmbeddingOption = None):
paths = check_dataloader_paths(paths)
datasets = {}
info = DataInfo()
vocab = Vocabulary(min_freq=2) if vocab_opt is None else Vocabulary(**vocab_opt)
for name, path in paths.items():
dataset = self.load(path)
datasets[name] = dataset
vocab.from_dataset(dataset, field_name="words")
info.vocabs = vocab
info.datasets = datasets
if embed_opt is not None:
embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab)
info.embeddings['words'] = embed
return info


Loading…
Cancel
Save