diff --git a/fastNLP/core/utils/tqdm_progress.py b/fastNLP/core/utils/tqdm_progress.py index d40cb81f..d6e0f9fb 100644 --- a/fastNLP/core/utils/tqdm_progress.py +++ b/fastNLP/core/utils/tqdm_progress.py @@ -47,7 +47,7 @@ class TqdmProgress(metaclass=Singleton): def __init__(self): self.bars = {} - def add_task(self, iterable=None, description=None, total=None, leave=False, + def add_task(self, description=None, total=None, leave=False, ncols=None, mininterval=0.1, maxinterval=10.0, miniters=None, ascii=None, visible=True, unit='it', unit_scale=False, dynamic_ncols=False, smoothing=0.3, bar_format=None, initial=0, @@ -57,7 +57,6 @@ class TqdmProgress(metaclass=Singleton): 主要就模仿了 tqdm bar 的创建,为了和 FRichProgress 的接口尽量统一,将 desc 重名为了 description,以及 disable 专为了 visible 。 - :param iterable: :param description: :param total: :param leave: @@ -96,7 +95,7 @@ class TqdmProgress(metaclass=Singleton): else: file = sys.stdout - bar = tqdm(iterable=iterable, desc=description, total=total, leave=leave, file=file, + bar = tqdm(iterable=None, desc=description, total=total, leave=leave, file=file, ncols=ncols, mininterval=mininterval, maxinterval=maxinterval, miniters=miniters, ascii=ascii, disable=not visible, unit=unit, unit_scale=unit_scale, dynamic_ncols=dynamic_ncols, smoothing=smoothing, bar_format=bar_format, initial=initial, diff --git a/fastNLP/embeddings/torch/char_embedding.py b/fastNLP/embeddings/torch/char_embedding.py index 69706281..73269e99 100644 --- a/fastNLP/embeddings/torch/char_embedding.py +++ b/fastNLP/embeddings/torch/char_embedding.py @@ -17,6 +17,8 @@ if _NEED_IMPORT_TORCH: import torch.nn as nn import torch.nn.functional as F from torch.nn import LSTM + import torch.nn.utils.rnn as rnn + from .embedding import TokenEmbedding from .static_embedding import StaticEmbedding @@ -270,8 +272,7 @@ class LSTMCharEmbedding(TokenEmbedding): chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size chars = self.dropout(chars) reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1) - char_seq_len = chars_masks.eq(False).sum(dim=-1).reshape(batch_size * max_len) - lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1) + lstm_chars = self.lstm(reshaped_chars, None)[0].reshape(batch_size, max_len, max_word_len, -1) # B x M x M x H lstm_chars = self.activation(lstm_chars)