From 9683231fb6f4574eede38ce5568cd68b32a03638 Mon Sep 17 00:00:00 2001 From: mbnacwlh3 Date: Wed, 13 Oct 2021 22:16:13 +0800 Subject: [PATCH] Update densenet --- densenet | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/densenet b/densenet index 80a6648..e0d63ec 100644 --- a/densenet +++ b/densenet @@ -34,6 +34,7 @@ class BottleNeck(tl.layers.Module): y = tf.keras.layers.concatenate([x, y], axis=-1) return y +# 构建密集块 class DenseBlock(tl.layers.Module): def __init__(self, num_layers, growth_rate, drop_rate=0.5): super(DenseBlock, self).__init__() @@ -49,7 +50,7 @@ class DenseBlock(tl.layers.Module): x = layer(x) return x - +# 构建过渡层 class TransitionLayer(tl.layers.Module): def __init__(self, out_channels): super(TransitionLayer, self).__init__() @@ -69,9 +70,10 @@ class TransitionLayer(tl.layers.Module): x = self.pool(x) return x -class DenseNet_121(tl.layers.Module): +# DenseNet-121,169,201,264模型 +class DenseNet(tl.layers.Module): def __init__(self, num_init_features, growth_rate, block_layers, compression_rate, drop_rate): - super(DenseNet_121, self).__init__() + super(DenseNet, self).__init__() self.conv = tl.layers.Conv2d(n_filter=num_init_features, filter_size=(7, 7), strides=(2,2), @@ -97,6 +99,7 @@ class DenseNet_121(tl.layers.Module): self.avgpool = tl.layers.GlobalMeanPool2d() self.fc = tl.layers.Dense(n_units=10,act=tl.softmax(logits=())) + def forward(self, inputs): x = self.conv(inputs) x = self.bn(x) @@ -116,7 +119,7 @@ class DenseNet_121(tl.layers.Module): return x - +# DenseNet-100模型 class DenseNet_100(tl.layers.Module): def __init__(self, num_init_features, growth_rate, block_layers, compression_rate, drop_rate): super(DenseNet_100, self).__init__() @@ -145,6 +148,7 @@ class DenseNet_100(tl.layers.Module): self.avgpool = tl.layers.GlobalMeanPool2d() self.fc = tl.layers.Dense(n_units=10,act=tl.softmax(logits=())) + def forward(self, inputs): x = self.conv(inputs) x = self.bn(x) @@ -167,17 +171,18 @@ class DenseNet_100(tl.layers.Module): def densenet(x): if x == 'densenet-121': - return DenseNet_121(num_init_features=64, growth_rate=32, block_layers=[6, 12, 24, 16], compression_rate=0.5, + return DenseNet(num_init_features=64, growth_rate=32, block_layers=[6, 12, 24, 16], compression_rate=0.5, drop_rate=0.5) elif x == 'densenet-169': - return DenseNet_121(num_init_features=64, growth_rate=32, block_layers=[6 , 12, 32, 32], compression_rate=0.5, + return DenseNet(num_init_features=64, growth_rate=32, block_layers=[6 , 12, 32, 32], compression_rate=0.5, drop_rate=0.5) elif x == 'densenet-201': - return DenseNet_121(num_init_features=64, growth_rate=32, block_layers=[6, 12, 48, 32], compression_rate=0.5, + return DenseNet(num_init_features=64, growth_rate=32, block_layers=[6, 12, 48, 32], compression_rate=0.5, drop_rate=0.5) elif x == 'densenet-264': - return DenseNet_121(num_init_features=64, growth_rate=32, block_layers=[6, 12, 64, 48], compression_rate=0.5, + return DenseNet(num_init_features=64, growth_rate=32, block_layers=[6, 12, 64, 48], compression_rate=0.5, drop_rate=0.5) elif x=='densenet-100': return DenseNet_100(num_init_features=64, growth_rate=12, block_layers=[16, 16, 16], compression_rate=0.5, drop_rate=0.5) +