From c611cc57bd5676522277f0ffebcc7614aba2c4f2 Mon Sep 17 00:00:00 2001 From: Eric Lai Date: Thu, 13 May 2021 11:50:05 +0800 Subject: [PATCH] fix core --- tensorlayer/layers/core/core_paddle.py | 36 +++++++++++++++++++++++--- tensorlayer/package_info.py | 4 +-- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/tensorlayer/layers/core/core_paddle.py b/tensorlayer/layers/core/core_paddle.py index 19b56ee..769053f 100644 --- a/tensorlayer/layers/core/core_paddle.py +++ b/tensorlayer/layers/core/core_paddle.py @@ -2,14 +2,14 @@ # -*- coding: utf-8 -*- import copy, six -import tensorlayer as tl from .common import str2act from paddle.fluid import framework from paddle.fluid.dygraph import Layer from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.dygraph.base import program_desc_tracing_guard, param_guard +from paddle.fluid.dygraph import parallel_helper - -_global_layer_name_dict = {} # TODO: better implementation? +_global_layer_name_dict = {} class Module(Layer): @@ -54,7 +54,10 @@ class Module(Layer): self.act = act # Layer building state - # self._built = False + self._built = False + + # paddl_built + self._paddle_built = False # Layer nodes state self._nodes = [] @@ -160,6 +163,31 @@ class Module(Layer): def forward(self, *inputs, **kwargs): raise Exception("The forward method must be implemented by inherited class") + def __call__(self, *inputs, **kwargs): + with param_guard(self._parameters), param_guard(self._buffers): + for forward_pre_hook in self._forward_pre_hooks.values(): + hook_result = forward_pre_hook(self, inputs) + if hook_result is not None: + if not isinstance(hook_result, tuple): + hook_result = (hook_result, ) + inputs = hook_result + + if not self._paddle_built: + with program_desc_tracing_guard(False): + self._build_once(*inputs, **kwargs) + if parallel_helper._is_data_parallel_mode(): + parallel_helper._broadcast_parameters( + self._parameters.values()) + self._paddle_built = True + + outputs = self.forward(*inputs, **kwargs) + + for forward_post_hook in self._forward_post_hooks.values(): + hook_result = forward_post_hook(self, inputs, outputs) + if hook_result is not None: + outputs = hook_result + + return outputs def _get_weights(self, var_name, shape, init=None, trainable=True): if var_name in ["filters", "weights"]: diff --git a/tensorlayer/package_info.py b/tensorlayer/package_info.py index de5a884..1efbae6 100644 --- a/tensorlayer/package_info.py +++ b/tensorlayer/package_info.py @@ -2,8 +2,8 @@ # -*- coding: utf-8 -*- """Deep learning and Reinforcement learning library for Researchers and Engineers.""" -MAJOR = 2 -MINOR = 2 +MAJOR = 3 +MINOR = 0 PATCH = 0 PRE_RELEASE = '' # Use the following formatting: (major, minor, patch, prerelease)