Browse Source

* final clean up

* remove conflicts
* all tests passed
tags/v0.2.0^2
FengZiYjun 5 years ago
parent
commit
db0a789d61
6 changed files with 38 additions and 26 deletions
  1. +2
    -2
      fastNLP/core/dataset.py
  2. +1
    -1
      fastNLP/core/trainer.py
  3. +0
    -16
      fastNLP/io/base_loader.py
  4. +7
    -3
      fastNLP/io/dataset_loader.py
  5. +14
    -0
      test/core/test_dataset.py
  6. +14
    -4
      test/core/test_optimizer.py

+ 2
- 2
fastNLP/core/dataset.py View File

@@ -98,10 +98,10 @@ class DataSet(object):
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))


def __getattr__(self, item): def __getattr__(self, item):
# Not tested. Don't use !!
if item == "field_arrays": if item == "field_arrays":
raise AttributeError raise AttributeError
# TODO dataset.x
if item in self.field_arrays:
if isinstance(item, str) and item in self.field_arrays:
return self.field_arrays[item] return self.field_arrays[item]
try: try:
reader = DataLoaderRegister.get_reader(item) reader = DataLoaderRegister.get_reader(item)


+ 1
- 1
fastNLP/core/trainer.py View File

@@ -85,7 +85,7 @@ class Trainer(object):
if metric_key is not None: if metric_key is not None:
self.increase_better = False if metric_key[0] == "-" else True self.increase_better = False if metric_key[0] == "-" else True
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
elif metrics is not None:
elif len(metrics) > 0:
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') self.metric_key = metrics[0].__class__.__name__.lower().strip('metric')


# prepare loss # prepare loss


+ 0
- 16
fastNLP/io/base_loader.py View File

@@ -31,22 +31,6 @@ class BaseLoader(object):
return obj return obj




class ToyLoader0(BaseLoader):
"""
For CharLM
"""

def __init__(self, data_path):
super(ToyLoader0, self).__init__(data_path)

def load(self):
with open(self.data_path, 'r') as f:
corpus = f.read().lower()
import re
corpus = re.sub(r"<unk>", "unk", corpus)
return corpus.split()


class DataLoaderRegister: class DataLoaderRegister:
""""register for data sets""" """"register for data sets"""
_readers = {} _readers = {}


+ 7
- 3
fastNLP/io/dataset_loader.py View File

@@ -75,7 +75,6 @@ class DataSetLoader:
raise NotImplementedError raise NotImplementedError




@DataSet.set_reader("read_naive")
class NativeDataSetLoader(DataSetLoader): class NativeDataSetLoader(DataSetLoader):
def __init__(self): def __init__(self):
super(NativeDataSetLoader, self).__init__() super(NativeDataSetLoader, self).__init__()
@@ -87,7 +86,9 @@ class NativeDataSetLoader(DataSetLoader):
return ds return ds




@DataSet.set_reader('read_raw')
DataLoaderRegister.set_reader(NativeDataSetLoader, 'read_naive')


class RawDataSetLoader(DataSetLoader): class RawDataSetLoader(DataSetLoader):
def __init__(self): def __init__(self):
super(RawDataSetLoader, self).__init__() super(RawDataSetLoader, self).__init__()
@@ -101,6 +102,8 @@ class RawDataSetLoader(DataSetLoader):


def convert(self, data): def convert(self, data):
return convert_seq_dataset(data) return convert_seq_dataset(data)


DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata')




@@ -171,6 +174,8 @@ class POSDataSetLoader(DataSetLoader):
"""Convert lists of strings into Instances with Fields. """Convert lists of strings into Instances with Fields.
""" """
return convert_seq2seq_dataset(data) return convert_seq2seq_dataset(data)


DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos') DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos')




@@ -348,7 +353,6 @@ class LMDataSetLoader(DataSetLoader):
pass pass




@DataSet.set_reader('read_people_daily')
class PeopleDailyCorpusLoader(DataSetLoader): class PeopleDailyCorpusLoader(DataSetLoader):
""" """
People Daily Corpus: Chinese word segmentation, POS tag, NER People Daily Corpus: Chinese word segmentation, POS tag, NER


+ 14
- 0
test/core/test_dataset.py View File

@@ -178,6 +178,20 @@ class TestDataSet(unittest.TestCase):
self.assertTrue(isinstance(ans, FieldArray)) self.assertTrue(isinstance(ans, FieldArray))
self.assertEqual(ans.content, [[5, 6]] * 10) self.assertEqual(ans.content, [[5, 6]] * 10)


def test_reader(self):
# 跑通即可
ds = DataSet().read_naive("test/data_for_tests/tutorial_sample_dataset.csv")
self.assertTrue(isinstance(ds, DataSet))
self.assertTrue(len(ds) > 0)

ds = DataSet().read_rawdata("test/data_for_tests/people_daily_raw.txt")
self.assertTrue(isinstance(ds, DataSet))
self.assertTrue(len(ds) > 0)

ds = DataSet().read_pos("test/data_for_tests/people.txt")
self.assertTrue(isinstance(ds, DataSet))
self.assertTrue(len(ds) > 0)



class TestDataSetIter(unittest.TestCase): class TestDataSetIter(unittest.TestCase):
def test__repr__(self): def test__repr__(self):


+ 14
- 4
test/core/test_optimizer.py View File

@@ -7,7 +7,7 @@ from fastNLP.core.optimizer import SGD, Adam


class TestOptim(unittest.TestCase): class TestOptim(unittest.TestCase):
def test_SGD(self): def test_SGD(self):
optim = SGD(torch.nn.Linear(10, 3).parameters())
optim = SGD(model_params=torch.nn.Linear(10, 3).parameters())
self.assertTrue("lr" in optim.__dict__["settings"]) self.assertTrue("lr" in optim.__dict__["settings"])
self.assertTrue("momentum" in optim.__dict__["settings"]) self.assertTrue("momentum" in optim.__dict__["settings"])
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
@@ -22,13 +22,18 @@ class TestOptim(unittest.TestCase):
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989) self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989)


with self.assertRaises(RuntimeError):
optim = SGD(0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.SGD))

with self.assertRaises(TypeError):
_ = SGD("???") _ = SGD("???")
with self.assertRaises(RuntimeError):
with self.assertRaises(TypeError):
_ = SGD(0.001, lr=0.002) _ = SGD(0.001, lr=0.002)


def test_Adam(self): def test_Adam(self):
optim = Adam(torch.nn.Linear(10, 3).parameters())
optim = Adam(model_params=torch.nn.Linear(10, 3).parameters())
self.assertTrue("lr" in optim.__dict__["settings"]) self.assertTrue("lr" in optim.__dict__["settings"])
self.assertTrue("weight_decay" in optim.__dict__["settings"]) self.assertTrue("weight_decay" in optim.__dict__["settings"])
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
@@ -42,3 +47,8 @@ class TestOptim(unittest.TestCase):
optim = Adam(lr=0.002, weight_decay=0.989) optim = Adam(lr=0.002, weight_decay=0.989)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989) self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989)

optim = Adam(0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.Adam))

Loading…
Cancel
Save