* DataSet修改__repr__,优化print(datset)的输出 * Instance修改__repr__,优化print的输出 * Optimizer优化传参提示 * Trainer去除kwargs参数 * losses.py加个参数 * 对应test code的修改tags/v0.2.0^2
@@ -110,6 +110,15 @@ class DataSet(object): | |||||
field = iter(self.field_arrays.values()).__next__() | field = iter(self.field_arrays.values()).__next__() | ||||
return len(field) | return len(field) | ||||
def __inner_repr__(self): | |||||
if len(self) < 20: | |||||
return ",\n".join([ins.__repr__() for ins in self]) | |||||
else: | |||||
return self[:5].__inner_repr__() + "\n...\n" + self[-5:].__inner_repr__() | |||||
def __repr__(self): | |||||
return "DataSet(" + self.__inner_repr__() + ")" | |||||
def append(self, ins): | def append(self, ins): | ||||
"""Add an instance to the DataSet. | """Add an instance to the DataSet. | ||||
If the DataSet is not empty, the instance must have the same field names as the rest instances in the DataSet. | If the DataSet is not empty, the instance must have the same field names as the rest instances in the DataSet. | ||||
@@ -1,5 +1,3 @@ | |||||
class Instance(object): | class Instance(object): | ||||
"""An Instance is an example of data. It is the collection of Fields. | """An Instance is an example of data. It is the collection of Fields. | ||||
@@ -33,4 +31,5 @@ class Instance(object): | |||||
return self.add_field(name, field) | return self.add_field(name, field) | ||||
def __repr__(self): | def __repr__(self): | ||||
return self.fields.__repr__() | |||||
return "{" + ",\n".join( | |||||
"\'" + field_name + "\': " + str(self.fields[field_name]) for field_name in self.fields) + "}" |
@@ -202,6 +202,7 @@ class LossInForward(LossBase): | |||||
all_needed=[], | all_needed=[], | ||||
varargs=[]) | varargs=[]) | ||||
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) | raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) | ||||
return kwargs[self.loss_key] | |||||
def __call__(self, pred_dict, target_dict, check=False): | def __call__(self, pred_dict, target_dict, check=False): | ||||
@@ -10,34 +10,7 @@ class Optimizer(object): | |||||
class SGD(Optimizer): | class SGD(Optimizer): | ||||
def __init__(self, *args, **kwargs): | |||||
model_params, lr, momentum = None, 0.01, 0.9 | |||||
if len(args) == 0 and len(kwargs) == 0: | |||||
# SGD() | |||||
pass | |||||
elif len(args) == 1 and len(kwargs) == 0: | |||||
if isinstance(args[0], float) or isinstance(args[0], int): | |||||
# SGD(0.001) | |||||
lr = args[0] | |||||
elif hasattr(args[0], "__next__"): | |||||
# SGD(model.parameters()) args[0] is a generator | |||||
model_params = args[0] | |||||
else: | |||||
raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||||
elif 2 >= len(kwargs) > 0 and len(args) <= 1: | |||||
# SGD(lr=0.01), SGD(lr=0.01, momentum=0.9), SGD(model.parameters(), lr=0.1, momentum=0.9) | |||||
if len(args) == 1: | |||||
if hasattr(args[0], "__next__"): | |||||
model_params = args[0] | |||||
else: | |||||
raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||||
if not all(key in ("lr", "momentum") for key in kwargs): | |||||
raise RuntimeError("Invalid SGD arguments. Expect {}, got {}.".format(("lr", "momentum"), kwargs)) | |||||
lr = kwargs.get("lr", 0.01) | |||||
momentum = kwargs.get("momentum", 0.9) | |||||
else: | |||||
raise RuntimeError("SGD only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args)) | |||||
def __init__(self, model_params=None, lr=0.01, momentum=0): | |||||
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) | super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) | ||||
def construct_from_pytorch(self, model_params): | def construct_from_pytorch(self, model_params): | ||||
@@ -49,30 +22,7 @@ class SGD(Optimizer): | |||||
class Adam(Optimizer): | class Adam(Optimizer): | ||||
def __init__(self, *args, **kwargs): | |||||
model_params, lr, weight_decay = None, 0.01, 0.9 | |||||
if len(args) == 0 and len(kwargs) == 0: | |||||
pass | |||||
elif len(args) == 1 and len(kwargs) == 0: | |||||
if isinstance(args[0], float) or isinstance(args[0], int): | |||||
lr = args[0] | |||||
elif hasattr(args[0], "__next__"): | |||||
model_params = args[0] | |||||
else: | |||||
raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||||
elif 2 >= len(kwargs) > 0 and len(args) <= 1: | |||||
if len(args) == 1: | |||||
if hasattr(args[0], "__next__"): | |||||
model_params = args[0] | |||||
else: | |||||
raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||||
if not all(key in ("lr", "weight_decay") for key in kwargs): | |||||
raise RuntimeError("Invalid Adam arguments. Expect {}, got {}.".format(("lr", "weight_decay"), kwargs)) | |||||
lr = kwargs.get("lr", 0.01) | |||||
weight_decay = kwargs.get("weight_decay", 0.9) | |||||
else: | |||||
raise RuntimeError("Adam only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args)) | |||||
def __init__(self, model_params=None, lr=0.01, weight_decay=0): | |||||
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) | super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) | ||||
def construct_from_pytorch(self, model_params): | def construct_from_pytorch(self, model_params): | ||||
@@ -32,8 +32,7 @@ class Trainer(object): | |||||
validate_every=-1, | validate_every=-1, | ||||
dev_data=None, use_cuda=False, save_path=None, | dev_data=None, use_cuda=False, save_path=None, | ||||
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | ||||
metric_key=None, | |||||
**kwargs): | |||||
metric_key=None): | |||||
""" | """ | ||||
:param DataSet train_data: the training data | :param DataSet train_data: the training data | ||||
@@ -44,6 +44,9 @@ class TestDataSet(unittest.TestCase): | |||||
self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) | self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) | ||||
self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) | self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) | ||||
with self.assertRaises(RuntimeError): | |||||
dd.add_field("??", [[1, 2]] * 40) | |||||
def test_delete_field(self): | def test_delete_field(self): | ||||
dd = DataSet() | dd = DataSet() | ||||
dd.add_field("x", [[1, 2, 3]] * 10) | dd.add_field("x", [[1, 2, 3]] * 10) | ||||
@@ -65,8 +68,66 @@ class TestDataSet(unittest.TestCase): | |||||
self.assertTrue(isinstance(sub_ds, DataSet)) | self.assertTrue(isinstance(sub_ds, DataSet)) | ||||
self.assertEqual(len(sub_ds), 10) | self.assertEqual(len(sub_ds), 10) | ||||
def test_get_item_error(self): | |||||
with self.assertRaises(RuntimeError): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
_ = ds[40:] | |||||
with self.assertRaises(KeyError): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
_ = ds["kom"] | |||||
def test_len_(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
self.assertEqual(len(ds), 40) | |||||
ds = DataSet() | |||||
self.assertEqual(len(ds), 0) | |||||
def test_apply(self): | def test_apply(self): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ||||
ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx") | ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx") | ||||
self.assertTrue("rx" in ds.field_arrays) | self.assertTrue("rx" in ds.field_arrays) | ||||
self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) | self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) | ||||
def test_contains(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
self.assertTrue("x" in ds) | |||||
self.assertTrue("y" in ds) | |||||
self.assertFalse("z" in ds) | |||||
def test_rename_field(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
ds.rename_field("x", "xx") | |||||
self.assertTrue("xx" in ds) | |||||
self.assertFalse("x" in ds) | |||||
with self.assertRaises(KeyError): | |||||
ds.rename_field("yyy", "oo") | |||||
def test_input_target(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
ds.set_input("x") | |||||
ds.set_target("y") | |||||
self.assertTrue(ds.field_arrays["x"].is_input) | |||||
self.assertTrue(ds.field_arrays["y"].is_target) | |||||
with self.assertRaises(KeyError): | |||||
ds.set_input("xxx") | |||||
with self.assertRaises(KeyError): | |||||
ds.set_input("yyy") | |||||
def test_get_input_name(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
self.assertEqual(ds.get_input_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_input]) | |||||
def test_get_target_name(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target]) | |||||
class TestDataSetIter(unittest.TestCase): | |||||
def test__repr__(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
for iter in ds: | |||||
self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4], 'y': [5, 6]}") |
@@ -27,3 +27,9 @@ class TestCase(unittest.TestCase): | |||||
self.assertEqual(ins["x"], [1, 2, 3]) | self.assertEqual(ins["x"], [1, 2, 3]) | ||||
self.assertEqual(ins["y"], [4, 5, 6]) | self.assertEqual(ins["y"], [4, 5, 6]) | ||||
self.assertEqual(ins["z"], [1, 1, 1]) | self.assertEqual(ins["z"], [1, 1, 1]) | ||||
def test_repr(self): | |||||
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | |||||
ins = Instance(**fields) | |||||
# simple print, that is enough. | |||||
print(ins) |
@@ -11,9 +11,6 @@ class TestOptim(unittest.TestCase): | |||||
self.assertTrue("lr" in optim.__dict__["settings"]) | self.assertTrue("lr" in optim.__dict__["settings"]) | ||||
self.assertTrue("momentum" in optim.__dict__["settings"]) | self.assertTrue("momentum" in optim.__dict__["settings"]) | ||||
optim = SGD(0.001) | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||||
optim = SGD(lr=0.001) | optim = SGD(lr=0.001) | ||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | ||||
@@ -25,17 +22,12 @@ class TestOptim(unittest.TestCase): | |||||
_ = SGD("???") | _ = SGD("???") | ||||
with self.assertRaises(RuntimeError): | with self.assertRaises(RuntimeError): | ||||
_ = SGD(0.001, lr=0.002) | _ = SGD(0.001, lr=0.002) | ||||
with self.assertRaises(RuntimeError): | |||||
_ = SGD(lr=0.009, shit=9000) | |||||
def test_Adam(self): | def test_Adam(self): | ||||
optim = Adam(torch.nn.Linear(10, 3).parameters()) | optim = Adam(torch.nn.Linear(10, 3).parameters()) | ||||
self.assertTrue("lr" in optim.__dict__["settings"]) | self.assertTrue("lr" in optim.__dict__["settings"]) | ||||
self.assertTrue("weight_decay" in optim.__dict__["settings"]) | self.assertTrue("weight_decay" in optim.__dict__["settings"]) | ||||
optim = Adam(0.001) | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||||
optim = Adam(lr=0.001) | optim = Adam(lr=0.001) | ||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | ||||