#!/usr/bin/env python # -*- coding: utf-8 -*- import os import unittest os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import tensorlayer as tl from tensorlayer.layers import * from tests.utils import CustomTestCase class Layer_Stack_Test(CustomTestCase): @classmethod def setUpClass(cls): print("-" * 20, "Layer_Stack_Test", "-" * 20) cls.batch_size = 4 cls.inputs_shape = [cls.batch_size, 10] cls.ni = Input(cls.inputs_shape, name='input_layer') class model(tl.layers.Module): def __init__(self): super(model, self).__init__() self.a = Dense(n_units=5) self.b = Dense(n_units=5) self.stack = Stack(axis=1) def forward(self, inputs): output1 = self.a(inputs) output2 = self.b(inputs) output = self.stack([output1, output2]) return output a = Dense(n_units=5)(cls.ni) b = Dense(n_units=5)(cls.ni) cls.layer1 = Stack(axis=1) cls.n1 = cls.layer1([a, b]) net = model() net.set_train() cls.inputs = Input(cls.inputs_shape) cls.n2 = net(cls.inputs) @classmethod def tearDownClass(cls): pass def test_layer_n1(self): self.assertEqual(self.n1.shape, (4, 2, 5)) def test_layer_n2(self): self.assertEqual(self.n2.shape, (4, 2, 5)) class Layer_UnStack_Test(CustomTestCase): @classmethod def setUpClass(cls): print("-" * 20, "Layer_UnStack_Test", "-" * 20) cls.batch_size = 4 cls.inputs_shape = [cls.batch_size, 10] cls.ni = Input(cls.inputs_shape, name='input_layer') a = Dense(n_units=5)(cls.ni) cls.layer1 = UnStack(axis=1) cls.n1 = cls.layer1(a) class model(tl.layers.Module): def __init__(self): super(model, self).__init__() self.a = Dense(n_units=5) self.unstack = UnStack(axis=1) def forward(self, inputs): output1 = self.a(inputs) output = self.unstack(output1) return output cls.inputs = Input(cls.inputs_shape) net = model() net.set_train() cls.n2 = net(cls.inputs) print(cls.layer1) @classmethod def tearDownClass(cls): pass def test_layer_n1(self): self.assertEqual(len(self.n1), 5) self.assertEqual(self.n1[0].shape, (self.batch_size, )) def test_layer_n2(self): self.assertEqual(len(self.n2), 5) self.assertEqual(self.n1[0].shape, (self.batch_size, )) if __name__ == '__main__': tl.logging.set_verbosity(tl.logging.DEBUG) unittest.main()