Browse Source

1.优化Trainer中对exception的处理;2.修改static_embedding贴合seq2seq

tags/v0.5.5
yh_cc 5 years ago
parent
commit
ed7f7b1cd9
2 changed files with 13 additions and 9 deletions
  1. +9
    -9
      fastNLP/core/trainer.py
  2. +4
    -0
      fastNLP/embeddings/static_embedding.py

+ 9
- 9
fastNLP/core/trainer.py View File

@@ -617,6 +617,14 @@ class Trainer(object):
elif on_exception == 'raise':
raise e

if self.dev_data is not None and self.best_dev_perf is not None and load_best_model:
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
load_succeed = self._load_model(self.model, model_name)
if load_succeed:
self.logger.info("Reloaded the best model.")
else:
self.logger.info("Fail to reload best model.")
finally:
if self.dev_data is not None and self.best_dev_perf is not None:
self.logger.info(
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step))
@@ -624,15 +632,7 @@ class Trainer(object):
results['best_eval'] = self.best_dev_perf
results['best_epoch'] = self.best_dev_epoch
results['best_step'] = self.best_dev_step
if load_best_model:
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
load_succeed = self._load_model(self.model, model_name)
if load_succeed:
self.logger.info("Reloaded the best model.")
else:
self.logger.info("Fail to reload best model.")
finally:
pass

results['seconds'] = round(time.time() - start_time, 2)

return results


+ 4
- 0
fastNLP/embeddings/static_embedding.py View File

@@ -76,6 +76,10 @@ class StaticEmbedding(TokenEmbedding):
"""
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
if embedding_dim > 0:
if model_dir_or_name is not None:
warnings.warn(f"StaticEmbedding will ignore `model_dir_or_name`, and randomly initialize embedding with"
f" dimension {embedding_dim}. If you want to use pre-trained embedding, "
f"set `embedding_dim` to 0.")
model_dir_or_name = None
# 得到cache_path


Loading…
Cancel
Save