| @@ -34,6 +34,7 @@ class BottleNeck(tl.layers.Module): | |||
| y = tf.keras.layers.concatenate([x, y], axis=-1) | |||
| return y | |||
| # 构建密集块 | |||
| class DenseBlock(tl.layers.Module): | |||
| def __init__(self, num_layers, growth_rate, drop_rate=0.5): | |||
| super(DenseBlock, self).__init__() | |||
| @@ -49,7 +50,7 @@ class DenseBlock(tl.layers.Module): | |||
| x = layer(x) | |||
| return x | |||
| # 构建过渡层 | |||
| class TransitionLayer(tl.layers.Module): | |||
| def __init__(self, out_channels): | |||
| super(TransitionLayer, self).__init__() | |||
| @@ -69,9 +70,10 @@ class TransitionLayer(tl.layers.Module): | |||
| x = self.pool(x) | |||
| return x | |||
| class DenseNet_121(tl.layers.Module): | |||
| # DenseNet-121,169,201,264模型 | |||
| class DenseNet(tl.layers.Module): | |||
| def __init__(self, num_init_features, growth_rate, block_layers, compression_rate, drop_rate): | |||
| super(DenseNet_121, self).__init__() | |||
| super(DenseNet, self).__init__() | |||
| self.conv = tl.layers.Conv2d(n_filter=num_init_features, | |||
| filter_size=(7, 7), | |||
| strides=(2,2), | |||
| @@ -97,6 +99,7 @@ class DenseNet_121(tl.layers.Module): | |||
| self.avgpool = tl.layers.GlobalMeanPool2d() | |||
| self.fc = tl.layers.Dense(n_units=10,act=tl.softmax(logits=())) | |||
| def forward(self, inputs): | |||
| x = self.conv(inputs) | |||
| x = self.bn(x) | |||
| @@ -116,7 +119,7 @@ class DenseNet_121(tl.layers.Module): | |||
| return x | |||
| # DenseNet-100模型 | |||
| class DenseNet_100(tl.layers.Module): | |||
| def __init__(self, num_init_features, growth_rate, block_layers, compression_rate, drop_rate): | |||
| super(DenseNet_100, self).__init__() | |||
| @@ -145,6 +148,7 @@ class DenseNet_100(tl.layers.Module): | |||
| self.avgpool = tl.layers.GlobalMeanPool2d() | |||
| self.fc = tl.layers.Dense(n_units=10,act=tl.softmax(logits=())) | |||
| def forward(self, inputs): | |||
| x = self.conv(inputs) | |||
| x = self.bn(x) | |||
| @@ -167,17 +171,18 @@ class DenseNet_100(tl.layers.Module): | |||
| def densenet(x): | |||
| if x == 'densenet-121': | |||
| return DenseNet_121(num_init_features=64, growth_rate=32, block_layers=[6, 12, 24, 16], compression_rate=0.5, | |||
| return DenseNet(num_init_features=64, growth_rate=32, block_layers=[6, 12, 24, 16], compression_rate=0.5, | |||
| drop_rate=0.5) | |||
| elif x == 'densenet-169': | |||
| return DenseNet_121(num_init_features=64, growth_rate=32, block_layers=[6 , 12, 32, 32], compression_rate=0.5, | |||
| return DenseNet(num_init_features=64, growth_rate=32, block_layers=[6 , 12, 32, 32], compression_rate=0.5, | |||
| drop_rate=0.5) | |||
| elif x == 'densenet-201': | |||
| return DenseNet_121(num_init_features=64, growth_rate=32, block_layers=[6, 12, 48, 32], compression_rate=0.5, | |||
| return DenseNet(num_init_features=64, growth_rate=32, block_layers=[6, 12, 48, 32], compression_rate=0.5, | |||
| drop_rate=0.5) | |||
| elif x == 'densenet-264': | |||
| return DenseNet_121(num_init_features=64, growth_rate=32, block_layers=[6, 12, 64, 48], compression_rate=0.5, | |||
| return DenseNet(num_init_features=64, growth_rate=32, block_layers=[6, 12, 64, 48], compression_rate=0.5, | |||
| drop_rate=0.5) | |||
| elif x=='densenet-100': | |||
| return DenseNet_100(num_init_features=64, growth_rate=12, block_layers=[16, 16, 16], compression_rate=0.5, | |||
| drop_rate=0.5) | |||