Browse Source

解决在DataParallel模型场景下无法进行参数匹配的问题

tags/v0.4.10
yh_cc 6 years ago
parent
commit
e9137349c3
3 changed files with 16 additions and 7 deletions
  1. +4
    -1
      fastNLP/core/tester.py
  2. +11
    -5
      fastNLP/core/trainer.py
  3. +1
    -1
      fastNLP/modules/encoder/_elmo.py

+ 4
- 1
fastNLP/core/tester.py View File

@@ -120,7 +120,10 @@ class Tester(object):
raise TypeError(f"`{_model_name}.predict` must be callable to be used "
f"for evaluation, not `{type(self._predict_func)}`.")
else:
self._predict_func = self._model.forward
if isinstance(model, nn.DataParallel):
self._predict_func = self._model.module.forward
else:
self._predict_func = self._model.forward
def test(self):
"""开始进行验证,并返回验证结果。


+ 11
- 5
fastNLP/core/trainer.py View File

@@ -578,7 +578,10 @@ class Trainer(object):
self.step = 0
self.epoch = 0
start = time.time()
if isinstance(self.model, nn.DataParallel):
self._forward_func = self.model.module.forward
else:
self._forward_func = self.model.forward
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
self.pbar = pbar
avg_loss = 0
@@ -682,11 +685,11 @@ class Trainer(object):
self.optimizer.step()
def _data_forward(self, network, x):
x = _build_args(network.forward, **x)
x = _build_args(self._forward_func, **x)
y = network(**x)
if not isinstance(y, dict):
raise TypeError(
f"The return value of {_get_func_signature(network.forward)} should be dict, got {type(y)}.")
f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.")
return y
def _grad_backward(self, loss):
@@ -845,8 +848,11 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
print(info_str)
_check_forward_error(forward_func=model.forward, dataset=dataset,
batch_x=batch_x, check_level=check_level)
refined_batch_x = _build_args(model.forward, **batch_x)
if isinstance(model, nn.DataParallel):
forward_func = model.module.forward
else:
forward_func = model.forward
refined_batch_x = _build_args(forward_func, **batch_x)
pred_dict = model(**refined_batch_x)
func_signature = _get_func_signature(model.forward)
if not isinstance(pred_dict, dict):


+ 1
- 1
fastNLP/modules/encoder/_elmo.py View File

@@ -709,7 +709,7 @@ class _ElmoModel(nn.Module):
config, word_emb_layer, char_emb_layer)
self.token_embedder.load_state_dict(token_embedder_states, strict=False)
if config['token_embedder']['word_dim'] > 0 and vocab._no_create_word_length > 0: # 需要映射,使得来自于dev, test的idx指向unk
words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False)
words_to_words = nn.Parameter(torch.arange(len(vocab)+2).long(), requires_grad=False)
for word, idx in vocab:
if vocab._is_word_no_create_entry(word):
words_to_words[idx] = vocab.unknown_idx


Loading…
Cancel
Save