|
@@ -2,14 +2,14 @@ |
|
|
# -*- coding: utf-8 -*- |
|
|
# -*- coding: utf-8 -*- |
|
|
|
|
|
|
|
|
import copy, six |
|
|
import copy, six |
|
|
import tensorlayer as tl |
|
|
|
|
|
from .common import str2act |
|
|
from .common import str2act |
|
|
from paddle.fluid import framework |
|
|
from paddle.fluid import framework |
|
|
from paddle.fluid.dygraph import Layer |
|
|
from paddle.fluid.dygraph import Layer |
|
|
from paddle.fluid.framework import in_dygraph_mode |
|
|
from paddle.fluid.framework import in_dygraph_mode |
|
|
|
|
|
from paddle.fluid.dygraph.base import program_desc_tracing_guard, param_guard |
|
|
|
|
|
from paddle.fluid.dygraph import parallel_helper |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_global_layer_name_dict = {} # TODO: better implementation? |
|
|
|
|
|
|
|
|
_global_layer_name_dict = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Module(Layer): |
|
|
class Module(Layer): |
|
@@ -54,7 +54,10 @@ class Module(Layer): |
|
|
self.act = act |
|
|
self.act = act |
|
|
|
|
|
|
|
|
# Layer building state |
|
|
# Layer building state |
|
|
# self._built = False |
|
|
|
|
|
|
|
|
self._built = False |
|
|
|
|
|
|
|
|
|
|
|
# paddl_built |
|
|
|
|
|
self._paddle_built = False |
|
|
|
|
|
|
|
|
# Layer nodes state |
|
|
# Layer nodes state |
|
|
self._nodes = [] |
|
|
self._nodes = [] |
|
@@ -160,6 +163,31 @@ class Module(Layer): |
|
|
def forward(self, *inputs, **kwargs): |
|
|
def forward(self, *inputs, **kwargs): |
|
|
raise Exception("The forward method must be implemented by inherited class") |
|
|
raise Exception("The forward method must be implemented by inherited class") |
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, *inputs, **kwargs): |
|
|
|
|
|
with param_guard(self._parameters), param_guard(self._buffers): |
|
|
|
|
|
for forward_pre_hook in self._forward_pre_hooks.values(): |
|
|
|
|
|
hook_result = forward_pre_hook(self, inputs) |
|
|
|
|
|
if hook_result is not None: |
|
|
|
|
|
if not isinstance(hook_result, tuple): |
|
|
|
|
|
hook_result = (hook_result, ) |
|
|
|
|
|
inputs = hook_result |
|
|
|
|
|
|
|
|
|
|
|
if not self._paddle_built: |
|
|
|
|
|
with program_desc_tracing_guard(False): |
|
|
|
|
|
self._build_once(*inputs, **kwargs) |
|
|
|
|
|
if parallel_helper._is_data_parallel_mode(): |
|
|
|
|
|
parallel_helper._broadcast_parameters( |
|
|
|
|
|
self._parameters.values()) |
|
|
|
|
|
self._paddle_built = True |
|
|
|
|
|
|
|
|
|
|
|
outputs = self.forward(*inputs, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
for forward_post_hook in self._forward_post_hooks.values(): |
|
|
|
|
|
hook_result = forward_post_hook(self, inputs, outputs) |
|
|
|
|
|
if hook_result is not None: |
|
|
|
|
|
outputs = hook_result |
|
|
|
|
|
|
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
def _get_weights(self, var_name, shape, init=None, trainable=True): |
|
|
def _get_weights(self, var_name, shape, init=None, trainable=True): |
|
|
if var_name in ["filters", "weights"]: |
|
|
if var_name in ["filters", "weights"]: |
|
|