You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

zipf_distribution.h 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. // Copyright 2017 The Abseil Authors.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // https://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #ifndef ABSL_RANDOM_ZIPF_DISTRIBUTION_H_
  15. #define ABSL_RANDOM_ZIPF_DISTRIBUTION_H_
  16. #include <cassert>
  17. #include <cmath>
  18. #include <istream>
  19. #include <limits>
  20. #include <ostream>
  21. #include <type_traits>
  22. #include "absl/random/internal/iostream_state_saver.h"
  23. #include "absl/random/internal/traits.h"
  24. #include "absl/random/uniform_real_distribution.h"
  25. namespace absl
  26. {
  27. ABSL_NAMESPACE_BEGIN
  28. // absl::zipf_distribution produces random integer-values in the range [0, k],
  29. // distributed according to the unnormalized discrete probability function:
  30. //
  31. // P(x) = (v + x) ^ -q
  32. //
  33. // The parameter `v` must be greater than 0 and the parameter `q` must be
  34. // greater than 1. If either of these parameters take invalid values then the
  35. // behavior is undefined.
  36. //
  37. // IntType is the result_type generated by the generator. It must be of integral
  38. // type; a static_assert ensures this is the case.
  39. //
  40. // The implementation is based on W.Hormann, G.Derflinger:
  41. //
  42. // "Rejection-Inversion to Generate Variates from Monotone Discrete
  43. // Distributions"
  44. //
  45. // http://eeyore.wu-wien.ac.at/papers/96-04-04.wh-der.ps.gz
  46. //
  47. template<typename IntType = int>
  48. class zipf_distribution
  49. {
  50. public:
  51. using result_type = IntType;
  52. class param_type
  53. {
  54. public:
  55. using distribution_type = zipf_distribution;
  56. // Preconditions: k > 0, v > 0, q > 1
  57. // The precondidtions are validated when NDEBUG is not defined via
  58. // a pair of assert() directives.
  59. // If NDEBUG is defined and either or both of these parameters take invalid
  60. // values, the behavior of the class is undefined.
  61. explicit param_type(result_type k = (std::numeric_limits<IntType>::max)(), double q = 2.0, double v = 1.0);
  62. result_type k() const
  63. {
  64. return k_;
  65. }
  66. double q() const
  67. {
  68. return q_;
  69. }
  70. double v() const
  71. {
  72. return v_;
  73. }
  74. friend bool operator==(const param_type& a, const param_type& b)
  75. {
  76. return a.k_ == b.k_ && a.q_ == b.q_ && a.v_ == b.v_;
  77. }
  78. friend bool operator!=(const param_type& a, const param_type& b)
  79. {
  80. return !(a == b);
  81. }
  82. private:
  83. friend class zipf_distribution;
  84. inline double h(double x) const;
  85. inline double hinv(double x) const;
  86. inline double compute_s() const;
  87. inline double pow_negative_q(double x) const;
  88. // Parameters here are exactly the same as the parameters of Algorithm ZRI
  89. // in the paper.
  90. IntType k_;
  91. double q_;
  92. double v_;
  93. double one_minus_q_; // 1-q
  94. double s_;
  95. double one_minus_q_inv_; // 1 / 1-q
  96. double hxm_; // h(k + 0.5)
  97. double hx0_minus_hxm_; // h(x0) - h(k + 0.5)
  98. static_assert(random_internal::IsIntegral<IntType>::value, "Class-template absl::zipf_distribution<> must be "
  99. "parameterized using an integral type.");
  100. };
  101. zipf_distribution() :
  102. zipf_distribution((std::numeric_limits<IntType>::max)())
  103. {
  104. }
  105. explicit zipf_distribution(result_type k, double q = 2.0, double v = 1.0) :
  106. param_(k, q, v)
  107. {
  108. }
  109. explicit zipf_distribution(const param_type& p) :
  110. param_(p)
  111. {
  112. }
  113. void reset()
  114. {
  115. }
  116. template<typename URBG>
  117. result_type operator()(URBG& g)
  118. { // NOLINT(runtime/references)
  119. return (*this)(g, param_);
  120. }
  121. template<typename URBG>
  122. result_type operator()(URBG& g, // NOLINT(runtime/references)
  123. const param_type& p);
  124. result_type k() const
  125. {
  126. return param_.k();
  127. }
  128. double q() const
  129. {
  130. return param_.q();
  131. }
  132. double v() const
  133. {
  134. return param_.v();
  135. }
  136. param_type param() const
  137. {
  138. return param_;
  139. }
  140. void param(const param_type& p)
  141. {
  142. param_ = p;
  143. }
  144. result_type(min)() const
  145. {
  146. return 0;
  147. }
  148. result_type(max)() const
  149. {
  150. return k();
  151. }
  152. friend bool operator==(const zipf_distribution& a, const zipf_distribution& b)
  153. {
  154. return a.param_ == b.param_;
  155. }
  156. friend bool operator!=(const zipf_distribution& a, const zipf_distribution& b)
  157. {
  158. return a.param_ != b.param_;
  159. }
  160. private:
  161. param_type param_;
  162. };
  163. // --------------------------------------------------------------------------
  164. // Implementation details follow
  165. // --------------------------------------------------------------------------
  166. template<typename IntType>
  167. zipf_distribution<IntType>::param_type::param_type(
  168. typename zipf_distribution<IntType>::result_type k, double q, double v
  169. ) :
  170. k_(k),
  171. q_(q),
  172. v_(v),
  173. one_minus_q_(1 - q)
  174. {
  175. assert(q > 1);
  176. assert(v > 0);
  177. assert(k > 0);
  178. one_minus_q_inv_ = 1 / one_minus_q_;
  179. // Setup for the ZRI algorithm (pg 17 of the paper).
  180. // Compute: h(i max) => h(k + 0.5)
  181. constexpr double kMax = 18446744073709549568.0;
  182. double kd = static_cast<double>(k);
  183. // TODO(absl-team): Determine if this check is needed, and if so, add a test
  184. // that fails for k > kMax
  185. if (kd > kMax)
  186. {
  187. // Ensure that our maximum value is capped to a value which will
  188. // round-trip back through double.
  189. kd = kMax;
  190. }
  191. hxm_ = h(kd + 0.5);
  192. // Compute: h(0)
  193. const bool use_precomputed = (v == 1.0 && q == 2.0);
  194. const double h0x5 = use_precomputed ? (-1.0 / 1.5) // exp(-log(1.5))
  195. :
  196. h(0.5);
  197. const double elogv_q = (v_ == 1.0) ? 1 : pow_negative_q(v_);
  198. // h(0) = h(0.5) - exp(log(v) * -q)
  199. hx0_minus_hxm_ = (h0x5 - elogv_q) - hxm_;
  200. // And s
  201. s_ = use_precomputed ? 0.46153846153846123 : compute_s();
  202. }
  203. template<typename IntType>
  204. double zipf_distribution<IntType>::param_type::h(double x) const
  205. {
  206. // std::exp(one_minus_q_ * std::log(v_ + x)) * one_minus_q_inv_;
  207. x += v_;
  208. return (one_minus_q_ == -1.0) ? (-1.0 / x) // -exp(-log(x))
  209. :
  210. (std::exp(std::log(x) * one_minus_q_) * one_minus_q_inv_);
  211. }
  212. template<typename IntType>
  213. double zipf_distribution<IntType>::param_type::hinv(double x) const
  214. {
  215. // std::exp(one_minus_q_inv_ * std::log(one_minus_q_ * x)) - v_;
  216. return -v_ + ((one_minus_q_ == -1.0) ? (-1.0 / x) // exp(-log(-x))
  217. :
  218. std::exp(one_minus_q_inv_ * std::log(one_minus_q_ * x)));
  219. }
  220. template<typename IntType>
  221. double zipf_distribution<IntType>::param_type::compute_s() const
  222. {
  223. // 1 - hinv(h(1.5) - std::exp(std::log(v_ + 1) * -q_));
  224. return 1.0 - hinv(h(1.5) - pow_negative_q(v_ + 1.0));
  225. }
  226. template<typename IntType>
  227. double zipf_distribution<IntType>::param_type::pow_negative_q(double x) const
  228. {
  229. // std::exp(std::log(x) * -q_);
  230. return q_ == 2.0 ? (1.0 / (x * x)) : std::exp(std::log(x) * -q_);
  231. }
  232. template<typename IntType>
  233. template<typename URBG>
  234. typename zipf_distribution<IntType>::result_type
  235. zipf_distribution<IntType>::operator()(
  236. URBG& g, const param_type& p
  237. )
  238. { // NOLINT(runtime/references)
  239. absl::uniform_real_distribution<double> uniform_double;
  240. double k;
  241. for (;;)
  242. {
  243. const double v = uniform_double(g);
  244. const double u = p.hxm_ + v * p.hx0_minus_hxm_;
  245. const double x = p.hinv(u);
  246. k = rint(x); // std::floor(x + 0.5);
  247. if (k > static_cast<double>(p.k()))
  248. continue; // reject k > max_k
  249. if (k - x <= p.s_)
  250. break;
  251. const double h = p.h(k + 0.5);
  252. const double r = p.pow_negative_q(p.v_ + k);
  253. if (u >= h - r)
  254. break;
  255. }
  256. IntType ki = static_cast<IntType>(k);
  257. assert(ki <= p.k_);
  258. return ki;
  259. }
  260. template<typename CharT, typename Traits, typename IntType>
  261. std::basic_ostream<CharT, Traits>& operator<<(
  262. std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references)
  263. const zipf_distribution<IntType>& x
  264. )
  265. {
  266. using stream_type =
  267. typename random_internal::stream_format_type<IntType>::type;
  268. auto saver = random_internal::make_ostream_state_saver(os);
  269. os.precision(random_internal::stream_precision_helper<double>::kPrecision);
  270. os << static_cast<stream_type>(x.k()) << os.fill() << x.q() << os.fill()
  271. << x.v();
  272. return os;
  273. }
  274. template<typename CharT, typename Traits, typename IntType>
  275. std::basic_istream<CharT, Traits>& operator>>(
  276. std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references)
  277. zipf_distribution<IntType>& x
  278. )
  279. { // NOLINT(runtime/references)
  280. using result_type = typename zipf_distribution<IntType>::result_type;
  281. using param_type = typename zipf_distribution<IntType>::param_type;
  282. using stream_type =
  283. typename random_internal::stream_format_type<IntType>::type;
  284. stream_type k;
  285. double q;
  286. double v;
  287. auto saver = random_internal::make_istream_state_saver(is);
  288. is >> k >> q >> v;
  289. if (!is.fail())
  290. {
  291. x.param(param_type(static_cast<result_type>(k), q, v));
  292. }
  293. return is;
  294. }
  295. ABSL_NAMESPACE_END
  296. } // namespace absl
  297. #endif // ABSL_RANDOM_ZIPF_DISTRIBUTION_H_