|
@@ -7,6 +7,7 @@ from fastNLP import Instance |
|
|
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader |
|
|
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader |
|
|
import csv |
|
|
import csv |
|
|
from typing import Union, Dict |
|
|
from typing import Union, Dict |
|
|
|
|
|
from reproduction.utils import check_dataloader_paths, get_tokenizer |
|
|
|
|
|
|
|
|
class SSTLoader(DataSetLoader): |
|
|
class SSTLoader(DataSetLoader): |
|
|
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' |
|
|
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' |
|
@@ -104,6 +105,7 @@ class sst2Loader(DataSetLoader): |
|
|
''' |
|
|
''' |
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
super(sst2Loader, self).__init__() |
|
|
super(sst2Loader, self).__init__() |
|
|
|
|
|
self.tokenizer = get_tokenizer() |
|
|
|
|
|
|
|
|
def _load(self, path: str) -> DataSet: |
|
|
def _load(self, path: str) -> DataSet: |
|
|
ds = DataSet() |
|
|
ds = DataSet() |
|
@@ -114,7 +116,7 @@ class sst2Loader(DataSetLoader): |
|
|
if idx<=skip_row: |
|
|
if idx<=skip_row: |
|
|
continue |
|
|
continue |
|
|
target = row[1] |
|
|
target = row[1] |
|
|
words = row[0].split() |
|
|
|
|
|
|
|
|
words=self.tokenizer(words) |
|
|
ds.append(Instance(words=words,target=target)) |
|
|
ds.append(Instance(words=words,target=target)) |
|
|
all_count+=1 |
|
|
all_count+=1 |
|
|
print("all count:", all_count) |
|
|
print("all count:", all_count) |
|
|