You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GpuLeakByCNN.cs 1.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Tensorflow.Keras.Layers;
  5. using Tensorflow.NumPy;
  6. using Tensorflow.Keras;
  7. using static Tensorflow.Binding;
  8. using static Tensorflow.KerasApi;
  9. using BenchmarkDotNet.Attributes;
  10. namespace Tensorflow.Benchmark.Leak
  11. {
  12. public class GpuLeakByCNN
  13. {
  14. protected static LayersApi layers = new LayersApi();
  15. [Benchmark]
  16. public void Run()
  17. {
  18. // tf.debugging.set_log_device_placement(true);
  19. tf.Context.Config.GpuOptions.AllowGrowth = true;
  20. int num = 50, width = 64, height = 64;
  21. // if width = 128, height = 128, the exception occurs faster
  22. var bytes = new byte[num * width * height * 3];
  23. var inputImages = np.array(bytes) / 255.0f;
  24. // inputImages = inputImages.reshape((num, height, width, 3));
  25. bytes = new byte[num];
  26. var outLables = np.array(bytes);
  27. Console.WriteLine("Image.Shape={0}", inputImages.dims);
  28. Console.WriteLine("Label.Shape={0}", outLables.dims);
  29. tf.enable_eager_execution();
  30. var inputs = keras.Input((height, width, 3));
  31. var layer = layers.Conv2D(32, (3, 3), activation: keras.activations.Relu).Apply(inputs);
  32. layer = layers.MaxPooling2D((2, 2)).Apply(layer);
  33. layer = layers.Flatten().Apply(layer);
  34. var outputs = layers.Dense(10).Apply(layer);
  35. var model = keras.Model(inputs, outputs, "gpuleak");
  36. model.summary();
  37. model.compile(loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true),
  38. optimizer: keras.optimizers.RMSprop(),
  39. metrics: new[] { "accuracy" });
  40. model.fit(inputImages, outLables, batch_size: 32, epochs: 200);
  41. keras.backend.clear_session();
  42. }
  43. }
  44. }