From f62060339edd1da3c3e1092057e014757714d28a Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Mon, 3 Dec 2018 12:37:33 +0800 Subject: [PATCH 1/6] =?UTF-8?q?All=20tests=20pass.=20Ready=20to=20merge.?= =?UTF-8?q?=20*=20=E6=9B=B4=E6=96=B0Loss=E7=9A=84=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=E5=BD=A2=E5=8F=82=E8=B7=9Fmetric=E4=BF=9D=E6=8C=81=E4=B8=80?= =?UTF-8?q?=E8=87=B4=20*=20=E6=B7=BB=E5=8A=A0=E5=AF=B9=E5=87=A0=E7=A7=8Dlo?= =?UTF-8?q?ss=E7=9A=84=E6=B5=8B=E8=AF=95=20*=20embed=5Floader=E9=87=87?= =?UTF-8?q?=E7=94=A8=E7=BB=B4=E5=BA=A6=E7=8B=AC=E7=AB=8B=E7=9A=84=E6=96=B9?= =?UTF-8?q?=E6=B3=95=E9=87=87=E6=A0=B7=20*=20=E5=AF=B9=E5=BA=94=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E4=BB=A3=E7=A0=81=E7=9A=84=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/losses.py | 31 +++++++++++++++------------- fastNLP/io/embed_loader.py | 6 +++--- test/core/test_loss.py | 40 +++++++++++++++--------------------- test/core/test_trainer.py | 2 +- test/io/test_embed_loader.py | 6 +++--- test/test_tutorial.py | 4 ++-- 6 files changed, 42 insertions(+), 47 deletions(-) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index dce568bd..64ad8e23 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -70,11 +70,11 @@ class LossBase(object): raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " f"positional argument.).") - def __call__(self, output_dict, target_dict, force_check=False): + def __call__(self, pred_dict, target_dict, check=False): """ - :param output_dict: A dict from forward function of the network. + :param pred_dict: A dict from forward function of the network. :param target_dict: A dict from DataSet.batch_y. - :param force_check: Boolean. Force to check the mapping functions when it is running. + :param check: Boolean. Force to check the mapping functions when it is running. :return: """ args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss) @@ -88,7 +88,8 @@ class LossBase(object): raise RuntimeError( f"There is not any param in function{get_func_signature(self.get_loss)}" ) - self._checked = self._checked and not force_check + + self._checked = self._checked and not check if not self._checked: for keys in args: if keys not in param_map: @@ -105,12 +106,12 @@ class LossBase(object): duplicated = [] missing = [] if not self._checked: - for keys, val in output_dict.items(): + for keys, val in pred_dict.items(): if keys in target_dict.keys(): duplicated.append(keys) param_val_dict = {} - for keys, val in output_dict.items(): + for keys, val in pred_dict.items(): param_val_dict.update({keys: val}) for keys, val in target_dict.items(): param_val_dict.update({keys: val}) @@ -158,29 +159,31 @@ class LossFunc(LossBase): class CrossEntropyLoss(LossBase): - def __init__(self, input=None, target=None): + def __init__(self, pred=None, target=None): super(CrossEntropyLoss, self).__init__() self.get_loss = F.cross_entropy - self._init_param_map(input=input, target=target) + self._init_param_map(input=pred, target=target) class L1Loss(LossBase): - def __init__(self): + def __init__(self, pred=None, target=None): super(L1Loss, self).__init__() self.get_loss = F.l1_loss + self._init_param_map(input=pred, target=target) class BCELoss(LossBase): - def __init__(self, input=None, target=None): + def __init__(self, pred=None, target=None): super(BCELoss, self).__init__() self.get_loss = F.binary_cross_entropy - self._init_param_map(input=input, target=target) + self._init_param_map(input=pred, target=target) class NLLLoss(LossBase): - def __init__(self): + def __init__(self, pred=None, target=None): super(NLLLoss, self).__init__() self.get_loss = F.nll_loss + self._init_param_map(input=pred, target=target) class LossInForward(LossBase): @@ -200,9 +203,9 @@ class LossInForward(LossBase): varargs=[]) raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) - def __call__(self, output_dict, predict_dict, force_check=False): + def __call__(self, pred_dict, target_dict, check=False): - loss = self.get_loss(**output_dict) + loss = self.get_loss(**pred_dict) if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): if not isinstance(loss, torch.Tensor): diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py index 6e557c2b..779b7fd0 100644 --- a/fastNLP/io/embed_loader.py +++ b/fastNLP/io/embed_loader.py @@ -105,9 +105,9 @@ class EmbedLoader(BaseLoader): if np.sum(hit_flags) < len(vocab): # some words from vocab are missing in pre-trained embedding - # we normally sample them + # we normally sample each dimension vocab_embed = embedding_matrix[np.where(hit_flags)] - mean, cov = vocab_embed.mean(axis=0), np.cov(vocab_embed.T) - sampled_vectors = np.random.multivariate_normal(mean, cov, size=(len(vocab) - np.sum(hit_flags),)) + sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0), + size=(len(vocab) - np.sum(hit_flags), emb_dim)) embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors return embedding_matrix diff --git a/test/core/test_loss.py b/test/core/test_loss.py index 1124860b..9b77d0a1 100644 --- a/test/core/test_loss.py +++ b/test/core/test_loss.py @@ -271,40 +271,32 @@ class TestLoss(unittest.TestCase): loss3 = get_loss_3({'predict': predict}, {'truth': truth}) assert loss1 == loss2 and loss1 == loss3 - """ - get_loss_4 = LossFunc(func4) - loss4 = get_loss_4({'a': 1, 'b': 3}, {}) - print(loss4) - assert loss4 == (1 + 3) * 2 - - get_loss_5 = LossFunc(func4) - loss5 = get_loss_5({'a': 1, 'b': 3}, {'c': 4}) - print(loss5) - assert loss5 == (1 + 3) * 4 - - get_loss_6 = LossFunc(func6) - loss6 = get_loss_6({'a': 1, 'b': 3}, {'c': 4}) - print(loss6) - assert loss6 == (1 + 3) * 4 - - get_loss_7 = LossFunc(func6, c='cc') - loss7 = get_loss_7({'a': 1, 'b': 3}, {'cc': 4}) - print(loss7) - assert loss7 == (1 + 3) * 4 - """ - class TestLoss_v2(unittest.TestCase): def test_CrossEntropyLoss(self): - ce = loss.CrossEntropyLoss(input="my_predict", target="my_truth") + ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth") a = torch.randn(3, 5, requires_grad=False) b = torch.empty(3, dtype=torch.long).random_(5) ans = ce({"my_predict": a}, {"my_truth": b}) self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) def test_BCELoss(self): - bce = loss.BCELoss(input="my_predict", target="my_truth") + bce = loss.BCELoss(pred="my_predict", target="my_truth") a = torch.sigmoid(torch.randn((3, 5), requires_grad=False)) b = torch.randn((3, 5), requires_grad=False) ans = bce({"my_predict": a}, {"my_truth": b}) self.assertEqual(ans, torch.nn.functional.binary_cross_entropy(a, b)) + + def test_L1Loss(self): + l1 = loss.L1Loss(pred="my_predict", target="my_truth") + a = torch.randn(3, 5, requires_grad=False) + b = torch.randn(3, 5) + ans = l1({"my_predict": a}, {"my_truth": b}) + self.assertEqual(ans, torch.nn.functional.l1_loss(a, b)) + + def test_NLLLoss(self): + l1 = loss.NLLLoss(pred="my_predict", target="my_truth") + a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) + b = torch.tensor([1, 0, 4]) + ans = l1({"my_predict": a}, {"my_truth": b}) + self.assertEqual(ans, torch.nn.functional.nll_loss(a, b)) diff --git a/test/core/test_trainer.py b/test/core/test_trainer.py index ee4a5770..bc8df2d2 100644 --- a/test/core/test_trainer.py +++ b/test/core/test_trainer.py @@ -32,7 +32,7 @@ class TrainerTestGround(unittest.TestCase): model = NaiveClassifier(2, 1) trainer = Trainer(train_set, model, - losser=BCELoss(input="predict", target="y"), + losser=BCELoss(pred="predict", target="y"), metrics=AccuracyMetric(pred="predict", target="y"), n_epochs=10, batch_size=32, diff --git a/test/io/test_embed_loader.py b/test/io/test_embed_loader.py index fc1e7124..60e3710e 100644 --- a/test/io/test_embed_loader.py +++ b/test/io/test_embed_loader.py @@ -1,12 +1,12 @@ import unittest from fastNLP.core.vocabulary import Vocabulary +from fastNLP.io.embed_loader import EmbedLoader class TestEmbedLoader(unittest.TestCase): def test_case(self): vocab = Vocabulary() vocab.update(["the", "in", "I", "to", "of", "hahaha"]) - # TODO: np.cov在linux上segment fault,原因未知 - # embedding = EmbedLoader().fast_load_embedding(50, "../data_for_tests/glove.6B.50d_test.txt", vocab) - # self.assertEqual(tuple(embedding.shape), (len(vocab), 50)) + embedding = EmbedLoader().fast_load_embedding(50, "test/data_for_tests/glove.6B.50d_test.txt", vocab) + self.assertEqual(tuple(embedding.shape), (len(vocab), 50)) diff --git a/test/test_tutorial.py b/test/test_tutorial.py index fe6a9d86..e7ee5cf6 100644 --- a/test/test_tutorial.py +++ b/test/test_tutorial.py @@ -72,7 +72,7 @@ class TestTutorial(unittest.TestCase): # 实例化Trainer,传入模型和数据,进行训练 copy_model = deepcopy(model) overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data, - losser=CrossEntropyLoss(input="output", target="label_seq"), + losser=CrossEntropyLoss(pred="output", target="label_seq"), metrics=AccuracyMetric(pred="predict", target="label_seq"), save_path="./save", batch_size=4, @@ -80,7 +80,7 @@ class TestTutorial(unittest.TestCase): overfit_trainer.train() trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, - losser=CrossEntropyLoss(input="output", target="label_seq"), + losser=CrossEntropyLoss(pred="output", target="label_seq"), metrics=AccuracyMetric(pred="predict", target="label_seq"), save_path="./save", batch_size=4, From 6f58ec34b4357e5df3c7cb467b9906a823a8ca26 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Mon, 3 Dec 2018 19:53:34 +0800 Subject: [PATCH 2/6] =?UTF-8?q?Updates:=20*=20DataSet=E4=BF=AE=E6=94=B9=5F?= =?UTF-8?q?=5Frepr=5F=5F=EF=BC=8C=E4=BC=98=E5=8C=96print(datset)=E7=9A=84?= =?UTF-8?q?=E8=BE=93=E5=87=BA=20*=20Instance=E4=BF=AE=E6=94=B9=5F=5Frepr?= =?UTF-8?q?=5F=5F=EF=BC=8C=E4=BC=98=E5=8C=96print=E7=9A=84=E8=BE=93?= =?UTF-8?q?=E5=87=BA=20*=20Optimizer=E4=BC=98=E5=8C=96=E4=BC=A0=E5=8F=82?= =?UTF-8?q?=E6=8F=90=E7=A4=BA=20*=20Trainer=E5=8E=BB=E9=99=A4kwargs?= =?UTF-8?q?=E5=8F=82=E6=95=B0=20*=20losses.py=E5=8A=A0=E4=B8=AA=E5=8F=82?= =?UTF-8?q?=E6=95=B0=20*=20=E5=AF=B9=E5=BA=94test=20code=E7=9A=84=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataset.py | 9 ++++++ fastNLP/core/instance.py | 5 ++- fastNLP/core/losses.py | 1 + fastNLP/core/optimizer.py | 54 ++------------------------------ fastNLP/core/trainer.py | 3 +- test/core/test_dataset.py | 61 +++++++++++++++++++++++++++++++++++++ test/core/test_instance.py | 6 ++++ test/core/test_optimizer.py | 8 ----- 8 files changed, 82 insertions(+), 65 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 749d3e74..40ea0aab 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -110,6 +110,15 @@ class DataSet(object): field = iter(self.field_arrays.values()).__next__() return len(field) + def __inner_repr__(self): + if len(self) < 20: + return ",\n".join([ins.__repr__() for ins in self]) + else: + return self[:5].__inner_repr__() + "\n...\n" + self[-5:].__inner_repr__() + + def __repr__(self): + return "DataSet(" + self.__inner_repr__() + ")" + def append(self, ins): """Add an instance to the DataSet. If the DataSet is not empty, the instance must have the same field names as the rest instances in the DataSet. diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py index 9dfe8fb8..dc65fa82 100644 --- a/fastNLP/core/instance.py +++ b/fastNLP/core/instance.py @@ -1,5 +1,3 @@ - - class Instance(object): """An Instance is an example of data. It is the collection of Fields. @@ -33,4 +31,5 @@ class Instance(object): return self.add_field(name, field) def __repr__(self): - return self.fields.__repr__() + return "{" + ",\n".join( + "\'" + field_name + "\': " + str(self.fields[field_name]) for field_name in self.fields) + "}" diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 64ad8e23..5f05eab1 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -202,6 +202,7 @@ class LossInForward(LossBase): all_needed=[], varargs=[]) raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) + return kwargs[self.loss_key] def __call__(self, pred_dict, target_dict, check=False): diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py index 5075fa02..692ff003 100644 --- a/fastNLP/core/optimizer.py +++ b/fastNLP/core/optimizer.py @@ -10,34 +10,7 @@ class Optimizer(object): class SGD(Optimizer): - def __init__(self, *args, **kwargs): - model_params, lr, momentum = None, 0.01, 0.9 - if len(args) == 0 and len(kwargs) == 0: - # SGD() - pass - elif len(args) == 1 and len(kwargs) == 0: - if isinstance(args[0], float) or isinstance(args[0], int): - # SGD(0.001) - lr = args[0] - elif hasattr(args[0], "__next__"): - # SGD(model.parameters()) args[0] is a generator - model_params = args[0] - else: - raise RuntimeError("Not supported type {}.".format(type(args[0]))) - elif 2 >= len(kwargs) > 0 and len(args) <= 1: - # SGD(lr=0.01), SGD(lr=0.01, momentum=0.9), SGD(model.parameters(), lr=0.1, momentum=0.9) - if len(args) == 1: - if hasattr(args[0], "__next__"): - model_params = args[0] - else: - raise RuntimeError("Not supported type {}.".format(type(args[0]))) - if not all(key in ("lr", "momentum") for key in kwargs): - raise RuntimeError("Invalid SGD arguments. Expect {}, got {}.".format(("lr", "momentum"), kwargs)) - lr = kwargs.get("lr", 0.01) - momentum = kwargs.get("momentum", 0.9) - else: - raise RuntimeError("SGD only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args)) - + def __init__(self, model_params=None, lr=0.01, momentum=0): super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) def construct_from_pytorch(self, model_params): @@ -49,30 +22,7 @@ class SGD(Optimizer): class Adam(Optimizer): - def __init__(self, *args, **kwargs): - model_params, lr, weight_decay = None, 0.01, 0.9 - if len(args) == 0 and len(kwargs) == 0: - pass - elif len(args) == 1 and len(kwargs) == 0: - if isinstance(args[0], float) or isinstance(args[0], int): - lr = args[0] - elif hasattr(args[0], "__next__"): - model_params = args[0] - else: - raise RuntimeError("Not supported type {}.".format(type(args[0]))) - elif 2 >= len(kwargs) > 0 and len(args) <= 1: - if len(args) == 1: - if hasattr(args[0], "__next__"): - model_params = args[0] - else: - raise RuntimeError("Not supported type {}.".format(type(args[0]))) - if not all(key in ("lr", "weight_decay") for key in kwargs): - raise RuntimeError("Invalid Adam arguments. Expect {}, got {}.".format(("lr", "weight_decay"), kwargs)) - lr = kwargs.get("lr", 0.01) - weight_decay = kwargs.get("weight_decay", 0.9) - else: - raise RuntimeError("Adam only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args)) - + def __init__(self, model_params=None, lr=0.01, weight_decay=0): super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) def construct_from_pytorch(self, model_params): diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index b24af193..5223bbab 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -32,8 +32,7 @@ class Trainer(object): validate_every=-1, dev_data=None, use_cuda=False, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, - metric_key=None, - **kwargs): + metric_key=None): """ :param DataSet train_data: the training data diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index 786e7248..fa3e1ea3 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -44,6 +44,9 @@ class TestDataSet(unittest.TestCase): self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) + with self.assertRaises(RuntimeError): + dd.add_field("??", [[1, 2]] * 40) + def test_delete_field(self): dd = DataSet() dd.add_field("x", [[1, 2, 3]] * 10) @@ -65,8 +68,66 @@ class TestDataSet(unittest.TestCase): self.assertTrue(isinstance(sub_ds, DataSet)) self.assertEqual(len(sub_ds), 10) + def test_get_item_error(self): + with self.assertRaises(RuntimeError): + ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) + _ = ds[40:] + + with self.assertRaises(KeyError): + ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) + _ = ds["kom"] + + def test_len_(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) + self.assertEqual(len(ds), 40) + + ds = DataSet() + self.assertEqual(len(ds), 0) + def test_apply(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx") self.assertTrue("rx" in ds.field_arrays) self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) + + def test_contains(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) + self.assertTrue("x" in ds) + self.assertTrue("y" in ds) + self.assertFalse("z" in ds) + + def test_rename_field(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) + ds.rename_field("x", "xx") + self.assertTrue("xx" in ds) + self.assertFalse("x" in ds) + + with self.assertRaises(KeyError): + ds.rename_field("yyy", "oo") + + def test_input_target(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) + ds.set_input("x") + ds.set_target("y") + self.assertTrue(ds.field_arrays["x"].is_input) + self.assertTrue(ds.field_arrays["y"].is_target) + + with self.assertRaises(KeyError): + ds.set_input("xxx") + with self.assertRaises(KeyError): + ds.set_input("yyy") + + def test_get_input_name(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) + self.assertEqual(ds.get_input_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_input]) + + def test_get_target_name(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) + self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target]) + + +class TestDataSetIter(unittest.TestCase): + def test__repr__(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) + for iter in ds: + self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4], 'y': [5, 6]}") diff --git a/test/core/test_instance.py b/test/core/test_instance.py index abe6b7f7..1342ba2c 100644 --- a/test/core/test_instance.py +++ b/test/core/test_instance.py @@ -27,3 +27,9 @@ class TestCase(unittest.TestCase): self.assertEqual(ins["x"], [1, 2, 3]) self.assertEqual(ins["y"], [4, 5, 6]) self.assertEqual(ins["z"], [1, 1, 1]) + + def test_repr(self): + fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} + ins = Instance(**fields) + # simple print, that is enough. + print(ins) diff --git a/test/core/test_optimizer.py b/test/core/test_optimizer.py index ab18b9be..7b29b826 100644 --- a/test/core/test_optimizer.py +++ b/test/core/test_optimizer.py @@ -11,9 +11,6 @@ class TestOptim(unittest.TestCase): self.assertTrue("lr" in optim.__dict__["settings"]) self.assertTrue("momentum" in optim.__dict__["settings"]) - optim = SGD(0.001) - self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) - optim = SGD(lr=0.001) self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) @@ -25,17 +22,12 @@ class TestOptim(unittest.TestCase): _ = SGD("???") with self.assertRaises(RuntimeError): _ = SGD(0.001, lr=0.002) - with self.assertRaises(RuntimeError): - _ = SGD(lr=0.009, shit=9000) def test_Adam(self): optim = Adam(torch.nn.Linear(10, 3).parameters()) self.assertTrue("lr" in optim.__dict__["settings"]) self.assertTrue("weight_decay" in optim.__dict__["settings"]) - optim = Adam(0.001) - self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) - optim = Adam(lr=0.001) self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) From 131e1ccd3b289388772ea4f1969558119789c33a Mon Sep 17 00:00:00 2001 From: yunfan Date: Mon, 3 Dec 2018 20:04:14 +0800 Subject: [PATCH 3/6] add _fast_param_map --- fastNLP/core/losses.py | 12 +++++++++++- fastNLP/core/metrics.py | 10 +++++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 64ad8e23..c3459964 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -70,6 +70,12 @@ class LossBase(object): raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " f"positional argument.).") + def _fast_param_map(self, pred_dict, target_dict): + if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: + return pred_dict.values[0], target_dict.values[0] + return None + + def __call__(self, pred_dict, target_dict, check=False): """ :param pred_dict: A dict from forward function of the network. @@ -77,6 +83,11 @@ class LossBase(object): :param check: Boolean. Force to check the mapping functions when it is running. :return: """ + fast_param = self._fast_param_map(pred_dict, target_dict) + if fast_param is not None: + loss = self.get_loss(*fast_param) + return loss + args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss) if varargs is not None: raise RuntimeError( @@ -132,7 +143,6 @@ class LossBase(object): param_map_val = _map_args(reversed_param_map, **param_val_dict) param_value = _build_args(self.get_loss, **param_map_val) - loss = self.get_loss(**param_value) if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index b1fc110b..6216b16d 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -71,7 +71,7 @@ class MetricBase(object): def get_metric(self, reset=True): raise NotImplemented - def _fast_call_evaluate(self, pred_dict, target_dict): + def _fast_param_map(self, pred_dict, target_dict): """ Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. @@ -80,7 +80,9 @@ class MetricBase(object): :param target_dict: :return: boolean, whether to go on codes in self.__call__(). When False, don't go on. """ - return False + if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: + return pred_dict.values[0] and target_dict.values[0] + return None def __call__(self, pred_dict, target_dict, check=False): """ @@ -103,7 +105,9 @@ class MetricBase(object): raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") if not check: - if self._fast_call_evaluate(pred_dict=pred_dict, target_dict=target_dict): + fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict) + if fast_param is not None: + self.evaluate(*fast_param) return if not self._checked: From 513876d5db1f7df2c08ea6984802901383ac3404 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Mon, 3 Dec 2018 20:50:51 +0800 Subject: [PATCH 4/6] =?UTF-8?q?Updates:=20*=20fix=20losses=E7=9A=84=5Ffast?= =?UTF-8?q?=5Fparam=5Fmap=E7=9A=84bug=20*=20Trainer=E6=B7=BB=E5=8A=A0sampe?= =?UTF-8?q?lr=E5=88=9D=E5=A7=8B=E5=8C=96=E5=8F=82=E6=95=B0=EF=BC=8C?= =?UTF-8?q?=E5=B9=B6=E8=B0=83=E6=95=B4=E5=8F=82=E6=95=B0=E9=A1=BA=E5=BA=8F?= =?UTF-8?q?=20*=20refine=20codes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/losses.py | 3 +-- fastNLP/core/metrics.py | 57 +++++++++++++++++++-------------------- fastNLP/core/trainer.py | 17 ++++-------- fastNLP/core/utils.py | 38 +++++++++++++++----------- test/core/test_trainer.py | 14 +++------- test/test_tutorial.py | 16 +++++------ 6 files changed, 65 insertions(+), 80 deletions(-) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 82f47025..f2fb16d0 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -72,10 +72,9 @@ class LossBase(object): def _fast_param_map(self, pred_dict, target_dict): if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: - return pred_dict.values[0], target_dict.values[0] + return tuple(pred_dict.values())[0], tuple(target_dict.values())[0] return None - def __call__(self, pred_dict, target_dict, check=False): """ :param pred_dict: A dict from forward function of the network. diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 6216b16d..d83c4022 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -1,4 +1,3 @@ - import inspect import warnings from collections import defaultdict @@ -7,11 +6,12 @@ import numpy as np import torch from fastNLP.core.utils import CheckError +from fastNLP.core.utils import CheckRes from fastNLP.core.utils import _build_args from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import seq_lens_to_masks -from fastNLP.core.utils import CheckRes + class MetricBase(object): def __init__(self): @@ -59,9 +59,10 @@ class MetricBase(object): func_args = [arg for arg in func_spect.args if arg != 'self'] for func_param, input_param in self.param_map.items(): if func_param not in func_args: - raise NameError(f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " - f"initialization parameters, or change the signature of" - f" {get_func_signature(self.evaluate)}.") + raise NameError( + f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " + f"initialization parameters, or change the signature of" + f" {get_func_signature(self.evaluate)}.") # evaluate should not have varargs. if func_spect.varargs: @@ -113,7 +114,7 @@ class MetricBase(object): if not self._checked: # 1. check consistence between signature and param_map func_spect = inspect.getfullargspec(self.evaluate) - func_args = set([arg for arg in func_spect.args if arg!='self']) + func_args = set([arg for arg in func_spect.args if arg != 'self']) for func_arg, input_arg in self.param_map.items(): if func_arg not in func_args: raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.") @@ -121,7 +122,7 @@ class MetricBase(object): # 2. only part of the param_map are passed, left are not for arg in func_args: if arg not in self.param_map: - self.param_map[arg] = arg #This param does not need mapping. + self.param_map[arg] = arg # This param does not need mapping. self._evaluate_args = func_args self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} @@ -153,14 +154,14 @@ class MetricBase(object): replaced_missing = list(missing) for idx, func_arg in enumerate(missing): replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ - f"in `{self.__class__.__name__}`)" + f"in `{self.__class__.__name__}`)" check_res = CheckRes(missing=replaced_missing, - unused=check_res.unused, - duplicated=duplicated, - required=check_res.required, - all_needed=check_res.all_needed, - varargs=check_res.varargs) + unused=check_res.unused, + duplicated=duplicated, + required=check_res.required, + all_needed=check_res.all_needed, + varargs=check_res.varargs) if check_res.missing or check_res.duplicated or check_res.varargs: raise CheckError(check_res=check_res, @@ -172,6 +173,7 @@ class MetricBase(object): return + class AccuracyMetric(MetricBase): def __init__(self, pred=None, target=None, masks=None, seq_lens=None): super().__init__() @@ -191,7 +193,7 @@ class AccuracyMetric(MetricBase): :param target_dict: :return: boolean, whether to go on codes in self.__call__(). When False, don't go on. """ - if len(pred_dict)==1 and len(target_dict)==1: + if len(pred_dict) == 1 and len(target_dict) == 1: pred = list(pred_dict.values())[0] target = list(target_dict.values())[0] self.evaluate(pred=pred, target=target) @@ -211,7 +213,7 @@ class AccuracyMetric(MetricBase): None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. :return: dict({'acc': float}) """ - #TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value + # TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value if not isinstance(pred, torch.Tensor): raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(pred)}.") @@ -224,14 +226,14 @@ class AccuracyMetric(MetricBase): f"got {type(masks)}.") elif seq_lens is not None and not isinstance(seq_lens, torch.Tensor): raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," - f"got {type(seq_lens)}.") + f"got {type(seq_lens)}.") if masks is None and seq_lens is not None: masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) - if pred.size()==target.size(): + if pred.size() == target.size(): pass - elif len(pred.size())==len(target.size())+1: + elif len(pred.size()) == len(target.size()) + 1: pred = pred.argmax(dim=-1) else: raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " @@ -245,18 +247,17 @@ class AccuracyMetric(MetricBase): self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item() self.total += torch.sum(masks.float()).item() else: - self.acc_count += torch.sum(torch.eq(pred, target).float()).item() + self.acc_count += torch.sum(torch.eq(pred, target).float()).item() self.total += np.prod(list(pred.size())) def get_metric(self, reset=True): - evaluate_result = {'acc': round(self.acc_count/self.total, 6)} + evaluate_result = {'acc': round(self.acc_count / self.total, 6)} if reset: self.acc_count = 0 self.total = 0 return evaluate_result - def _prepare_metrics(metrics): """ @@ -278,7 +279,8 @@ def _prepare_metrics(metrics): raise TypeError(f"{metric_name}.get_metric must be callable, got {type(metric.get_metric)}.") _metrics.append(metric) else: - raise TypeError(f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.") + raise TypeError( + f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.") elif isinstance(metrics, MetricBase): _metrics = [metrics] else: @@ -300,6 +302,7 @@ class Evaluator(object): """ raise NotImplementedError + class ClassifyEvaluator(Evaluator): def __init__(self): super(ClassifyEvaluator, self).__init__() @@ -335,6 +338,7 @@ class SeqLabelEvaluator(Evaluator): accuracy = total_correct / total_count return {"accuracy": float(accuracy)} + class SeqLabelEvaluator2(Evaluator): # 上面的evaluator应该是错误的 def __init__(self, seq_lens_field_name='word_seq_origin_len'): @@ -367,7 +371,7 @@ class SeqLabelEvaluator2(Evaluator): if x_i in self.end_tagidx_set: truth_count += 1 for j in range(start, idx_i + 1): - if y_[j]!=x_[j]: + if y_[j] != x_[j]: flag = False break if flag: @@ -380,8 +384,7 @@ class SeqLabelEvaluator2(Evaluator): R = corr_count / (float(truth_count) + 1e-6) F = 2 * P * R / (P + R + 1e-6) - return {"P": P, 'R':R, 'F': F} - + return {"P": P, 'R': R, 'F': F} class SNLIEvaluator(Evaluator): @@ -563,10 +566,6 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 -def classification_report(y_true, y_pred, labels=None, target_names=None, digits=2): - raise NotImplementedError - - def accuracy_topk(y_true, y_prob, k=1): """Compute accuracy of y_true matching top-k probable labels in y_prob. diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 5223bbab..dd5862d3 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -28,11 +28,9 @@ class Trainer(object): """ - def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, - validate_every=-1, - dev_data=None, use_cuda=False, save_path=None, - optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, - metric_key=None): + def __init__(self, train_data, model, losser=None, metrics=None, optimizer=Adam(lr=0.01, weight_decay=0), + sampler=RandomSampler(), n_epochs=3, batch_size=32, print_every=50, validate_every=-1, dev_data=None, + use_cuda=False, metric_key=None, save_path=None, check_code_level=0): """ :param DataSet train_data: the training data @@ -54,7 +52,6 @@ class Trainer(object): :: metric_key="-PPL" # language model gets better as perplexity gets smaller - :param kwargs: """ super(Trainer, self).__init__() @@ -105,6 +102,7 @@ class Trainer(object): self.print_every = int(print_every) self.validate_every = int(validate_every) self.best_metric_indicator = None + self.sampler = sampler self._model_device = model.parameters().__next__().device @@ -120,14 +118,9 @@ class Trainer(object): batch_size=self.batch_size, use_cuda=self.use_cuda) - for k, v in kwargs.items(): - setattr(self, k, v) - self.step = 0 self.start_time = None # start timestamp - # print(self.__dict__) - def train(self): """Start Training. @@ -158,7 +151,7 @@ class Trainer(object): epoch = 1 while epoch <= self.n_epochs: - data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), + data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False) self._train_epoch(data_iterator, self.model, epoch, start) diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index bfbeb6e5..6c101890 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -10,6 +10,8 @@ import torch CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', 'varargs'], verbose=False) + + def save_pickle(obj, pickle_path, file_name): """Save an object into a pickle file. @@ -53,6 +55,7 @@ def pickle_exist(pickle_path, pickle_name): else: return False + def _build_args(func, **kwargs): spect = inspect.getfullargspec(func) if spect.varkw is not None: @@ -108,7 +111,7 @@ def _check_arg_dict_list(func, args): assert callable(func) and isinstance(arg_dict_list, (list, tuple)) assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) spect = inspect.getfullargspec(func) - all_args = set([arg for arg in spect.args if arg!='self']) + all_args = set([arg for arg in spect.args if arg != 'self']) defaults = [] if spect.defaults is not None: defaults = [arg for arg in spect.defaults] @@ -130,6 +133,7 @@ def _check_arg_dict_list(func, args): all_needed=list(all_args), varargs=varargs) + def get_func_signature(func): """ @@ -153,7 +157,7 @@ def get_func_signature(func): class_name = func.__self__.__class__.__name__ signature = inspect.signature(func) signature_str = str(signature) - if len(signature_str)>2: + if len(signature_str) > 2: _self = '(self, ' else: _self = '(self' @@ -176,12 +180,13 @@ def _is_function_or_method(func): return False return True + def _check_function_or_method(func): if not _is_function_or_method(func): raise TypeError(f"{type(func)} is not a method or function.") -def _move_dict_value_to_device(*args, device:torch.device): +def _move_dict_value_to_device(*args, device: torch.device): """ move data to model's device, element in *args should be dict. This is a inplace change. @@ -206,7 +211,8 @@ class CheckError(Exception): CheckError. Used in losses.LossBase, metrics.MetricBase. """ - def __init__(self, check_res:CheckRes, func_signature:str): + + def __init__(self, check_res: CheckRes, func_signature: str): errs = [f'The following problems occurred when calling `{func_signature}`'] if check_res.varargs: @@ -228,8 +234,9 @@ IGNORE_CHECK_LEVEL = 0 WARNING_CHECK_LEVEL = 1 STRICT_CHECK_LEVEL = 2 -def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:CheckRes, - pred_dict:dict, target_dict:dict, dataset, check_level=0): + +def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_res: CheckRes, + pred_dict: dict, target_dict: dict, dataset, check_level=0): errs = [] unuseds = [] _unused_field = [] @@ -268,8 +275,8 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res: f"target is {list(target_dict.keys())}).") if _miss_out_dataset: _tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " - f"target is {list(target_dict.keys())}) or output it " - f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).") + f"target is {list(target_dict.keys())}) or output it " + f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).") if _unused_field: _tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " suggestions.append(_tmp) @@ -277,15 +284,15 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res: if check_res.duplicated: errs.append(f"\tduplicated param: {check_res.duplicated}.") suggestions.append(f"Delete {check_res.duplicated} in the output of " - f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") + f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") if check_level == STRICT_CHECK_LEVEL: errs.extend(unuseds) - if len(errs)>0: + if len(errs) > 0: errs.insert(0, f'The following problems occurred when calling {func_signature}') sugg_str = "" - if len(suggestions)>1: + if len(suggestions) > 1: for idx, sugg in enumerate(suggestions): sugg_str += f'({idx+1}). {sugg}' else: @@ -332,10 +339,10 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): if check_level == STRICT_CHECK_LEVEL: errs.extend(_unused) - if len(errs)>0: + if len(errs) > 0: errs.insert(0, f'The following problems occurred when calling {func_signature}') sugg_str = "" - if len(suggestions)>1: + if len(suggestions) > 1: for idx, sugg in enumerate(suggestions): sugg_str += f'({idx+1}). {sugg}' else: @@ -357,11 +364,11 @@ def seq_lens_to_masks(seq_lens, float=True): :return: list, np.ndarray or torch.Tensor, shape will be (B, max_length) """ if isinstance(seq_lens, np.ndarray): - assert len(np.shape(seq_lens))==1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}." + assert len(np.shape(seq_lens)) == 1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}." assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}." raise NotImplemented elif isinstance(seq_lens, torch.LongTensor): - assert len(seq_lens.size())==1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}." + assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}." batch_size = seq_lens.size(0) max_len = seq_lens.max() indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) @@ -375,4 +382,3 @@ def seq_lens_to_masks(seq_lens, float=True): raise NotImplemented else: raise NotImplemented - diff --git a/test/core/test_trainer.py b/test/core/test_trainer.py index bc8df2d2..0a59b3cd 100644 --- a/test/core/test_trainer.py +++ b/test/core/test_trainer.py @@ -31,15 +31,7 @@ class TrainerTestGround(unittest.TestCase): model = NaiveClassifier(2, 1) - trainer = Trainer(train_set, model, - losser=BCELoss(pred="predict", target="y"), - metrics=AccuracyMetric(pred="predict", target="y"), - n_epochs=10, - batch_size=32, - print_every=10, - validate_every=-1, - dev_data=dev_set, - optimizer=SGD(0.1), - check_code_level=2 - ) + trainer = Trainer(train_set, model, losser=BCELoss(pred="predict", target="y"), + metrics=AccuracyMetric(pred="predict", target="y"), optimizer=SGD(), n_epochs=10, + batch_size=32, print_every=10, validate_every=-1, dev_data=dev_set, check_code_level=2) trainer.train() diff --git a/test/test_tutorial.py b/test/test_tutorial.py index e7ee5cf6..f3648b4f 100644 --- a/test/test_tutorial.py +++ b/test/test_tutorial.py @@ -71,20 +71,16 @@ class TestTutorial(unittest.TestCase): # 实例化Trainer,传入模型和数据,进行训练 copy_model = deepcopy(model) - overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data, + overfit_trainer = Trainer(train_data=test_data, model=copy_model, losser=CrossEntropyLoss(pred="output", target="label_seq"), - metrics=AccuracyMetric(pred="predict", target="label_seq"), - save_path="./save", - batch_size=4, - n_epochs=10) + metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4, + dev_data=test_data, save_path="./save") overfit_trainer.train() - trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, + trainer = Trainer(train_data=train_data, model=model, losser=CrossEntropyLoss(pred="output", target="label_seq"), - metrics=AccuracyMetric(pred="predict", target="label_seq"), - save_path="./save", - batch_size=4, - n_epochs=10) + metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4, + dev_data=test_data, save_path="./save") trainer.train() print('Train finished!') From ad3c5b6ef02947bb718382538d22c3407625acf5 Mon Sep 17 00:00:00 2001 From: yunfan Date: Mon, 3 Dec 2018 21:54:22 +0800 Subject: [PATCH 5/6] add magic iter in dataset --- fastNLP/core/dataset.py | 44 ++++++++++++----------- fastNLP/core/utils.py | 16 +++++++++ fastNLP/modules/encoder/char_embedding.py | 2 +- test/core/test_dataset.py | 2 +- 4 files changed, 41 insertions(+), 23 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 40ea0aab..dea27174 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -26,24 +26,6 @@ class DataSet(object): However, it stores data in a different way: Field-first, Instance-second. """ - - class DataSetIter(object): - def __init__(self, data_set, idx=-1, **fields): - self.data_set = data_set - self.idx = idx - self.fields = fields - - def __next__(self): - self.idx += 1 - if self.idx >= len(self.data_set): - raise StopIteration - # this returns a copy - return self.data_set[self.idx] - - def __repr__(self): - return "\n".join(['{}: {}'.format(name, repr(self.data_set[name][self.idx])) for name - in self.data_set.get_fields().keys()]) - def __init__(self, data=None): """ @@ -72,7 +54,27 @@ class DataSet(object): return item in self.field_arrays def __iter__(self): - return self.DataSetIter(self) + def iter_func(): + for idx in range(len(self)): + yield self[idx] + return iter_func() + + def _inner_iter(self): + class Iter_ptr: + def __init__(self, dataset, idx): + self.dataset = dataset + self.idx = idx + def __getitem__(self, item): + assert self.idx < len(self.dataset), "index:{} out of range".format(self.idx) + assert item in self.dataset.field_arrays, "no such field:{} in instance {}".format(item, self.dataset[self.idx]) + return self.dataset.field_arrays[item][self.idx] + def __repr__(self): + return self.dataset[self.idx].__repr__() + + def inner_iter_func(): + for idx in range(len(self)): + yield Iter_ptr(self, idx) + return inner_iter_func() def __getitem__(self, idx): """Fetch Instance(s) at the `idx` position(s) in the dataset. @@ -232,7 +234,7 @@ class DataSet(object): :param str new_field_name: If not None, results of the function will be stored as a new field. :return results: if new_field_name is not passed, returned values of the function over all instances. """ - results = [func(ins) for ins in self] + results = [func(ins) for ins in self._inner_iter()] if new_field_name is not None: if new_field_name in self.field_arrays: # overwrite the field, keep same attributes @@ -248,7 +250,7 @@ class DataSet(object): return results def drop(self, func): - results = [ins for ins in self if not func(ins)] + results = [ins for ins in self._inner_iter() if not func(ins)] for name, old_field in self.field_arrays.items(): self.field_arrays[name].content = [ins[name] for ins in results] diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 6c101890..abe7889c 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -382,3 +382,19 @@ def seq_lens_to_masks(seq_lens, float=True): raise NotImplemented else: raise NotImplemented + + +def seq_mask(seq_len, max_len): + """Create sequence mask. + + :param seq_len: list or torch.Tensor, the lengths of sequences in a batch. + :param max_len: int, the maximum sequence length in a batch. + :return mask: torch.LongTensor, [batch_size, max_len] + + """ + if not isinstance(seq_len, torch.Tensor): + seq_len = torch.LongTensor(seq_len) + seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] + seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] + return torch.gt(seq_len, seq_range) # [batch_size, max_len] + diff --git a/fastNLP/modules/encoder/char_embedding.py b/fastNLP/modules/encoder/char_embedding.py index 1ca3b5ba..249a73ad 100644 --- a/fastNLP/modules/encoder/char_embedding.py +++ b/fastNLP/modules/encoder/char_embedding.py @@ -43,7 +43,7 @@ class ConvCharEmbedding(nn.Module): # [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1] y = torch.squeeze(y, 2) # [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1] - y = F.tanh(y) + y = torch.tanh(y) y, __ = torch.max(y, 2) # [batch_size*sent_length, feature_maps[i]] feats.append(y) diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index fa3e1ea3..8ca2ed86 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -130,4 +130,4 @@ class TestDataSetIter(unittest.TestCase): def test__repr__(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) for iter in ds: - self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4], 'y': [5, 6]}") + self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4],\n'y': [5, 6]}") From 1421b7dfbabaec073e87717420b41c9c70f1539c Mon Sep 17 00:00:00 2001 From: yunfan Date: Mon, 3 Dec 2018 22:48:02 +0800 Subject: [PATCH 6/6] add this feature totally for yh --- fastNLP/core/dataset.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index dea27174..4925ac36 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -1,3 +1,4 @@ +import _pickle as pickle import numpy as np from fastNLP.core.fieldarray import FieldArray @@ -317,3 +318,12 @@ class DataSet(object): for header, content in zip(headers, contents): _dict[header].append(content) return cls(_dict) + + def save(self, path): + with open(path, 'wb') as f: + pickle.dump(self, f) + + @staticmethod + def load(self, path): + with open(path, 'rb') as f: + return pickle.load(f)