Browse Source

* fix README figure

* refine code style
tags/v0.3.0
FengZiYjun 6 years ago
parent
commit
5d8f6960a7
3 changed files with 16 additions and 15 deletions
  1. BIN
      docs/source/figures/text_classification.png
  2. +11
    -11
      fastNLP/io/dataset_loader.py
  3. +5
    -4
      test/io/test_dataset_loader.py

BIN
docs/source/figures/text_classification.png View File

Before After
Width: 1699  |  Height: 747  |  Size: 73 kB Width: 1699  |  Height: 722  |  Size: 74 kB

+ 11
- 11
fastNLP/io/dataset_loader.py View File

@@ -417,7 +417,7 @@ class PeopleDailyCorpusLoader(DataSetLoader):
data_set.set_input("seq_len")
return data_set

class Conll2003Loader(DataSetLoader):
"""Self-defined loader of conll2003 dataset
@@ -425,14 +425,14 @@ class Conll2003Loader(DataSetLoader):
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data
"""
def __init__(self):
super(Conll2003Loader, self).__init__()
def load(self, dataset_path):
with open(dataset_path, "r", encoding="utf-8") as f:
lines = f.readlines()
##Parse the dataset line by line
parsed_data = []
sentence = []
@@ -444,13 +444,13 @@ class Conll2003Loader(DataSetLoader):
sentence = []
tokens = []
continue
temp = line.strip().split(" ")
sentence.append(temp[0])
sentence.append(temp[0])
tokens.append(temp[1:4])
return self.convert(parsed_data)
def convert(self, parsed_data):
dataset = DataSet()
for sample in parsed_data:
@@ -460,11 +460,11 @@ class Conll2003Loader(DataSetLoader):
lambda labels: labels[1], sample[1]))
label2_list = list(map(
lambda labels: labels[2], sample[1]))
dataset.append(Instance(token_list=sample[0],
label0_list=label0_list,
dataset.append(Instance(token_list=sample[0],
label0_list=label0_list,
label1_list=label1_list,
label2_list=label2_list))
return dataset

class SNLIDataSetLoader(DataSetLoader):


+ 5
- 4
test/io/test_dataset_loader.py View File

@@ -1,9 +1,10 @@
import os
import unittest

from fastNLP.io.dataset_loader import Conll2003Loader


class TestDatasetLoader(unittest.TestCase):
def test_case_1(self):
'''
Test the the loader of Conll2003 dataset
@@ -12,7 +13,7 @@ class TestDatasetLoader(unittest.TestCase):
dataset_path = "test/data_for_tests/conll_2003_example.txt"
loader = Conll2003Loader()
dataset_2003 = loader.load(dataset_path)
for item in dataset_2003:
len0 = len(item["label0_list"])
len1 = len(item["label1_list"])
@@ -20,4 +21,4 @@ class TestDatasetLoader(unittest.TestCase):
lentoken = len(item["token_list"])
self.assertNotEqual(len0, 0)
self.assertEqual(len0, len1)
self.assertEqual(len1, len2)
self.assertEqual(len1, len2)

Loading…
Cancel
Save