Browse Source

fix core

master
Eric Lai 4 years ago
parent
commit
c611cc57bd
2 changed files with 34 additions and 6 deletions
  1. +32
    -4
      tensorlayer/layers/core/core_paddle.py
  2. +2
    -2
      tensorlayer/package_info.py

+ 32
- 4
tensorlayer/layers/core/core_paddle.py View File

@@ -2,14 +2,14 @@
# -*- coding: utf-8 -*-

import copy, six
import tensorlayer as tl
from .common import str2act
from paddle.fluid import framework
from paddle.fluid.dygraph import Layer
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.dygraph.base import program_desc_tracing_guard, param_guard
from paddle.fluid.dygraph import parallel_helper


_global_layer_name_dict = {} # TODO: better implementation?
_global_layer_name_dict = {}


class Module(Layer):
@@ -54,7 +54,10 @@ class Module(Layer):
self.act = act

# Layer building state
# self._built = False
self._built = False

# paddl_built
self._paddle_built = False

# Layer nodes state
self._nodes = []
@@ -160,6 +163,31 @@ class Module(Layer):
def forward(self, *inputs, **kwargs):
raise Exception("The forward method must be implemented by inherited class")

def __call__(self, *inputs, **kwargs):
with param_guard(self._parameters), param_guard(self._buffers):
for forward_pre_hook in self._forward_pre_hooks.values():
hook_result = forward_pre_hook(self, inputs)
if hook_result is not None:
if not isinstance(hook_result, tuple):
hook_result = (hook_result, )
inputs = hook_result

if not self._paddle_built:
with program_desc_tracing_guard(False):
self._build_once(*inputs, **kwargs)
if parallel_helper._is_data_parallel_mode():
parallel_helper._broadcast_parameters(
self._parameters.values())
self._paddle_built = True

outputs = self.forward(*inputs, **kwargs)

for forward_post_hook in self._forward_post_hooks.values():
hook_result = forward_post_hook(self, inputs, outputs)
if hook_result is not None:
outputs = hook_result

return outputs

def _get_weights(self, var_name, shape, init=None, trainable=True):
if var_name in ["filters", "weights"]:


+ 2
- 2
tensorlayer/package_info.py View File

@@ -2,8 +2,8 @@
# -*- coding: utf-8 -*-
"""Deep learning and Reinforcement learning library for Researchers and Engineers."""

MAJOR = 2
MINOR = 2
MAJOR = 3
MINOR = 0
PATCH = 0
PRE_RELEASE = ''
# Use the following formatting: (major, minor, patch, prerelease)


Loading…
Cancel
Save