diff --git a/docs/source/figures/text_classification.png b/docs/source/figures/text_classification.png index 183aaba9..0d36a2a1 100644 Binary files a/docs/source/figures/text_classification.png and b/docs/source/figures/text_classification.png differ diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 76b9584d..fedf8058 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -417,7 +417,7 @@ class PeopleDailyCorpusLoader(DataSetLoader): data_set.set_input("seq_len") return data_set - + class Conll2003Loader(DataSetLoader): """Self-defined loader of conll2003 dataset @@ -425,14 +425,14 @@ class Conll2003Loader(DataSetLoader): https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data """ - + def __init__(self): super(Conll2003Loader, self).__init__() - + def load(self, dataset_path): with open(dataset_path, "r", encoding="utf-8") as f: lines = f.readlines() - + ##Parse the dataset line by line parsed_data = [] sentence = [] @@ -444,13 +444,13 @@ class Conll2003Loader(DataSetLoader): sentence = [] tokens = [] continue - + temp = line.strip().split(" ") - sentence.append(temp[0]) + sentence.append(temp[0]) tokens.append(temp[1:4]) - + return self.convert(parsed_data) - + def convert(self, parsed_data): dataset = DataSet() for sample in parsed_data: @@ -460,11 +460,11 @@ class Conll2003Loader(DataSetLoader): lambda labels: labels[1], sample[1])) label2_list = list(map( lambda labels: labels[2], sample[1])) - dataset.append(Instance(token_list=sample[0], - label0_list=label0_list, + dataset.append(Instance(token_list=sample[0], + label0_list=label0_list, label1_list=label1_list, label2_list=label2_list)) - + return dataset class SNLIDataSetLoader(DataSetLoader): diff --git a/test/io/test_dataset_loader.py b/test/io/test_dataset_loader.py index 9bee175b..cf38c973 100644 --- a/test/io/test_dataset_loader.py +++ b/test/io/test_dataset_loader.py @@ -1,9 +1,10 @@ -import os import unittest from fastNLP.io.dataset_loader import Conll2003Loader + + class TestDatasetLoader(unittest.TestCase): - + def test_case_1(self): ''' Test the the loader of Conll2003 dataset @@ -12,7 +13,7 @@ class TestDatasetLoader(unittest.TestCase): dataset_path = "test/data_for_tests/conll_2003_example.txt" loader = Conll2003Loader() dataset_2003 = loader.load(dataset_path) - + for item in dataset_2003: len0 = len(item["label0_list"]) len1 = len(item["label1_list"]) @@ -20,4 +21,4 @@ class TestDatasetLoader(unittest.TestCase): lentoken = len(item["token_list"]) self.assertNotEqual(len0, 0) self.assertEqual(len0, len1) - self.assertEqual(len1, len2) \ No newline at end of file + self.assertEqual(len1, len2)