From 6f58ec34b4357e5df3c7cb467b9906a823a8ca26 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Mon, 3 Dec 2018 19:53:34 +0800 Subject: [PATCH] =?UTF-8?q?Updates:=20*=20DataSet=E4=BF=AE=E6=94=B9=5F=5Fr?= =?UTF-8?q?epr=5F=5F=EF=BC=8C=E4=BC=98=E5=8C=96print(datset)=E7=9A=84?= =?UTF-8?q?=E8=BE=93=E5=87=BA=20*=20Instance=E4=BF=AE=E6=94=B9=5F=5Frepr?= =?UTF-8?q?=5F=5F=EF=BC=8C=E4=BC=98=E5=8C=96print=E7=9A=84=E8=BE=93?= =?UTF-8?q?=E5=87=BA=20*=20Optimizer=E4=BC=98=E5=8C=96=E4=BC=A0=E5=8F=82?= =?UTF-8?q?=E6=8F=90=E7=A4=BA=20*=20Trainer=E5=8E=BB=E9=99=A4kwargs?= =?UTF-8?q?=E5=8F=82=E6=95=B0=20*=20losses.py=E5=8A=A0=E4=B8=AA=E5=8F=82?= =?UTF-8?q?=E6=95=B0=20*=20=E5=AF=B9=E5=BA=94test=20code=E7=9A=84=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataset.py | 9 ++++++ fastNLP/core/instance.py | 5 ++- fastNLP/core/losses.py | 1 + fastNLP/core/optimizer.py | 54 ++------------------------------ fastNLP/core/trainer.py | 3 +- test/core/test_dataset.py | 61 +++++++++++++++++++++++++++++++++++++ test/core/test_instance.py | 6 ++++ test/core/test_optimizer.py | 8 ----- 8 files changed, 82 insertions(+), 65 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 749d3e74..40ea0aab 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -110,6 +110,15 @@ class DataSet(object): field = iter(self.field_arrays.values()).__next__() return len(field) + def __inner_repr__(self): + if len(self) < 20: + return ",\n".join([ins.__repr__() for ins in self]) + else: + return self[:5].__inner_repr__() + "\n...\n" + self[-5:].__inner_repr__() + + def __repr__(self): + return "DataSet(" + self.__inner_repr__() + ")" + def append(self, ins): """Add an instance to the DataSet. If the DataSet is not empty, the instance must have the same field names as the rest instances in the DataSet. diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py index 9dfe8fb8..dc65fa82 100644 --- a/fastNLP/core/instance.py +++ b/fastNLP/core/instance.py @@ -1,5 +1,3 @@ - - class Instance(object): """An Instance is an example of data. It is the collection of Fields. @@ -33,4 +31,5 @@ class Instance(object): return self.add_field(name, field) def __repr__(self): - return self.fields.__repr__() + return "{" + ",\n".join( + "\'" + field_name + "\': " + str(self.fields[field_name]) for field_name in self.fields) + "}" diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 64ad8e23..5f05eab1 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -202,6 +202,7 @@ class LossInForward(LossBase): all_needed=[], varargs=[]) raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) + return kwargs[self.loss_key] def __call__(self, pred_dict, target_dict, check=False): diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py index 5075fa02..692ff003 100644 --- a/fastNLP/core/optimizer.py +++ b/fastNLP/core/optimizer.py @@ -10,34 +10,7 @@ class Optimizer(object): class SGD(Optimizer): - def __init__(self, *args, **kwargs): - model_params, lr, momentum = None, 0.01, 0.9 - if len(args) == 0 and len(kwargs) == 0: - # SGD() - pass - elif len(args) == 1 and len(kwargs) == 0: - if isinstance(args[0], float) or isinstance(args[0], int): - # SGD(0.001) - lr = args[0] - elif hasattr(args[0], "__next__"): - # SGD(model.parameters()) args[0] is a generator - model_params = args[0] - else: - raise RuntimeError("Not supported type {}.".format(type(args[0]))) - elif 2 >= len(kwargs) > 0 and len(args) <= 1: - # SGD(lr=0.01), SGD(lr=0.01, momentum=0.9), SGD(model.parameters(), lr=0.1, momentum=0.9) - if len(args) == 1: - if hasattr(args[0], "__next__"): - model_params = args[0] - else: - raise RuntimeError("Not supported type {}.".format(type(args[0]))) - if not all(key in ("lr", "momentum") for key in kwargs): - raise RuntimeError("Invalid SGD arguments. Expect {}, got {}.".format(("lr", "momentum"), kwargs)) - lr = kwargs.get("lr", 0.01) - momentum = kwargs.get("momentum", 0.9) - else: - raise RuntimeError("SGD only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args)) - + def __init__(self, model_params=None, lr=0.01, momentum=0): super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) def construct_from_pytorch(self, model_params): @@ -49,30 +22,7 @@ class SGD(Optimizer): class Adam(Optimizer): - def __init__(self, *args, **kwargs): - model_params, lr, weight_decay = None, 0.01, 0.9 - if len(args) == 0 and len(kwargs) == 0: - pass - elif len(args) == 1 and len(kwargs) == 0: - if isinstance(args[0], float) or isinstance(args[0], int): - lr = args[0] - elif hasattr(args[0], "__next__"): - model_params = args[0] - else: - raise RuntimeError("Not supported type {}.".format(type(args[0]))) - elif 2 >= len(kwargs) > 0 and len(args) <= 1: - if len(args) == 1: - if hasattr(args[0], "__next__"): - model_params = args[0] - else: - raise RuntimeError("Not supported type {}.".format(type(args[0]))) - if not all(key in ("lr", "weight_decay") for key in kwargs): - raise RuntimeError("Invalid Adam arguments. Expect {}, got {}.".format(("lr", "weight_decay"), kwargs)) - lr = kwargs.get("lr", 0.01) - weight_decay = kwargs.get("weight_decay", 0.9) - else: - raise RuntimeError("Adam only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args)) - + def __init__(self, model_params=None, lr=0.01, weight_decay=0): super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) def construct_from_pytorch(self, model_params): diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index b24af193..5223bbab 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -32,8 +32,7 @@ class Trainer(object): validate_every=-1, dev_data=None, use_cuda=False, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, - metric_key=None, - **kwargs): + metric_key=None): """ :param DataSet train_data: the training data diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index 786e7248..fa3e1ea3 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -44,6 +44,9 @@ class TestDataSet(unittest.TestCase): self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) + with self.assertRaises(RuntimeError): + dd.add_field("??", [[1, 2]] * 40) + def test_delete_field(self): dd = DataSet() dd.add_field("x", [[1, 2, 3]] * 10) @@ -65,8 +68,66 @@ class TestDataSet(unittest.TestCase): self.assertTrue(isinstance(sub_ds, DataSet)) self.assertEqual(len(sub_ds), 10) + def test_get_item_error(self): + with self.assertRaises(RuntimeError): + ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) + _ = ds[40:] + + with self.assertRaises(KeyError): + ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) + _ = ds["kom"] + + def test_len_(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) + self.assertEqual(len(ds), 40) + + ds = DataSet() + self.assertEqual(len(ds), 0) + def test_apply(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx") self.assertTrue("rx" in ds.field_arrays) self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) + + def test_contains(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) + self.assertTrue("x" in ds) + self.assertTrue("y" in ds) + self.assertFalse("z" in ds) + + def test_rename_field(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) + ds.rename_field("x", "xx") + self.assertTrue("xx" in ds) + self.assertFalse("x" in ds) + + with self.assertRaises(KeyError): + ds.rename_field("yyy", "oo") + + def test_input_target(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) + ds.set_input("x") + ds.set_target("y") + self.assertTrue(ds.field_arrays["x"].is_input) + self.assertTrue(ds.field_arrays["y"].is_target) + + with self.assertRaises(KeyError): + ds.set_input("xxx") + with self.assertRaises(KeyError): + ds.set_input("yyy") + + def test_get_input_name(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) + self.assertEqual(ds.get_input_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_input]) + + def test_get_target_name(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) + self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target]) + + +class TestDataSetIter(unittest.TestCase): + def test__repr__(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) + for iter in ds: + self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4], 'y': [5, 6]}") diff --git a/test/core/test_instance.py b/test/core/test_instance.py index abe6b7f7..1342ba2c 100644 --- a/test/core/test_instance.py +++ b/test/core/test_instance.py @@ -27,3 +27,9 @@ class TestCase(unittest.TestCase): self.assertEqual(ins["x"], [1, 2, 3]) self.assertEqual(ins["y"], [4, 5, 6]) self.assertEqual(ins["z"], [1, 1, 1]) + + def test_repr(self): + fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} + ins = Instance(**fields) + # simple print, that is enough. + print(ins) diff --git a/test/core/test_optimizer.py b/test/core/test_optimizer.py index ab18b9be..7b29b826 100644 --- a/test/core/test_optimizer.py +++ b/test/core/test_optimizer.py @@ -11,9 +11,6 @@ class TestOptim(unittest.TestCase): self.assertTrue("lr" in optim.__dict__["settings"]) self.assertTrue("momentum" in optim.__dict__["settings"]) - optim = SGD(0.001) - self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) - optim = SGD(lr=0.001) self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) @@ -25,17 +22,12 @@ class TestOptim(unittest.TestCase): _ = SGD("???") with self.assertRaises(RuntimeError): _ = SGD(0.001, lr=0.002) - with self.assertRaises(RuntimeError): - _ = SGD(lr=0.009, shit=9000) def test_Adam(self): optim = Adam(torch.nn.Linear(10, 3).parameters()) self.assertTrue("lr" in optim.__dict__["settings"]) self.assertTrue("weight_decay" in optim.__dict__["settings"]) - optim = Adam(0.001) - self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) - optim = Adam(lr=0.001) self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)