From ed7f7b1cd95a50a28c3e9e3e4db9f0b48efbffe9 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sun, 1 Mar 2020 10:41:48 +0800 Subject: [PATCH] =?UTF-8?q?1.=E4=BC=98=E5=8C=96Trainer=E4=B8=AD=E5=AF=B9ex?= =?UTF-8?q?ception=E7=9A=84=E5=A4=84=E7=90=86=EF=BC=9B2.=E4=BF=AE=E6=94=B9?= =?UTF-8?q?static=5Fembedding=E8=B4=B4=E5=90=88seq2seq?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/trainer.py | 18 +++++++++--------- fastNLP/embeddings/static_embedding.py | 4 ++++ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index a5fea9bf..af68158c 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -617,6 +617,14 @@ class Trainer(object): elif on_exception == 'raise': raise e + if self.dev_data is not None and self.best_dev_perf is not None and load_best_model: + model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) + load_succeed = self._load_model(self.model, model_name) + if load_succeed: + self.logger.info("Reloaded the best model.") + else: + self.logger.info("Fail to reload best model.") + finally: if self.dev_data is not None and self.best_dev_perf is not None: self.logger.info( "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step)) @@ -624,15 +632,7 @@ class Trainer(object): results['best_eval'] = self.best_dev_perf results['best_epoch'] = self.best_dev_epoch results['best_step'] = self.best_dev_step - if load_best_model: - model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) - load_succeed = self._load_model(self.model, model_name) - if load_succeed: - self.logger.info("Reloaded the best model.") - else: - self.logger.info("Fail to reload best model.") - finally: - pass + results['seconds'] = round(time.time() - start_time, 2) return results diff --git a/fastNLP/embeddings/static_embedding.py b/fastNLP/embeddings/static_embedding.py index fcfb3fec..42d30bac 100644 --- a/fastNLP/embeddings/static_embedding.py +++ b/fastNLP/embeddings/static_embedding.py @@ -76,6 +76,10 @@ class StaticEmbedding(TokenEmbedding): """ super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) if embedding_dim > 0: + if model_dir_or_name is not None: + warnings.warn(f"StaticEmbedding will ignore `model_dir_or_name`, and randomly initialize embedding with" + f" dimension {embedding_dim}. If you want to use pre-trained embedding, " + f"set `embedding_dim` to 0.") model_dir_or_name = None # 得到cache_path