Browse Source

Update densenet

master
mbnacwlh3 4 years ago
parent
commit
9683231fb6
1 changed files with 13 additions and 8 deletions
  1. +13
    -8
      densenet

+ 13
- 8
densenet View File

@@ -34,6 +34,7 @@ class BottleNeck(tl.layers.Module):
y = tf.keras.layers.concatenate([x, y], axis=-1)
return y
# 构建密集块
class DenseBlock(tl.layers.Module):
def __init__(self, num_layers, growth_rate, drop_rate=0.5):
super(DenseBlock, self).__init__()
@@ -49,7 +50,7 @@ class DenseBlock(tl.layers.Module):
x = layer(x)
return x
# 构建过渡层
class TransitionLayer(tl.layers.Module):
def __init__(self, out_channels):
super(TransitionLayer, self).__init__()
@@ -69,9 +70,10 @@ class TransitionLayer(tl.layers.Module):
x = self.pool(x)
return x
class DenseNet_121(tl.layers.Module):
# DenseNet-121,169,201,264模型
class DenseNet(tl.layers.Module):
def __init__(self, num_init_features, growth_rate, block_layers, compression_rate, drop_rate):
super(DenseNet_121, self).__init__()
super(DenseNet, self).__init__()
self.conv = tl.layers.Conv2d(n_filter=num_init_features,
filter_size=(7, 7),
strides=(2,2),
@@ -97,6 +99,7 @@ class DenseNet_121(tl.layers.Module):
self.avgpool = tl.layers.GlobalMeanPool2d()
self.fc = tl.layers.Dense(n_units=10,act=tl.softmax(logits=()))
def forward(self, inputs):
x = self.conv(inputs)
x = self.bn(x)
@@ -116,7 +119,7 @@ class DenseNet_121(tl.layers.Module):
return x
# DenseNet-100模型
class DenseNet_100(tl.layers.Module):
def __init__(self, num_init_features, growth_rate, block_layers, compression_rate, drop_rate):
super(DenseNet_100, self).__init__()
@@ -145,6 +148,7 @@ class DenseNet_100(tl.layers.Module):
self.avgpool = tl.layers.GlobalMeanPool2d()
self.fc = tl.layers.Dense(n_units=10,act=tl.softmax(logits=()))
def forward(self, inputs):
x = self.conv(inputs)
x = self.bn(x)
@@ -167,17 +171,18 @@ class DenseNet_100(tl.layers.Module):
def densenet(x):
if x == 'densenet-121':
return DenseNet_121(num_init_features=64, growth_rate=32, block_layers=[6, 12, 24, 16], compression_rate=0.5,
return DenseNet(num_init_features=64, growth_rate=32, block_layers=[6, 12, 24, 16], compression_rate=0.5,
drop_rate=0.5)
elif x == 'densenet-169':
return DenseNet_121(num_init_features=64, growth_rate=32, block_layers=[6 , 12, 32, 32], compression_rate=0.5,
return DenseNet(num_init_features=64, growth_rate=32, block_layers=[6 , 12, 32, 32], compression_rate=0.5,
drop_rate=0.5)
elif x == 'densenet-201':
return DenseNet_121(num_init_features=64, growth_rate=32, block_layers=[6, 12, 48, 32], compression_rate=0.5,
return DenseNet(num_init_features=64, growth_rate=32, block_layers=[6, 12, 48, 32], compression_rate=0.5,
drop_rate=0.5)
elif x == 'densenet-264':
return DenseNet_121(num_init_features=64, growth_rate=32, block_layers=[6, 12, 64, 48], compression_rate=0.5,
return DenseNet(num_init_features=64, growth_rate=32, block_layers=[6, 12, 64, 48], compression_rate=0.5,
drop_rate=0.5)
elif x=='densenet-100':
return DenseNet_100(num_init_features=64, growth_rate=12, block_layers=[16, 16, 16], compression_rate=0.5,
drop_rate=0.5)

Loading…
Cancel
Save