|
|
@@ -99,6 +99,20 @@ class LeNet(nn.Cell): # nn.Cell, 定义神经网络必须继承的模块, Mind |
|
|
|
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
|
|
|
|
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
|
|
|
|
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
|
|
|
|
def construct(self, x):
|
|
|
|
x = self.conv1(x)
|
|
|
|
x = self.relu(x)
|
|
|
|
x = self.max_pool2d(x)
|
|
|
|
x = self.conv2(x)
|
|
|
|
x = self.relu(x)
|
|
|
|
x = self.max_pool2d(x)
|
|
|
|
if not self.include_top:
|
|
|
|
return x
|
|
|
|
x = self.flatten(x)
|
|
|
|
x = self.relu(self.fc1(x))
|
|
|
|
x = self.relu(self.fc2(x))
|
|
|
|
x = self.fc3(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
print("-----猫子说一切正常1-----")
|
|
|
|
# 设计前向传播算法
|
|
|
|