Browse Source

!1740 domask

From: @dimitri_rose
Reviewed-by: @ji_chen,@wangxiaotian22
Signed-off-by: @ji_chen
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
5b4d07a822
3 changed files with 5 additions and 1 deletions
  1. +1
    -0
      ge/common/types.cc
  2. +3
    -1
      ge/graph/passes/link_gen_mask_nodes_pass.cc
  3. +1
    -0
      inc/framework/common/types.h

+ 1
- 0
ge/common/types.cc View File

@@ -92,6 +92,7 @@ REGISTER_OPTYPE_DEFINE(DROPOUTGENMASK, "DropOutGenMask");
REGISTER_OPTYPE_DEFINE(DROPOUTDOMASK, "DropOutDoMask");
REGISTER_OPTYPE_DEFINE(DROPOUTDOMASKV3, "DropOutDoMaskV3");
REGISTER_OPTYPE_DEFINE(DROPOUTDOMASKV3D, "DropOutDoMaskV3D");
REGISTER_OPTYPE_DEFINE(SOFTMAXV2WITHDROPOUTDOMASKV3D, "SoftmaxV2WithDropOutDoMaskV3D");
REGISTER_OPTYPE_DEFINE(CONCAT, "Concat");
REGISTER_OPTYPE_DEFINE(ROIPOOLING, "ROIPooling");
REGISTER_OPTYPE_DEFINE(PROPOSAL, "Proposal");


+ 3
- 1
ge/graph/passes/link_gen_mask_nodes_pass.cc View File

@@ -96,7 +96,9 @@ bool LinkGenMaskNodesPass::AreAllInputsConst(const NodePtr &node) const {
void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<NodePtr> &gen_mask_nodes) const {
set<NodePtr> nodes_set;
for (const NodePtr &node : graph->GetDirectNode()) {
if (node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 && node->GetType() != DROPOUTDOMASKV3D) {
bool not_domask = node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 &&
node->GetType() != DROPOUTDOMASKV3D && node->GetType() != SOFTMAXV2WITHDROPOUTDOMASKV3D;
if (not_domask) {
continue;
}



+ 1
- 0
inc/framework/common/types.h View File

@@ -132,6 +132,7 @@ REGISTER_OPTYPE_DECLARE(DROPOUT, "Dropout");
REGISTER_OPTYPE_DECLARE(DROPOUTDOMASK, "DropOutDoMask");
REGISTER_OPTYPE_DECLARE(DROPOUTDOMASKV3, "DropOutDoMaskV3");
REGISTER_OPTYPE_DECLARE(DROPOUTDOMASKV3D, "DropOutDoMaskV3D");
REGISTER_OPTYPE_DECLARE(SOFTMAXV2WITHDROPOUTDOMASKV3D, "SoftmaxV2WithDropOutDoMaskV3D");
REGISTER_OPTYPE_DECLARE(DROPOUTGENMASK, "DropOutGenMask");
REGISTER_OPTYPE_DECLARE(CONCAT, "Concat");
REGISTER_OPTYPE_DECLARE(ROIPOOLING, "ROIPooling");


Loading…
Cancel
Save