You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CGAN.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import jittor as jt
  2. from jittor import nn, transform
  3. import numpy as np
  4. from PIL import Image
  5. import os
  6. import time
  7. import sys
  8. jt.flags.use_cuda = 1
  9. # hyper params
  10. epochs = 100
  11. batch_size = 64
  12. lr = 0.0001
  13. latent_dim = 100
  14. # ref https://github.com/arturml/mnist-cgan/blob/master/mnist-cgan.ipynb
  15. # generator
  16. class Generator(nn.Module):
  17. def __init__(self):
  18. super(Generator, self).__init__()
  19. self.label_emb = nn.Embedding(10, 10)
  20. self.model = nn.Sequential(
  21. nn.Linear(110, 256),
  22. nn.LeakyReLU(0.2),
  23. nn.Linear(256, 512),
  24. nn.LeakyReLU(0.2),
  25. nn.Linear(512, 1024),
  26. nn.LeakyReLU(0.2),
  27. nn.Linear(1024, 784),
  28. nn.Tanh()
  29. )
  30. def execute(self, noise, labels):
  31. labels = self.label_emb(labels)
  32. x = jt.concat([noise, labels], 1)
  33. out = self.model(x)
  34. return out.view(out.shape[0], 1, 28, 28)
  35. class Discriminator(nn.Module):
  36. def __init__(self):
  37. super(Discriminator, self).__init__()
  38. self.label_emb = nn.Embedding(10, 10)
  39. self.model = nn.Sequential(
  40. nn.Linear(794, 1024),
  41. nn.LeakyReLU(0.2),
  42. nn.Linear(1024, 512),
  43. nn.LeakyReLU(0.2),
  44. nn.Linear(512, 256),
  45. nn.LeakyReLU(0.2),
  46. nn.Linear(256, 1),
  47. nn.Sigmoid()
  48. )
  49. def execute(self, images, labels):
  50. images = images.view(images.shape[0], -1)
  51. labels = self.label_emb(labels)
  52. x = jt.concat([images, labels], 1)
  53. out = self.model(x)
  54. return out
  55. generator = Generator()
  56. discriminator = Discriminator()
  57. trans = transform.Compose([
  58. transform.Gray(),
  59. transform.ImageNormalize([0.5], [0.5])
  60. ])
  61. data = jt.dataset.MNIST(train=True, transform=trans).set_attrs(batch_size=batch_size, shuffle=True)
  62. loss = nn.BCELoss()
  63. generator_opt = jt.optim.Adam(generator.parameters(), lr)
  64. discriminator_opt = jt.optim.Adam(discriminator.parameters(), lr)
  65. saved_path = 'saved'
  66. if not os.path.exists(saved_path):
  67. os.mkdir(saved_path)
  68. def save_image(filename):
  69. noise = jt.array(np.random.randn(10, 100)).float32().stop_grad()
  70. labels = jt.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
  71. images = generator(noise, labels).numpy()
  72. image_all = []
  73. for i in range(10):
  74. image_all.append(images[i])
  75. image_all = np.concatenate(image_all, 1)
  76. image_all = (image_all + 1) / 2 * 255
  77. image_all = image_all.transpose(1, 2, 0)
  78. image_all = image_all[:, :, 0]
  79. Image.fromarray(np.uint8(image_all)).save(filename)
  80. def train():
  81. generator.train()
  82. discriminator.train()
  83. for epoch in range(epochs):
  84. start_time = time.time()
  85. for i, (images, labels) in enumerate(data):
  86. batch_size = images.shape[0]
  87. valid = jt.ones([batch_size, 1]).float32().stop_grad()
  88. fake = jt.zeros([batch_size, 1]).float32().stop_grad()
  89. # train generator
  90. noise = jt.array(np.random.randn(batch_size, 100))
  91. fake_labels = jt.array(np.random.randint(0, 10, batch_size))
  92. fake_images = generator(noise, fake_labels)
  93. validity = discriminator(fake_images, fake_labels)
  94. generator_loss = loss(validity, valid)
  95. generator_loss.sync()
  96. generator_opt.step(generator_loss)
  97. # train discriminator
  98. # real images
  99. real_images = jt.array(images)
  100. labels = jt.array(labels)
  101. real_validity = discriminator(real_images, labels)
  102. real_loss = loss(real_validity, valid)
  103. # fake images
  104. noise = jt.array(np.random.randn(batch_size, 100))
  105. fake_validity = discriminator(fake_images.stop_grad(), fake_labels)
  106. fake_loss = loss(fake_validity, fake)
  107. discriminator_loss = (real_loss + fake_loss) / 2
  108. discriminator_loss.sync()
  109. discriminator_opt.step(discriminator_loss)
  110. if i % 50 == 0:
  111. print('Epoch %d Batch %d Loss g %f d %f' % (epoch, i, generator_loss.data, discriminator_loss.data))
  112. print('Epoch %d done Time %s' % (epoch, time.time() - start_time))
  113. generator.save('saved/generator_%d.pkl' % epoch)
  114. discriminator.save('saved/discriminator_%d.pkl' % epoch)
  115. save_image('saved/gen_%d.png' % epoch)
  116. def generate(numbers, epoch, filename):
  117. generator_g = Generator()
  118. generator_g.eval()
  119. generator_g.load('saved/generator_%d.pkl' % epoch)
  120. num_labels = len(numbers)
  121. noise = jt.array(np.random.randn(num_labels, 100)).float32().stop_grad()
  122. labels = jt.array(numbers).float32().stop_grad()
  123. images = generator_g(noise, labels).numpy()
  124. image_all = []
  125. for i in range(num_labels):
  126. image_all.append(images[i])
  127. image_all = np.concatenate(image_all, 2)
  128. image_all = (image_all + 1) / 2 * 255
  129. image_all = image_all.transpose(1, 2, 0)
  130. image_all = image_all[:, :, 0]
  131. Image.fromarray(np.uint8(image_all)).save(filename)
  132. if __name__ == '__main__':
  133. if len(sys.argv) < 2:
  134. print('usage: \npython CGAN.py train\npython CGAN.py eval [output file] [epoch]')
  135. exit(1)
  136. if sys.argv[1] == 'train':
  137. train()
  138. elif sys.argv[1] == 'eval':
  139. generate([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], int(sys.argv[3]), sys.argv[2])
  140. else:
  141. print('usage: \npython CGAN.py train\npython CGAN.py eval [output file] [epoch]')
  142. exit(1)

第二届计图人工智能挑战赛 华中科技大学“武汉大学”队 热身赛项目,使用了Jittor框架实现的Conditional GAN完成手写数字生成题目

Contributors (1)