* DataSet修改__repr__,优化print(datset)的输出 * Instance修改__repr__,优化print的输出 * Optimizer优化传参提示 * Trainer去除kwargs参数 * losses.py加个参数 * 对应test code的修改tags/v0.2.0^2
@@ -110,6 +110,15 @@ class DataSet(object): | |||
field = iter(self.field_arrays.values()).__next__() | |||
return len(field) | |||
def __inner_repr__(self): | |||
if len(self) < 20: | |||
return ",\n".join([ins.__repr__() for ins in self]) | |||
else: | |||
return self[:5].__inner_repr__() + "\n...\n" + self[-5:].__inner_repr__() | |||
def __repr__(self): | |||
return "DataSet(" + self.__inner_repr__() + ")" | |||
def append(self, ins): | |||
"""Add an instance to the DataSet. | |||
If the DataSet is not empty, the instance must have the same field names as the rest instances in the DataSet. | |||
@@ -1,5 +1,3 @@ | |||
class Instance(object): | |||
"""An Instance is an example of data. It is the collection of Fields. | |||
@@ -33,4 +31,5 @@ class Instance(object): | |||
return self.add_field(name, field) | |||
def __repr__(self): | |||
return self.fields.__repr__() | |||
return "{" + ",\n".join( | |||
"\'" + field_name + "\': " + str(self.fields[field_name]) for field_name in self.fields) + "}" |
@@ -202,6 +202,7 @@ class LossInForward(LossBase): | |||
all_needed=[], | |||
varargs=[]) | |||
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) | |||
return kwargs[self.loss_key] | |||
def __call__(self, pred_dict, target_dict, check=False): | |||
@@ -10,34 +10,7 @@ class Optimizer(object): | |||
class SGD(Optimizer): | |||
def __init__(self, *args, **kwargs): | |||
model_params, lr, momentum = None, 0.01, 0.9 | |||
if len(args) == 0 and len(kwargs) == 0: | |||
# SGD() | |||
pass | |||
elif len(args) == 1 and len(kwargs) == 0: | |||
if isinstance(args[0], float) or isinstance(args[0], int): | |||
# SGD(0.001) | |||
lr = args[0] | |||
elif hasattr(args[0], "__next__"): | |||
# SGD(model.parameters()) args[0] is a generator | |||
model_params = args[0] | |||
else: | |||
raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||
elif 2 >= len(kwargs) > 0 and len(args) <= 1: | |||
# SGD(lr=0.01), SGD(lr=0.01, momentum=0.9), SGD(model.parameters(), lr=0.1, momentum=0.9) | |||
if len(args) == 1: | |||
if hasattr(args[0], "__next__"): | |||
model_params = args[0] | |||
else: | |||
raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||
if not all(key in ("lr", "momentum") for key in kwargs): | |||
raise RuntimeError("Invalid SGD arguments. Expect {}, got {}.".format(("lr", "momentum"), kwargs)) | |||
lr = kwargs.get("lr", 0.01) | |||
momentum = kwargs.get("momentum", 0.9) | |||
else: | |||
raise RuntimeError("SGD only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args)) | |||
def __init__(self, model_params=None, lr=0.01, momentum=0): | |||
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) | |||
def construct_from_pytorch(self, model_params): | |||
@@ -49,30 +22,7 @@ class SGD(Optimizer): | |||
class Adam(Optimizer): | |||
def __init__(self, *args, **kwargs): | |||
model_params, lr, weight_decay = None, 0.01, 0.9 | |||
if len(args) == 0 and len(kwargs) == 0: | |||
pass | |||
elif len(args) == 1 and len(kwargs) == 0: | |||
if isinstance(args[0], float) or isinstance(args[0], int): | |||
lr = args[0] | |||
elif hasattr(args[0], "__next__"): | |||
model_params = args[0] | |||
else: | |||
raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||
elif 2 >= len(kwargs) > 0 and len(args) <= 1: | |||
if len(args) == 1: | |||
if hasattr(args[0], "__next__"): | |||
model_params = args[0] | |||
else: | |||
raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||
if not all(key in ("lr", "weight_decay") for key in kwargs): | |||
raise RuntimeError("Invalid Adam arguments. Expect {}, got {}.".format(("lr", "weight_decay"), kwargs)) | |||
lr = kwargs.get("lr", 0.01) | |||
weight_decay = kwargs.get("weight_decay", 0.9) | |||
else: | |||
raise RuntimeError("Adam only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args)) | |||
def __init__(self, model_params=None, lr=0.01, weight_decay=0): | |||
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) | |||
def construct_from_pytorch(self, model_params): | |||
@@ -32,8 +32,7 @@ class Trainer(object): | |||
validate_every=-1, | |||
dev_data=None, use_cuda=False, save_path=None, | |||
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | |||
metric_key=None, | |||
**kwargs): | |||
metric_key=None): | |||
""" | |||
:param DataSet train_data: the training data | |||
@@ -44,6 +44,9 @@ class TestDataSet(unittest.TestCase): | |||
self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) | |||
self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) | |||
with self.assertRaises(RuntimeError): | |||
dd.add_field("??", [[1, 2]] * 40) | |||
def test_delete_field(self): | |||
dd = DataSet() | |||
dd.add_field("x", [[1, 2, 3]] * 10) | |||
@@ -65,8 +68,66 @@ class TestDataSet(unittest.TestCase): | |||
self.assertTrue(isinstance(sub_ds, DataSet)) | |||
self.assertEqual(len(sub_ds), 10) | |||
def test_get_item_error(self): | |||
with self.assertRaises(RuntimeError): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
_ = ds[40:] | |||
with self.assertRaises(KeyError): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
_ = ds["kom"] | |||
def test_len_(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
self.assertEqual(len(ds), 40) | |||
ds = DataSet() | |||
self.assertEqual(len(ds), 0) | |||
def test_apply(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx") | |||
self.assertTrue("rx" in ds.field_arrays) | |||
self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) | |||
def test_contains(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
self.assertTrue("x" in ds) | |||
self.assertTrue("y" in ds) | |||
self.assertFalse("z" in ds) | |||
def test_rename_field(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
ds.rename_field("x", "xx") | |||
self.assertTrue("xx" in ds) | |||
self.assertFalse("x" in ds) | |||
with self.assertRaises(KeyError): | |||
ds.rename_field("yyy", "oo") | |||
def test_input_target(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
ds.set_input("x") | |||
ds.set_target("y") | |||
self.assertTrue(ds.field_arrays["x"].is_input) | |||
self.assertTrue(ds.field_arrays["y"].is_target) | |||
with self.assertRaises(KeyError): | |||
ds.set_input("xxx") | |||
with self.assertRaises(KeyError): | |||
ds.set_input("yyy") | |||
def test_get_input_name(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
self.assertEqual(ds.get_input_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_input]) | |||
def test_get_target_name(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target]) | |||
class TestDataSetIter(unittest.TestCase): | |||
def test__repr__(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
for iter in ds: | |||
self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4], 'y': [5, 6]}") |
@@ -27,3 +27,9 @@ class TestCase(unittest.TestCase): | |||
self.assertEqual(ins["x"], [1, 2, 3]) | |||
self.assertEqual(ins["y"], [4, 5, 6]) | |||
self.assertEqual(ins["z"], [1, 1, 1]) | |||
def test_repr(self): | |||
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | |||
ins = Instance(**fields) | |||
# simple print, that is enough. | |||
print(ins) |
@@ -11,9 +11,6 @@ class TestOptim(unittest.TestCase): | |||
self.assertTrue("lr" in optim.__dict__["settings"]) | |||
self.assertTrue("momentum" in optim.__dict__["settings"]) | |||
optim = SGD(0.001) | |||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||
optim = SGD(lr=0.001) | |||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||
@@ -25,17 +22,12 @@ class TestOptim(unittest.TestCase): | |||
_ = SGD("???") | |||
with self.assertRaises(RuntimeError): | |||
_ = SGD(0.001, lr=0.002) | |||
with self.assertRaises(RuntimeError): | |||
_ = SGD(lr=0.009, shit=9000) | |||
def test_Adam(self): | |||
optim = Adam(torch.nn.Linear(10, 3).parameters()) | |||
self.assertTrue("lr" in optim.__dict__["settings"]) | |||
self.assertTrue("weight_decay" in optim.__dict__["settings"]) | |||
optim = Adam(0.001) | |||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||
optim = Adam(lr=0.001) | |||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||