From 9c1b4914d8f4fda018f449cf5374941b1fa03c9d Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sun, 30 Jun 2019 09:52:01 +0800 Subject: [PATCH] =?UTF-8?q?1.=E4=BF=AE=E5=A4=8Dtrainer=E4=B8=AD=E6=BD=9C?= =?UTF-8?q?=E5=9C=A8=E5=A4=9A=E6=AD=A5=E6=9B=B4=E6=96=B0bug;=202.=20LSTM?= =?UTF-8?q?=E7=9A=84=E6=95=B0=E6=8D=AE=E5=B9=B6=E8=A1=8C=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=EF=BC=9B3.=20embed=5Floader=E4=B8=ADbug=E4=BF=AE=E5=A4=8D,=20?= =?UTF-8?q?=E4=B8=94=E5=85=81=E8=AE=B8=E6=89=8B=E5=8A=A8=E5=88=9D=E5=A7=8B?= =?UTF-8?q?=E5=8C=96=EF=BC=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callback.py | 2 +- fastNLP/core/dataset.py | 6 ++++-- fastNLP/core/optimizer.py | 17 +++++++++++++++++ fastNLP/core/trainer.py | 6 +++--- fastNLP/io/embed_loader.py | 33 +++++++++++++++++++-------------- fastNLP/modules/encoder/lstm.py | 11 +---------- fastNLP/modules/utils.py | 2 ++ setup.py | 2 +- 8 files changed, 48 insertions(+), 31 deletions(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 483f6dc1..5dfd889b 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -548,7 +548,7 @@ class LRScheduler(Callback): else: raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") - def on_epoch_begin(self): + def on_epoch_end(self): self.scheduler.step(self.epoch) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 4cd1ad9c..b7df9dec 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -801,17 +801,19 @@ class DataSet(object): else: return DataSet() - def split(self, ratio): + def split(self, ratio, shuffle=True): """ 将DataSet按照ratio的比例拆分,返回两个DataSet :param float ratio: 0 [N,L,C] - output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first) + output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) if self.batch_first: output = output[unsort_idx] else: output = output[:, unsort_idx] - # 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591 - if self.batch_first: - if output.size(1) < max_len: - dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1)) - output = torch.cat([output, dummy_tensor], 1) - else: - if output.size(0) < max_len: - dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1)) - output = torch.cat([output, dummy_tensor], 0) else: output, hx = self.lstm(x, hx) return output, hx diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index c87f3a68..3c6a3d27 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -82,6 +82,8 @@ def get_embeddings(init_embed): if isinstance(init_embed, tuple): res = nn.Embedding( num_embeddings=init_embed[0], embedding_dim=init_embed[1]) + nn.init.uniform_(res.weight.data, a=-np.sqrt(3/res.weight.data.size(1)), + b=np.sqrt(3/res.weight.data.size(1))) elif isinstance(init_embed, nn.Module): res = init_embed elif isinstance(init_embed, torch.Tensor): diff --git a/setup.py b/setup.py index 49646761..0dbef455 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ with open('requirements.txt', encoding='utf-8') as f: setup( name='FastNLP', - version='0.4.0', + version='dev0.5.0', description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', long_description=readme, long_description_content_type='text/markdown',