You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ZeroFractionTest.cs 2.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Microsoft.VisualStudio.TestTools.UnitTesting;
  6. using NumSharp;
  7. using Tensorflow;
  8. namespace TensorFlowNET.UnitTest.nn_test
  9. {
  10. [TestClass]
  11. public class ZeroFractionTest : PythonTest
  12. {
  13. protected double _ZeroFraction(NDArray x)
  14. {
  15. assert(x.shape);
  16. int total_elements = np.prod(x.shape);
  17. var eps = 1e-8;
  18. var nonzeros = x.Data<double>().Count(d=>Math.Abs(d)> eps);
  19. return 1.0 - nonzeros / (double)total_elements;
  20. }
  21. [Ignore("TODO implement nn_impl.zero_fraction")]
  22. [TestMethod]
  23. public void testZeroFraction()
  24. {
  25. var x_shape = new Shape(5, 17);
  26. var x_np = new NumPyRandom().randint(0, 2, x_shape);
  27. x_np.astype(np.float32);
  28. var y_np = this._ZeroFraction(x_np);
  29. var x_tf = constant_op.constant(x_np);
  30. x_tf.setShape(x_shape);
  31. var y_tf = nn_impl.zero_fraction(x_tf);
  32. var y_tf_np = self.evaluate<NDArray>(y_tf);
  33. var eps = 1e-8;
  34. self.assertAllClose(y_tf_np, y_np, eps);
  35. }
  36. [Ignore("TODO implement nn_impl.zero_fraction")]
  37. [TestMethod]
  38. public void testZeroFractionEmpty()
  39. {
  40. var x = np.zeros(0);
  41. var y = self.evaluate<NDArray>(nn_impl.zero_fraction(new Tensor(x)));
  42. self.assertTrue(np.isnan(y));
  43. }
  44. [Ignore("TODO implement nn_impl.zero_fraction")]
  45. [TestMethod]
  46. public void testZeroFraction2_27Zeros()
  47. {
  48. var sparsity = nn_impl.zero_fraction(
  49. array_ops.zeros(new Shape((int) Math.Pow(2, 27 * 1.01)), dtypes.int8));
  50. self.assertAllClose(1.0, self.evaluate<NDArray>(sparsity));
  51. }
  52. [Ignore("TODO implement nn_impl.zero_fraction")]
  53. [TestMethod]
  54. public void testZeroFraction2_27Ones()
  55. {
  56. var sparsity = nn_impl.zero_fraction(
  57. array_ops.ones(new Shape((int)Math.Pow(2, 27 * 1.01)), dtypes.int8));
  58. self.assertAllClose(0.0, self.evaluate<NDArray>(sparsity));
  59. }
  60. [Ignore("TODO implement nn_impl.zero_fraction")]
  61. [TestMethod]
  62. public void testUnknownSize()
  63. {
  64. var value = array_ops.placeholder(dtype: dtypes.float32);
  65. var sparsity = nn_impl.zero_fraction(value);
  66. with<Session>(self.cached_session(), sess => {
  67. // TODO: make this compile
  68. //self.assertAllClose(
  69. // 0.25,
  70. // sess.run(sparsity, {value: [[0., 1.], [0.3, 2.]]}));
  71. });
  72. }
  73. }
  74. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。