diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 483f6dc1..5dfd889b 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -548,7 +548,7 @@ class LRScheduler(Callback): else: raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") - def on_epoch_begin(self): + def on_epoch_end(self): self.scheduler.step(self.epoch) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 4cd1ad9c..b7df9dec 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -801,17 +801,19 @@ class DataSet(object): else: return DataSet() - def split(self, ratio): + def split(self, ratio, shuffle=True): """ 将DataSet按照ratio的比例拆分,返回两个DataSet :param float ratio: 0 [N,L,C] - output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first) + output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) if self.batch_first: output = output[unsort_idx] else: output = output[:, unsort_idx] - # 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591 - if self.batch_first: - if output.size(1) < max_len: - dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1)) - output = torch.cat([output, dummy_tensor], 1) - else: - if output.size(0) < max_len: - dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1)) - output = torch.cat([output, dummy_tensor], 0) else: output, hx = self.lstm(x, hx) return output, hx diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index c87f3a68..3c6a3d27 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -82,6 +82,8 @@ def get_embeddings(init_embed): if isinstance(init_embed, tuple): res = nn.Embedding( num_embeddings=init_embed[0], embedding_dim=init_embed[1]) + nn.init.uniform_(res.weight.data, a=-np.sqrt(3/res.weight.data.size(1)), + b=np.sqrt(3/res.weight.data.size(1))) elif isinstance(init_embed, nn.Module): res = init_embed elif isinstance(init_embed, torch.Tensor): diff --git a/setup.py b/setup.py index 49646761..0dbef455 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ with open('requirements.txt', encoding='utf-8') as f: setup( name='FastNLP', - version='0.4.0', + version='dev0.5.0', description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', long_description=readme, long_description_content_type='text/markdown',