|
|
|
@@ -185,19 +185,32 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { |
|
|
|
return node_adjoint; |
|
|
|
} |
|
|
|
|
|
|
|
bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) { |
|
|
|
// Do not care about non-CNode |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
// Do not care about kPrimReturn |
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimReturn)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto &users = primal_graph_->manager()->node_users()[node]; |
|
|
|
// Do not care about isolated morphisms |
|
|
|
if (users.empty()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
// Not free if it's used by some node in primal_graph |
|
|
|
bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) { |
|
|
|
auto &user = kv.first; |
|
|
|
return user->func_graph() == primal_graph_; |
|
|
|
}); |
|
|
|
return !nonfree; |
|
|
|
} |
|
|
|
|
|
|
|
void DFunctor::MapFreeMorphism() { |
|
|
|
// Handle cnode not attached to output, that might be refered in other functions. |
|
|
|
for (auto &node : primal_graph_->nodes()) { |
|
|
|
auto adjoint = FindAdjoint(node); |
|
|
|
if (adjoint != nullptr) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
MS_LOG(DEBUG) << "MapFreeMorphism noncnode not mapped after MapMorphism " << node->ToString() << " " |
|
|
|
<< node->type_name() << "."; |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimReturn)) { |
|
|
|
if (!IsFreeMorphism(node)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << "."; |
|
|
|
@@ -256,9 +269,10 @@ void DFunctor::MapMorphism() { |
|
|
|
// Set stop_gradient before MapMorphism. |
|
|
|
BroadCastStopFlag(); |
|
|
|
|
|
|
|
// Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent |
|
|
|
MapFreeMorphism(); |
|
|
|
// Handle morphism from output. |
|
|
|
(void)MapMorphism(primal_graph_->output()); |
|
|
|
MapFreeMorphism(); |
|
|
|
|
|
|
|
// Construct K for primal_graph_ |
|
|
|
auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output()); |
|
|
|
@@ -298,9 +312,10 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) { |
|
|
|
const size_t param_diff = 1; |
|
|
|
if (bprop_graph->output()->isa<CNode>() && |
|
|
|
bprop_graph->output()->cast<CNodePtr>()->size() + param_diff != bprop_graph->parameters().size()) { |
|
|
|
MS_LOG(EXCEPTION) << "User defined Cell bprop " << primal->ToString() << " in scope " |
|
|
|
<< primal->output()->scope()->name() |
|
|
|
<< " output must be a tuple and output number should be the same with inputs."; |
|
|
|
// It does not matter with the final tangents, just a tip for debugging |
|
|
|
MS_LOG(DEBUG) << "User defined Cell bprop " << primal->ToString() << " in scope " |
|
|
|
<< primal->output()->scope()->name() |
|
|
|
<< " output must be a tuple and output number should be the same with inputs."; |
|
|
|
} |
|
|
|
resources_->manager()->AddFuncGraph(bprop_graph); |
|
|
|
|
|
|
|
|