diff --git a/reproduction/text_classification/data/sstLoader.py b/reproduction/text_classification/data/sstLoader.py index d8403b7a..e1907d8f 100644 --- a/reproduction/text_classification/data/sstLoader.py +++ b/reproduction/text_classification/data/sstLoader.py @@ -7,6 +7,7 @@ from fastNLP import Instance from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader import csv from typing import Union, Dict +from reproduction.utils import check_dataloader_paths, get_tokenizer class SSTLoader(DataSetLoader): URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' @@ -104,6 +105,7 @@ class sst2Loader(DataSetLoader): ''' def __init__(self): super(sst2Loader, self).__init__() + self.tokenizer = get_tokenizer() def _load(self, path: str) -> DataSet: ds = DataSet() @@ -114,7 +116,7 @@ class sst2Loader(DataSetLoader): if idx<=skip_row: continue target = row[1] - words = row[0].split() + words=self.tokenizer(words) ds.append(Instance(words=words,target=target)) all_count+=1 print("all count:", all_count)