You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

distributions.h 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. // Copyright 2017 The Abseil Authors.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // https://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // -----------------------------------------------------------------------------
  16. // File: distributions.h
  17. // -----------------------------------------------------------------------------
  18. //
  19. // This header defines functions representing distributions, which you use in
  20. // combination with an Abseil random bit generator to produce random values
  21. // according to the rules of that distribution.
  22. //
  23. // The Abseil random library defines the following distributions within this
  24. // file:
  25. //
  26. // * `absl::Uniform` for uniform (constant) distributions having constant
  27. // probability
  28. // * `absl::Bernoulli` for discrete distributions having exactly two outcomes
  29. // * `absl::Beta` for continuous distributions parameterized through two
  30. // free parameters
  31. // * `absl::Exponential` for discrete distributions of events occurring
  32. // continuously and independently at a constant average rate
  33. // * `absl::Gaussian` (also known as "normal distributions") for continuous
  34. // distributions using an associated quadratic function
  35. // * `absl::LogUniform` for continuous uniform distributions where the log
  36. // to the given base of all values is uniform
  37. // * `absl::Poisson` for discrete probability distributions that express the
  38. // probability of a given number of events occurring within a fixed interval
  39. // * `absl::Zipf` for discrete probability distributions commonly used for
  40. // modelling of rare events
  41. //
  42. // Prefer use of these distribution function classes over manual construction of
  43. // your own distribution classes, as it allows library maintainers greater
  44. // flexibility to change the underlying implementation in the future.
  45. #ifndef ABSL_RANDOM_DISTRIBUTIONS_H_
  46. #define ABSL_RANDOM_DISTRIBUTIONS_H_
  47. #include <algorithm>
  48. #include <cmath>
  49. #include <limits>
  50. #include <random>
  51. #include <type_traits>
  52. #include "absl/base/internal/inline_variable.h"
  53. #include "absl/random/bernoulli_distribution.h"
  54. #include "absl/random/beta_distribution.h"
  55. #include "absl/random/exponential_distribution.h"
  56. #include "absl/random/gaussian_distribution.h"
  57. #include "absl/random/internal/distribution_caller.h" // IWYU pragma: export
  58. #include "absl/random/internal/uniform_helper.h" // IWYU pragma: export
  59. #include "absl/random/log_uniform_int_distribution.h"
  60. #include "absl/random/poisson_distribution.h"
  61. #include "absl/random/uniform_int_distribution.h"
  62. #include "absl/random/uniform_real_distribution.h"
  63. #include "absl/random/zipf_distribution.h"
  64. namespace absl
  65. {
  66. ABSL_NAMESPACE_BEGIN
  67. ABSL_INTERNAL_INLINE_CONSTEXPR(IntervalClosedClosedTag, IntervalClosedClosed, {});
  68. ABSL_INTERNAL_INLINE_CONSTEXPR(IntervalClosedClosedTag, IntervalClosed, {});
  69. ABSL_INTERNAL_INLINE_CONSTEXPR(IntervalClosedOpenTag, IntervalClosedOpen, {});
  70. ABSL_INTERNAL_INLINE_CONSTEXPR(IntervalOpenOpenTag, IntervalOpenOpen, {});
  71. ABSL_INTERNAL_INLINE_CONSTEXPR(IntervalOpenOpenTag, IntervalOpen, {});
  72. ABSL_INTERNAL_INLINE_CONSTEXPR(IntervalOpenClosedTag, IntervalOpenClosed, {});
  73. // -----------------------------------------------------------------------------
  74. // absl::Uniform<T>(tag, bitgen, lo, hi)
  75. // -----------------------------------------------------------------------------
  76. //
  77. // `absl::Uniform()` produces random values of type `T` uniformly distributed in
  78. // a defined interval {lo, hi}. The interval `tag` defines the type of interval
  79. // which should be one of the following possible values:
  80. //
  81. // * `absl::IntervalOpenOpen`
  82. // * `absl::IntervalOpenClosed`
  83. // * `absl::IntervalClosedOpen`
  84. // * `absl::IntervalClosedClosed`
  85. //
  86. // where "open" refers to an exclusive value (excluded) from the output, while
  87. // "closed" refers to an inclusive value (included) from the output.
  88. //
  89. // In the absence of an explicit return type `T`, `absl::Uniform()` will deduce
  90. // the return type based on the provided endpoint arguments {A lo, B hi}.
  91. // Given these endpoints, one of {A, B} will be chosen as the return type, if
  92. // a type can be implicitly converted into the other in a lossless way. The
  93. // lack of any such implicit conversion between {A, B} will produce a
  94. // compile-time error
  95. //
  96. // See https://en.wikipedia.org/wiki/Uniform_distribution_(continuous)
  97. //
  98. // Example:
  99. //
  100. // absl::BitGen bitgen;
  101. //
  102. // // Produce a random float value between 0.0 and 1.0, inclusive
  103. // auto x = absl::Uniform(absl::IntervalClosedClosed, bitgen, 0.0f, 1.0f);
  104. //
  105. // // The most common interval of `absl::IntervalClosedOpen` is available by
  106. // // default:
  107. //
  108. // auto x = absl::Uniform(bitgen, 0.0f, 1.0f);
  109. //
  110. // // Return-types are typically inferred from the arguments, however callers
  111. // // can optionally provide an explicit return-type to the template.
  112. //
  113. // auto x = absl::Uniform<float>(bitgen, 0, 1);
  114. //
  115. template<typename R = void, typename TagType, typename URBG>
  116. typename absl::enable_if_t<!std::is_same<R, void>::value, R> //
  117. Uniform(TagType tag,
  118. URBG&& urbg, // NOLINT(runtime/references)
  119. R lo,
  120. R hi)
  121. {
  122. using gen_t = absl::decay_t<URBG>;
  123. using distribution_t = random_internal::UniformDistributionWrapper<R>;
  124. auto a = random_internal::uniform_lower_bound(tag, lo, hi);
  125. auto b = random_internal::uniform_upper_bound(tag, lo, hi);
  126. if (!random_internal::is_uniform_range_valid(a, b))
  127. return lo;
  128. return random_internal::DistributionCaller<gen_t>::template Call<
  129. distribution_t>(&urbg, tag, lo, hi);
  130. }
  131. // absl::Uniform<T>(bitgen, lo, hi)
  132. //
  133. // Overload of `Uniform()` using the default closed-open interval of [lo, hi),
  134. // and returning values of type `T`
  135. template<typename R = void, typename URBG>
  136. typename absl::enable_if_t<!std::is_same<R, void>::value, R> //
  137. Uniform(URBG&& urbg, // NOLINT(runtime/references)
  138. R lo,
  139. R hi)
  140. {
  141. using gen_t = absl::decay_t<URBG>;
  142. using distribution_t = random_internal::UniformDistributionWrapper<R>;
  143. constexpr auto tag = absl::IntervalClosedOpen;
  144. auto a = random_internal::uniform_lower_bound(tag, lo, hi);
  145. auto b = random_internal::uniform_upper_bound(tag, lo, hi);
  146. if (!random_internal::is_uniform_range_valid(a, b))
  147. return lo;
  148. return random_internal::DistributionCaller<gen_t>::template Call<
  149. distribution_t>(&urbg, lo, hi);
  150. }
  151. // absl::Uniform(tag, bitgen, lo, hi)
  152. //
  153. // Overload of `Uniform()` using different (but compatible) lo, hi types. Note
  154. // that a compile-error will result if the return type cannot be deduced
  155. // correctly from the passed types.
  156. template<typename R = void, typename TagType, typename URBG, typename A, typename B>
  157. typename absl::enable_if_t<std::is_same<R, void>::value, random_internal::uniform_inferred_return_t<A, B>>
  158. Uniform(TagType tag,
  159. URBG&& urbg, // NOLINT(runtime/references)
  160. A lo,
  161. B hi)
  162. {
  163. using gen_t = absl::decay_t<URBG>;
  164. using return_t = typename random_internal::uniform_inferred_return_t<A, B>;
  165. using distribution_t = random_internal::UniformDistributionWrapper<return_t>;
  166. auto a = random_internal::uniform_lower_bound<return_t>(tag, lo, hi);
  167. auto b = random_internal::uniform_upper_bound<return_t>(tag, lo, hi);
  168. if (!random_internal::is_uniform_range_valid(a, b))
  169. return lo;
  170. return random_internal::DistributionCaller<gen_t>::template Call<
  171. distribution_t>(&urbg, tag, static_cast<return_t>(lo), static_cast<return_t>(hi));
  172. }
  173. // absl::Uniform(bitgen, lo, hi)
  174. //
  175. // Overload of `Uniform()` using different (but compatible) lo, hi types and the
  176. // default closed-open interval of [lo, hi). Note that a compile-error will
  177. // result if the return type cannot be deduced correctly from the passed types.
  178. template<typename R = void, typename URBG, typename A, typename B>
  179. typename absl::enable_if_t<std::is_same<R, void>::value, random_internal::uniform_inferred_return_t<A, B>>
  180. Uniform(URBG&& urbg, // NOLINT(runtime/references)
  181. A lo,
  182. B hi)
  183. {
  184. using gen_t = absl::decay_t<URBG>;
  185. using return_t = typename random_internal::uniform_inferred_return_t<A, B>;
  186. using distribution_t = random_internal::UniformDistributionWrapper<return_t>;
  187. constexpr auto tag = absl::IntervalClosedOpen;
  188. auto a = random_internal::uniform_lower_bound<return_t>(tag, lo, hi);
  189. auto b = random_internal::uniform_upper_bound<return_t>(tag, lo, hi);
  190. if (!random_internal::is_uniform_range_valid(a, b))
  191. return lo;
  192. return random_internal::DistributionCaller<gen_t>::template Call<
  193. distribution_t>(&urbg, static_cast<return_t>(lo), static_cast<return_t>(hi));
  194. }
  195. // absl::Uniform<unsigned T>(bitgen)
  196. //
  197. // Overload of Uniform() using the minimum and maximum values of a given type
  198. // `T` (which must be unsigned), returning a value of type `unsigned T`
  199. template<typename R, typename URBG>
  200. typename absl::enable_if_t<!std::is_signed<R>::value, R> //
  201. Uniform(URBG&& urbg)
  202. { // NOLINT(runtime/references)
  203. using gen_t = absl::decay_t<URBG>;
  204. using distribution_t = random_internal::UniformDistributionWrapper<R>;
  205. return random_internal::DistributionCaller<gen_t>::template Call<
  206. distribution_t>(&urbg);
  207. }
  208. // -----------------------------------------------------------------------------
  209. // absl::Bernoulli(bitgen, p)
  210. // -----------------------------------------------------------------------------
  211. //
  212. // `absl::Bernoulli` produces a random boolean value, with probability `p`
  213. // (where 0.0 <= p <= 1.0) equaling `true`.
  214. //
  215. // Prefer `absl::Bernoulli` to produce boolean values over other alternatives
  216. // such as comparing an `absl::Uniform()` value to a specific output.
  217. //
  218. // See https://en.wikipedia.org/wiki/Bernoulli_distribution
  219. //
  220. // Example:
  221. //
  222. // absl::BitGen bitgen;
  223. // ...
  224. // if (absl::Bernoulli(bitgen, 1.0/3721.0)) {
  225. // std::cout << "Asteroid field navigation successful.";
  226. // }
  227. //
  228. template<typename URBG>
  229. bool Bernoulli(URBG&& urbg, // NOLINT(runtime/references)
  230. double p)
  231. {
  232. using gen_t = absl::decay_t<URBG>;
  233. using distribution_t = absl::bernoulli_distribution;
  234. return random_internal::DistributionCaller<gen_t>::template Call<
  235. distribution_t>(&urbg, p);
  236. }
  237. // -----------------------------------------------------------------------------
  238. // absl::Beta<T>(bitgen, alpha, beta)
  239. // -----------------------------------------------------------------------------
  240. //
  241. // `absl::Beta` produces a floating point number distributed in the closed
  242. // interval [0,1] and parameterized by two values `alpha` and `beta` as per a
  243. // Beta distribution. `T` must be a floating point type, but may be inferred
  244. // from the types of `alpha` and `beta`.
  245. //
  246. // See https://en.wikipedia.org/wiki/Beta_distribution.
  247. //
  248. // Example:
  249. //
  250. // absl::BitGen bitgen;
  251. // ...
  252. // double sample = absl::Beta(bitgen, 3.0, 2.0);
  253. //
  254. template<typename RealType, typename URBG>
  255. RealType Beta(URBG&& urbg, // NOLINT(runtime/references)
  256. RealType alpha,
  257. RealType beta)
  258. {
  259. static_assert(
  260. std::is_floating_point<RealType>::value,
  261. "Template-argument 'RealType' must be a floating-point type, in "
  262. "absl::Beta<RealType, URBG>(...)"
  263. );
  264. using gen_t = absl::decay_t<URBG>;
  265. using distribution_t = typename absl::beta_distribution<RealType>;
  266. return random_internal::DistributionCaller<gen_t>::template Call<
  267. distribution_t>(&urbg, alpha, beta);
  268. }
  269. // -----------------------------------------------------------------------------
  270. // absl::Exponential<T>(bitgen, lambda = 1)
  271. // -----------------------------------------------------------------------------
  272. //
  273. // `absl::Exponential` produces a floating point number representing the
  274. // distance (time) between two consecutive events in a point process of events
  275. // occurring continuously and independently at a constant average rate. `T` must
  276. // be a floating point type, but may be inferred from the type of `lambda`.
  277. //
  278. // See https://en.wikipedia.org/wiki/Exponential_distribution.
  279. //
  280. // Example:
  281. //
  282. // absl::BitGen bitgen;
  283. // ...
  284. // double call_length = absl::Exponential(bitgen, 7.0);
  285. //
  286. template<typename RealType, typename URBG>
  287. RealType Exponential(URBG&& urbg, // NOLINT(runtime/references)
  288. RealType lambda = 1)
  289. {
  290. static_assert(
  291. std::is_floating_point<RealType>::value,
  292. "Template-argument 'RealType' must be a floating-point type, in "
  293. "absl::Exponential<RealType, URBG>(...)"
  294. );
  295. using gen_t = absl::decay_t<URBG>;
  296. using distribution_t = typename absl::exponential_distribution<RealType>;
  297. return random_internal::DistributionCaller<gen_t>::template Call<
  298. distribution_t>(&urbg, lambda);
  299. }
  300. // -----------------------------------------------------------------------------
  301. // absl::Gaussian<T>(bitgen, mean = 0, stddev = 1)
  302. // -----------------------------------------------------------------------------
  303. //
  304. // `absl::Gaussian` produces a floating point number selected from the Gaussian
  305. // (ie. "Normal") distribution. `T` must be a floating point type, but may be
  306. // inferred from the types of `mean` and `stddev`.
  307. //
  308. // See https://en.wikipedia.org/wiki/Normal_distribution
  309. //
  310. // Example:
  311. //
  312. // absl::BitGen bitgen;
  313. // ...
  314. // double giraffe_height = absl::Gaussian(bitgen, 16.3, 3.3);
  315. //
  316. template<typename RealType, typename URBG>
  317. RealType Gaussian(URBG&& urbg, // NOLINT(runtime/references)
  318. RealType mean = 0,
  319. RealType stddev = 1)
  320. {
  321. static_assert(
  322. std::is_floating_point<RealType>::value,
  323. "Template-argument 'RealType' must be a floating-point type, in "
  324. "absl::Gaussian<RealType, URBG>(...)"
  325. );
  326. using gen_t = absl::decay_t<URBG>;
  327. using distribution_t = typename absl::gaussian_distribution<RealType>;
  328. return random_internal::DistributionCaller<gen_t>::template Call<
  329. distribution_t>(&urbg, mean, stddev);
  330. }
  331. // -----------------------------------------------------------------------------
  332. // absl::LogUniform<T>(bitgen, lo, hi, base = 2)
  333. // -----------------------------------------------------------------------------
  334. //
  335. // `absl::LogUniform` produces random values distributed where the log to a
  336. // given base of all values is uniform in a closed interval [lo, hi]. `T` must
  337. // be an integral type, but may be inferred from the types of `lo` and `hi`.
  338. //
  339. // I.e., `LogUniform(0, n, b)` is uniformly distributed across buckets
  340. // [0], [1, b-1], [b, b^2-1] .. [b^(k-1), (b^k)-1] .. [b^floor(log(n, b)), n]
  341. // and is uniformly distributed within each bucket.
  342. //
  343. // The resulting probability density is inversely related to bucket size, though
  344. // values in the final bucket may be more likely than previous values. (In the
  345. // extreme case where n = b^i the final value will be tied with zero as the most
  346. // probable result.
  347. //
  348. // If `lo` is nonzero then this distribution is shifted to the desired interval,
  349. // so LogUniform(lo, hi, b) is equivalent to LogUniform(0, hi-lo, b)+lo.
  350. //
  351. // See http://ecolego.facilia.se/ecolego/show/Log-Uniform%20Distribution
  352. //
  353. // Example:
  354. //
  355. // absl::BitGen bitgen;
  356. // ...
  357. // int v = absl::LogUniform(bitgen, 0, 1000);
  358. //
  359. template<typename IntType, typename URBG>
  360. IntType LogUniform(URBG&& urbg, // NOLINT(runtime/references)
  361. IntType lo,
  362. IntType hi,
  363. IntType base = 2)
  364. {
  365. static_assert(random_internal::IsIntegral<IntType>::value, "Template-argument 'IntType' must be an integral type, in "
  366. "absl::LogUniform<IntType, URBG>(...)");
  367. using gen_t = absl::decay_t<URBG>;
  368. using distribution_t = typename absl::log_uniform_int_distribution<IntType>;
  369. return random_internal::DistributionCaller<gen_t>::template Call<
  370. distribution_t>(&urbg, lo, hi, base);
  371. }
  372. // -----------------------------------------------------------------------------
  373. // absl::Poisson<T>(bitgen, mean = 1)
  374. // -----------------------------------------------------------------------------
  375. //
  376. // `absl::Poisson` produces discrete probabilities for a given number of events
  377. // occurring within a fixed interval within the closed interval [0, max]. `T`
  378. // must be an integral type.
  379. //
  380. // See https://en.wikipedia.org/wiki/Poisson_distribution
  381. //
  382. // Example:
  383. //
  384. // absl::BitGen bitgen;
  385. // ...
  386. // int requests_per_minute = absl::Poisson<int>(bitgen, 3.2);
  387. //
  388. template<typename IntType, typename URBG>
  389. IntType Poisson(URBG&& urbg, // NOLINT(runtime/references)
  390. double mean = 1.0)
  391. {
  392. static_assert(random_internal::IsIntegral<IntType>::value, "Template-argument 'IntType' must be an integral type, in "
  393. "absl::Poisson<IntType, URBG>(...)");
  394. using gen_t = absl::decay_t<URBG>;
  395. using distribution_t = typename absl::poisson_distribution<IntType>;
  396. return random_internal::DistributionCaller<gen_t>::template Call<
  397. distribution_t>(&urbg, mean);
  398. }
  399. // -----------------------------------------------------------------------------
  400. // absl::Zipf<T>(bitgen, hi = max, q = 2, v = 1)
  401. // -----------------------------------------------------------------------------
  402. //
  403. // `absl::Zipf` produces discrete probabilities commonly used for modelling of
  404. // rare events over the closed interval [0, hi]. The parameters `v` and `q`
  405. // determine the skew of the distribution. `T` must be an integral type, but
  406. // may be inferred from the type of `hi`.
  407. //
  408. // See http://mathworld.wolfram.com/ZipfDistribution.html
  409. //
  410. // Example:
  411. //
  412. // absl::BitGen bitgen;
  413. // ...
  414. // int term_rank = absl::Zipf<int>(bitgen);
  415. //
  416. template<typename IntType, typename URBG>
  417. IntType Zipf(URBG&& urbg, // NOLINT(runtime/references)
  418. IntType hi = (std::numeric_limits<IntType>::max)(),
  419. double q = 2.0,
  420. double v = 1.0)
  421. {
  422. static_assert(random_internal::IsIntegral<IntType>::value, "Template-argument 'IntType' must be an integral type, in "
  423. "absl::Zipf<IntType, URBG>(...)");
  424. using gen_t = absl::decay_t<URBG>;
  425. using distribution_t = typename absl::zipf_distribution<IntType>;
  426. return random_internal::DistributionCaller<gen_t>::template Call<
  427. distribution_t>(&urbg, hi, q, v);
  428. }
  429. ABSL_NAMESPACE_END
  430. } // namespace absl
  431. #endif // ABSL_RANDOM_DISTRIBUTIONS_H_