You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CGAN.py 8.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import jittor as jt
  2. from jittor import init
  3. import argparse
  4. import os
  5. import numpy as np
  6. import math
  7. from jittor import nn
  8. if jt.has_cuda:
  9. jt.flags.use_cuda = 1
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs of training')
  12. parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
  13. parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
  14. parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
  15. parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
  16. parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
  17. parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
  18. parser.add_argument('--n_classes', type=int, default=10, help='number of classes for dataset')
  19. parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension')
  20. parser.add_argument('--channels', type=int, default=1, help='number of image channels')
  21. parser.add_argument('--sample_interval', type=int, default=1000, help='interval between image sampling')
  22. opt = parser.parse_args()
  23. print(opt)
  24. img_shape = (opt.channels, opt.img_size, opt.img_size)
  25. class Generator(nn.Module):
  26. def __init__(self):
  27. super(Generator, self).__init__()
  28. self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
  29. # nn.Linear(in_dim, out_dim)表示全连接层
  30. # in_dim:输入向量维度
  31. # out_dim:输出向量维度
  32. def block(in_feat, out_feat, normalize=True):
  33. layers = [nn.Linear(in_feat, out_feat)]
  34. if normalize:
  35. layers.append(nn.BatchNorm1d(out_feat, 0.8))
  36. layers.append(nn.LeakyReLU(0.2))
  37. return layers
  38. self.model = nn.Sequential(*block((opt.latent_dim + opt.n_classes), 128, normalize=False),
  39. *block(128, 256),
  40. *block(256, 512),
  41. *block(512, 1024),
  42. nn.Linear(1024, int(np.prod(img_shape))),
  43. nn.Tanh())
  44. def execute(self, noise, labels):
  45. gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)
  46. img = self.model(gen_input)
  47. # 将img从1024维向量变为32*32矩阵
  48. img = img.view((img.shape[0], *img_shape))
  49. return img
  50. class Discriminator(nn.Module):
  51. def __init__(self):
  52. super(Discriminator, self).__init__()
  53. self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)
  54. self.model = nn.Sequential(nn.Linear((opt.n_classes + int(np.prod(img_shape))), 512),
  55. nn.LeakyReLU(0.2),
  56. nn.Linear(512, 512),
  57. nn.Dropout(0.4),
  58. nn.LeakyReLU(0.2),
  59. nn.Linear(512, 512),
  60. nn.Dropout(0.4),
  61. nn.LeakyReLU(0.2),
  62. # TODO: 添加最后一个线性层,最终输出为一个实数
  63. )
  64. def execute(self, img, labels):
  65. d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
  66. validity = self.model(d_in)
  67. return validity
  68. # TODO: 将d_in输入到模型中并返回计算结果
  69. # 损失函数:平方误差
  70. # 调用方法:adversarial_loss(网络输出A, 分类标签B)
  71. # 计算结果:(A-B)^2
  72. adversarial_loss = nn.MSELoss()
  73. generator = Generator()
  74. discriminator = Discriminator()
  75. # 导入MNIST数据集
  76. from jittor.dataset.mnist import MNIST
  77. import jittor.transform as transform
  78. transform = transform.Compose([
  79. transform.Resize(opt.img_size),
  80. transform.Gray(),
  81. transform.ImageNormalize(mean=[0.5], std=[0.5]),
  82. ])
  83. dataloader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.batch_size, shuffle=True)
  84. optimizer_G = nn.Adan(generator.parameters(), lr=opt.lr, betas=(0.98, 0.99, 0.99))
  85. optimizer_D = nn.AdamW(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
  86. from PIL import Image
  87. def save_image(img, path, nrow=10, padding=5):
  88. N,C,W,H = img.shape
  89. if (N%nrow!=0):
  90. print("N%nrow!=0")
  91. return
  92. ncol=int(N/nrow)
  93. img_all = []
  94. for i in range(ncol):
  95. img_ = []
  96. for j in range(nrow):
  97. img_.append(img[i*nrow+j])
  98. img_.append(np.zeros((C,W,padding)))
  99. img_all.append(np.concatenate(img_, 2))
  100. img_all.append(np.zeros((C,padding,img_all[0].shape[2])))
  101. img = np.concatenate(img_all, 1)
  102. img = np.concatenate([np.zeros((C,padding,img.shape[2])), img], 1)
  103. img = np.concatenate([np.zeros((C,img.shape[1],padding)), img], 2)
  104. min_=img.min()
  105. max_=img.max()
  106. img=(img-min_)/(max_-min_)*255
  107. img=img.transpose((1,2,0))
  108. if C==3:
  109. img = img[:,:,::-1]
  110. elif C==1:
  111. img = img[:,:,0]
  112. Image.fromarray(np.uint8(img)).save(path)
  113. def sample_image(n_row, batches_done):
  114. # 随机采样输入并保存生成的图片
  115. z = jt.array(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))).float32().stop_grad()
  116. labels = jt.array(np.array([num for _ in range(n_row) for num in range(n_row)])).float32().stop_grad()
  117. gen_imgs = generator(z, labels)
  118. save_image(gen_imgs.numpy(), "%d.png" % batches_done, nrow=n_row)
  119. # ----------
  120. # 模型训练
  121. # ----------
  122. for epoch in range(opt.n_epochs):
  123. for i, (imgs, labels) in enumerate(dataloader):
  124. batch_size = imgs.shape[0]
  125. # 数据标签,valid=1表示真实的图片,fake=0表示生成的图片
  126. valid = jt.ones([batch_size, 1]).float32().stop_grad()
  127. fake = jt.zeros([batch_size, 1]).float32().stop_grad()
  128. # 真实图片及其类别
  129. real_imgs = jt.array(imgs)
  130. labels = jt.array(labels)
  131. # -----------------
  132. # 训练生成器
  133. # -----------------
  134. # 采样随机噪声和数字类别作为生成器输入
  135. z = jt.array(np.random.normal(0, 1, (batch_size, opt.latent_dim))).float32()
  136. gen_labels = jt.array(np.random.randint(0, opt.n_classes, batch_size)).float32()
  137. # 生成一组图片
  138. gen_imgs = generator(z, gen_labels)
  139. # 损失函数衡量生成器欺骗判别器的能力,即希望判别器将生成图片分类为valid
  140. validity = discriminator(gen_imgs, gen_labels)
  141. g_loss = adversarial_loss(validity, valid)
  142. g_loss.sync()
  143. optimizer_G.step(g_loss)
  144. # ---------------------
  145. # 训练判别器
  146. # ---------------------
  147. validity_real = discriminator(real_imgs, labels)
  148. # d_real_loss = adversarial_loss("""TODO: 计算真实类别的损失函数""")
  149. d_real_loss = adversarial_loss(validity_real, valid)
  150. validity_fake = discriminator(gen_imgs.stop_grad(), gen_labels)
  151. # d_fake_loss = adversarial_loss("""TODO: 计算虚假类别的损失函数""")
  152. d_fake_loss = adversarial_loss(validity_fake, fake)
  153. # 总的判别器损失
  154. d_loss = (d_real_loss + d_fake_loss) / 2
  155. d_loss.sync()
  156. optimizer_D.step(d_loss)
  157. if i % 50 == 0:
  158. print(
  159. "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
  160. % (epoch, opt.n_epochs, i, len(dataloader), d_loss.data, g_loss.data)
  161. )
  162. batches_done = epoch * len(dataloader) + i
  163. if batches_done % opt.sample_interval == 0:
  164. sample_image(n_row=10, batches_done=batches_done)
  165. if epoch % 10 == 0:
  166. generator.save("generator_last.pkl")
  167. discriminator.save("discriminator_last.pkl")
  168. generator.eval()
  169. discriminator.eval()
  170. generator.load('generator_last.pkl')
  171. discriminator.load('discriminator_last.pkl')
  172. # number = #TODO: 写入比赛页面中指定的数字序列(字符串类型)
  173. number = '20145792009834'
  174. n_row = len(number)
  175. z = jt.array(np.random.normal(0, 1, (n_row, opt.latent_dim))).float32().stop_grad()
  176. labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
  177. gen_imgs = generator(z,labels)
  178. img_array = gen_imgs.data.transpose((1,2,0,3))[0].reshape((gen_imgs.shape[2], -1))
  179. min_=img_array.min()
  180. max_=img_array.max()
  181. img_array=(img_array-min_)/(max_-min_)*255
  182. Image.fromarray(np.uint8(img_array)).save("result.png")

将在数字图片数据集 MNIST 上训练 Conditional GAN(Conditional generative adversarial nets)模型,通过输入一个随机向量 z 和额外的辅助信息 y (如类别标签),生成特定数字的图像。

Contributors (1)