@@ -98,10 +98,10 @@ class DataSet(object): | |||||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | ||||
def __getattr__(self, item): | def __getattr__(self, item): | ||||
# Not tested. Don't use !! | |||||
if item == "field_arrays": | if item == "field_arrays": | ||||
raise AttributeError | raise AttributeError | ||||
# TODO dataset.x | |||||
if item in self.field_arrays: | |||||
if isinstance(item, str) and item in self.field_arrays: | |||||
return self.field_arrays[item] | return self.field_arrays[item] | ||||
try: | try: | ||||
reader = DataLoaderRegister.get_reader(item) | reader = DataLoaderRegister.get_reader(item) | ||||
@@ -85,7 +85,7 @@ class Trainer(object): | |||||
if metric_key is not None: | if metric_key is not None: | ||||
self.increase_better = False if metric_key[0] == "-" else True | self.increase_better = False if metric_key[0] == "-" else True | ||||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | ||||
elif metrics is not None: | |||||
elif len(metrics) > 0: | |||||
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') | self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') | ||||
# prepare loss | # prepare loss | ||||
@@ -31,22 +31,6 @@ class BaseLoader(object): | |||||
return obj | return obj | ||||
class ToyLoader0(BaseLoader): | |||||
""" | |||||
For CharLM | |||||
""" | |||||
def __init__(self, data_path): | |||||
super(ToyLoader0, self).__init__(data_path) | |||||
def load(self): | |||||
with open(self.data_path, 'r') as f: | |||||
corpus = f.read().lower() | |||||
import re | |||||
corpus = re.sub(r"<unk>", "unk", corpus) | |||||
return corpus.split() | |||||
class DataLoaderRegister: | class DataLoaderRegister: | ||||
""""register for data sets""" | """"register for data sets""" | ||||
_readers = {} | _readers = {} | ||||
@@ -75,7 +75,6 @@ class DataSetLoader: | |||||
raise NotImplementedError | raise NotImplementedError | ||||
@DataSet.set_reader("read_naive") | |||||
class NativeDataSetLoader(DataSetLoader): | class NativeDataSetLoader(DataSetLoader): | ||||
def __init__(self): | def __init__(self): | ||||
super(NativeDataSetLoader, self).__init__() | super(NativeDataSetLoader, self).__init__() | ||||
@@ -87,7 +86,9 @@ class NativeDataSetLoader(DataSetLoader): | |||||
return ds | return ds | ||||
@DataSet.set_reader('read_raw') | |||||
DataLoaderRegister.set_reader(NativeDataSetLoader, 'read_naive') | |||||
class RawDataSetLoader(DataSetLoader): | class RawDataSetLoader(DataSetLoader): | ||||
def __init__(self): | def __init__(self): | ||||
super(RawDataSetLoader, self).__init__() | super(RawDataSetLoader, self).__init__() | ||||
@@ -101,6 +102,8 @@ class RawDataSetLoader(DataSetLoader): | |||||
def convert(self, data): | def convert(self, data): | ||||
return convert_seq_dataset(data) | return convert_seq_dataset(data) | ||||
DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') | DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') | ||||
@@ -171,6 +174,8 @@ class POSDataSetLoader(DataSetLoader): | |||||
"""Convert lists of strings into Instances with Fields. | """Convert lists of strings into Instances with Fields. | ||||
""" | """ | ||||
return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos') | DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos') | ||||
@@ -348,7 +353,6 @@ class LMDataSetLoader(DataSetLoader): | |||||
pass | pass | ||||
@DataSet.set_reader('read_people_daily') | |||||
class PeopleDailyCorpusLoader(DataSetLoader): | class PeopleDailyCorpusLoader(DataSetLoader): | ||||
""" | """ | ||||
People Daily Corpus: Chinese word segmentation, POS tag, NER | People Daily Corpus: Chinese word segmentation, POS tag, NER | ||||
@@ -178,6 +178,20 @@ class TestDataSet(unittest.TestCase): | |||||
self.assertTrue(isinstance(ans, FieldArray)) | self.assertTrue(isinstance(ans, FieldArray)) | ||||
self.assertEqual(ans.content, [[5, 6]] * 10) | self.assertEqual(ans.content, [[5, 6]] * 10) | ||||
def test_reader(self): | |||||
# 跑通即可 | |||||
ds = DataSet().read_naive("test/data_for_tests/tutorial_sample_dataset.csv") | |||||
self.assertTrue(isinstance(ds, DataSet)) | |||||
self.assertTrue(len(ds) > 0) | |||||
ds = DataSet().read_rawdata("test/data_for_tests/people_daily_raw.txt") | |||||
self.assertTrue(isinstance(ds, DataSet)) | |||||
self.assertTrue(len(ds) > 0) | |||||
ds = DataSet().read_pos("test/data_for_tests/people.txt") | |||||
self.assertTrue(isinstance(ds, DataSet)) | |||||
self.assertTrue(len(ds) > 0) | |||||
class TestDataSetIter(unittest.TestCase): | class TestDataSetIter(unittest.TestCase): | ||||
def test__repr__(self): | def test__repr__(self): | ||||
@@ -7,7 +7,7 @@ from fastNLP.core.optimizer import SGD, Adam | |||||
class TestOptim(unittest.TestCase): | class TestOptim(unittest.TestCase): | ||||
def test_SGD(self): | def test_SGD(self): | ||||
optim = SGD(torch.nn.Linear(10, 3).parameters()) | |||||
optim = SGD(model_params=torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue("lr" in optim.__dict__["settings"]) | self.assertTrue("lr" in optim.__dict__["settings"]) | ||||
self.assertTrue("momentum" in optim.__dict__["settings"]) | self.assertTrue("momentum" in optim.__dict__["settings"]) | ||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | ||||
@@ -22,13 +22,18 @@ class TestOptim(unittest.TestCase): | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | ||||
self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989) | self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989) | ||||
with self.assertRaises(RuntimeError): | |||||
optim = SGD(0.001) | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue(isinstance(res, torch.optim.SGD)) | |||||
with self.assertRaises(TypeError): | |||||
_ = SGD("???") | _ = SGD("???") | ||||
with self.assertRaises(RuntimeError): | |||||
with self.assertRaises(TypeError): | |||||
_ = SGD(0.001, lr=0.002) | _ = SGD(0.001, lr=0.002) | ||||
def test_Adam(self): | def test_Adam(self): | ||||
optim = Adam(torch.nn.Linear(10, 3).parameters()) | |||||
optim = Adam(model_params=torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue("lr" in optim.__dict__["settings"]) | self.assertTrue("lr" in optim.__dict__["settings"]) | ||||
self.assertTrue("weight_decay" in optim.__dict__["settings"]) | self.assertTrue("weight_decay" in optim.__dict__["settings"]) | ||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | ||||
@@ -42,3 +47,8 @@ class TestOptim(unittest.TestCase): | |||||
optim = Adam(lr=0.002, weight_decay=0.989) | optim = Adam(lr=0.002, weight_decay=0.989) | ||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | ||||
self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989) | self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989) | ||||
optim = Adam(0.001) | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue(isinstance(res, torch.optim.Adam)) |