Browse Source

修复bug

tags/v1.0.0alpha
yhcc 2 years ago
parent
commit
cef9dca103
2 changed files with 5 additions and 5 deletions
  1. +2
    -3
      fastNLP/core/utils/tqdm_progress.py
  2. +3
    -2
      fastNLP/embeddings/torch/char_embedding.py

+ 2
- 3
fastNLP/core/utils/tqdm_progress.py View File

@@ -47,7 +47,7 @@ class TqdmProgress(metaclass=Singleton):
def __init__(self):
self.bars = {}

def add_task(self, iterable=None, description=None, total=None, leave=False,
def add_task(self, description=None, total=None, leave=False,
ncols=None, mininterval=0.1, maxinterval=10.0, miniters=None,
ascii=None, visible=True, unit='it', unit_scale=False,
dynamic_ncols=False, smoothing=0.3, bar_format=None, initial=0,
@@ -57,7 +57,6 @@ class TqdmProgress(metaclass=Singleton):
主要就模仿了 tqdm bar 的创建,为了和 FRichProgress 的接口尽量统一,将 desc 重名为了 description,以及 disable 专为了
visible 。

:param iterable:
:param description:
:param total:
:param leave:
@@ -96,7 +95,7 @@ class TqdmProgress(metaclass=Singleton):
else:
file = sys.stdout

bar = tqdm(iterable=iterable, desc=description, total=total, leave=leave, file=file,
bar = tqdm(iterable=None, desc=description, total=total, leave=leave, file=file,
ncols=ncols, mininterval=mininterval, maxinterval=maxinterval, miniters=miniters,
ascii=ascii, disable=not visible, unit=unit, unit_scale=unit_scale,
dynamic_ncols=dynamic_ncols, smoothing=smoothing, bar_format=bar_format, initial=initial,


+ 3
- 2
fastNLP/embeddings/torch/char_embedding.py View File

@@ -17,6 +17,8 @@ if _NEED_IMPORT_TORCH:
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LSTM
import torch.nn.utils.rnn as rnn


from .embedding import TokenEmbedding
from .static_embedding import StaticEmbedding
@@ -270,8 +272,7 @@ class LSTMCharEmbedding(TokenEmbedding):
chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size
chars = self.dropout(chars)
reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1)
char_seq_len = chars_masks.eq(False).sum(dim=-1).reshape(batch_size * max_len)
lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1)
lstm_chars = self.lstm(reshaped_chars, None)[0].reshape(batch_size, max_len, max_word_len, -1)
# B x M x M x H
lstm_chars = self.activation(lstm_chars)


Loading…
Cancel
Save