Browse Source

fix core

master
Eric Lai 4 years ago
parent
commit
c611cc57bd
2 changed files with 34 additions and 6 deletions
  1. +32
    -4
      tensorlayer/layers/core/core_paddle.py
  2. +2
    -2
      tensorlayer/package_info.py

+ 32
- 4
tensorlayer/layers/core/core_paddle.py View File

@@ -2,14 +2,14 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-


import copy, six import copy, six
import tensorlayer as tl
from .common import str2act from .common import str2act
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid.dygraph import Layer from paddle.fluid.dygraph import Layer
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.dygraph.base import program_desc_tracing_guard, param_guard
from paddle.fluid.dygraph import parallel_helper



_global_layer_name_dict = {} # TODO: better implementation?
_global_layer_name_dict = {}




class Module(Layer): class Module(Layer):
@@ -54,7 +54,10 @@ class Module(Layer):
self.act = act self.act = act


# Layer building state # Layer building state
# self._built = False
self._built = False

# paddl_built
self._paddle_built = False


# Layer nodes state # Layer nodes state
self._nodes = [] self._nodes = []
@@ -160,6 +163,31 @@ class Module(Layer):
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
raise Exception("The forward method must be implemented by inherited class") raise Exception("The forward method must be implemented by inherited class")


def __call__(self, *inputs, **kwargs):
with param_guard(self._parameters), param_guard(self._buffers):
for forward_pre_hook in self._forward_pre_hooks.values():
hook_result = forward_pre_hook(self, inputs)
if hook_result is not None:
if not isinstance(hook_result, tuple):
hook_result = (hook_result, )
inputs = hook_result

if not self._paddle_built:
with program_desc_tracing_guard(False):
self._build_once(*inputs, **kwargs)
if parallel_helper._is_data_parallel_mode():
parallel_helper._broadcast_parameters(
self._parameters.values())
self._paddle_built = True

outputs = self.forward(*inputs, **kwargs)

for forward_post_hook in self._forward_post_hooks.values():
hook_result = forward_post_hook(self, inputs, outputs)
if hook_result is not None:
outputs = hook_result

return outputs


def _get_weights(self, var_name, shape, init=None, trainable=True): def _get_weights(self, var_name, shape, init=None, trainable=True):
if var_name in ["filters", "weights"]: if var_name in ["filters", "weights"]:


+ 2
- 2
tensorlayer/package_info.py View File

@@ -2,8 +2,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""Deep learning and Reinforcement learning library for Researchers and Engineers.""" """Deep learning and Reinforcement learning library for Researchers and Engineers."""


MAJOR = 2
MINOR = 2
MAJOR = 3
MINOR = 0
PATCH = 0 PATCH = 0
PRE_RELEASE = '' PRE_RELEASE = ''
# Use the following formatting: (major, minor, patch, prerelease) # Use the following formatting: (major, minor, patch, prerelease)


Loading…
Cancel
Save