|
|
@@ -7,6 +7,7 @@ from fastNLP import Instance |
|
|
|
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader |
|
|
|
import csv |
|
|
|
from typing import Union, Dict |
|
|
|
from reproduction.utils import check_dataloader_paths, get_tokenizer |
|
|
|
|
|
|
|
class SSTLoader(DataSetLoader): |
|
|
|
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' |
|
|
@@ -104,6 +105,7 @@ class sst2Loader(DataSetLoader): |
|
|
|
''' |
|
|
|
def __init__(self): |
|
|
|
super(sst2Loader, self).__init__() |
|
|
|
self.tokenizer = get_tokenizer() |
|
|
|
|
|
|
|
def _load(self, path: str) -> DataSet: |
|
|
|
ds = DataSet() |
|
|
@@ -114,7 +116,7 @@ class sst2Loader(DataSetLoader): |
|
|
|
if idx<=skip_row: |
|
|
|
continue |
|
|
|
target = row[1] |
|
|
|
words = row[0].split() |
|
|
|
words=self.tokenizer(row[0]) |
|
|
|
ds.append(Instance(words=words,target=target)) |
|
|
|
all_count+=1 |
|
|
|
print("all count:", all_count) |
|
|
@@ -135,11 +137,13 @@ class sst2Loader(DataSetLoader): |
|
|
|
datasets[name] = dataset |
|
|
|
|
|
|
|
def wordtochar(words): |
|
|
|
chars=[] |
|
|
|
chars = [] |
|
|
|
for word in words: |
|
|
|
word=word.lower() |
|
|
|
word = word.lower() |
|
|
|
for char in word: |
|
|
|
chars.append(char) |
|
|
|
chars.append('') |
|
|
|
chars.pop() |
|
|
|
return chars |
|
|
|
|
|
|
|
input_name, target_name = 'words', 'target' |
|
|
|