1use crate::HistogramVal;
2use paste::paste;
3use std::{
4 borrow::Borrow,
5 num::NonZeroUsize,
6 ops::{
7 RangeInclusive,
8 },
10};
11
12use super::{
13 from_u, to_u, Bin, Binning, GenericHist, HasUnsignedVersion, HistErrors, Histogram,
14 HistogramCombine, HistogramPartition,
15};
16use num_bigint::BigUint;
17
18#[cfg(feature = "serde_support")]
19use serde::{Deserialize, Serialize};
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
25#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
26pub struct FastSingleIntBinning<T> {
27 start: T,
29 end_inclusive: T,
31}
32
33macro_rules! impl_binning {
34 (
35 $t:ty
36 ) => {
37
38
39 paste::item! {
40 #[doc = "# Checked multiply divide\n\
41 The operation is: a * b / denominator.\n\n \
42 However this function guards against an overflow of a * b. \n\n As long as the mathematical result of a * b / denominator \
43 is representable as unsigned version of `<" $t " as HasUnsignedVersion>::Unsigned` then the mathematical answer is returned. Otherwise, None is returned\n\n ## Note: \n\n `denominator` is not allowed to be 0"]
44 pub fn [< checked_mul_div_ $t >] (
45 a: <$t as HasUnsignedVersion>::Unsigned,
46 b: <$t as HasUnsignedVersion>::Unsigned,
47 denominator: <$t as HasUnsignedVersion>::Unsigned
48 ) -> Option<<$t as HasUnsignedVersion>::Unsigned>
49 {
50
51 if let Some(val) = a.checked_mul(b){
52 return Some(val / denominator);
53 }
54
55 enum Answer{
56 Known(Option<<$t as HasUnsignedVersion>::Unsigned>),
57 Unknown
58 }
59
60 #[inline(always)]
61 fn mul_div(
62 mut a: <$t as HasUnsignedVersion>::Unsigned,
63 mut b: <$t as HasUnsignedVersion>::Unsigned,
64 denominator: <$t as HasUnsignedVersion>::Unsigned
65 ) -> Answer
66 {
67 if a < b {
68 std::mem::swap(&mut a, &mut b);
69 }
70 let left = match (a / denominator).checked_mul(b){
77 None => return Answer::Known(None),
78 Some(val) => val
79 };
80 let right_mul = match (a % denominator)
81 .checked_mul(b){
82 None => return Answer::Unknown,
83 Some(v) => v
84 };
85
86
87 let result = left.checked_add(right_mul / denominator);
88 Answer::Known(result)
89 }
90
91 match mul_div(a, b, denominator){
92 Answer::Known(res) => return res,
93 Answer::Unknown => {
94 let a: BigUint = a.into();
95 let b: BigUint = b.into();
96 let denominator: BigUint = denominator.into();
97 let res = a * b / denominator;
98 res.try_into().ok()
99 }
100 }
101
102 }
103 }
104
105 paste!{
106 #[doc = "Efficient binning for `" $t "` with bins of width 1"]
107 pub type [<FastBinning $t:upper>] = FastSingleIntBinning<$t>;
108 }
109
110 impl paste!{[<FastBinning $t:upper>]}{
111 #[inline(always)]
117 pub const fn new_inclusive(start: $t, end_inclusive: $t) -> Self{
118 assert!(start <= end_inclusive, "Start needs to be <= end_inclusive!");
119 Self {start, end_inclusive }
120 }
121
122 #[inline(always)]
124 pub const fn left(&self) -> $t {
125 self.start
126 }
127
128 #[inline(always)]
130 pub const fn right(&self) -> $t
131 {
132 self.end_inclusive
133 }
134
135 #[inline(always)]
137 pub const fn range_inclusive(&self) -> RangeInclusive<$t>
138 {
139 self.start..=self.end_inclusive
140 }
141
142 paste!{
143 #[doc = "# Iterator over all the bins\
144 \nSince the bins have width 1, a bin can be defined by its corresponding value \
145 which we can iterate over.\n\
146 # Example\n\
147 ```\n\
148 use sampling::histogram::" [<FastBinning $t:upper>] ";\n\
149 let binning = " [<FastBinning $t:upper>] "::new_inclusive(2,5);\n\
150 let vec: Vec<_> = binning.native_bin_iter().collect();\n\
151 assert_eq!(&vec, &[2, 3, 4, 5]);\n\
152 ```"]
153 pub fn native_bin_iter(&self) -> impl Iterator<Item=$t>
154 {
155 self.range_inclusive()
156 }
157 }
158
159 pub fn bins_m1(&self) -> <$t as HasUnsignedVersion>::Unsigned{
168 let left = to_u(self.start);
169 let right = to_u(self.end_inclusive);
170
171 right - left
172 }
173 }
174
175 impl paste!{[<FastBinning $t:upper>]}
176 {
177 #[inline(always)]
180 pub fn get_bin_index_native<V: Borrow<$t>>(&self, val: V) -> Option<<$t as HasUnsignedVersion>::Unsigned>{
181 let val = *val.borrow();
182 if self.is_inside(val)
183 {
184 Some(to_u(val) - to_u(self.start))
185 } else{
186 None
187 }
188 }
189 }
190
191 impl GenericHist<paste!{[<FastBinning $t:upper>]}, $t>{
192 pub fn bin_hits_iter(&'_ self) -> impl Iterator<Item=($t, usize)> + '_
197 {
198 self.binning()
199 .native_bin_iter()
200 .zip(self.hist().iter().copied())
201 }
202 }
203
204
205 impl Binning<$t> for paste!{[<FastBinning $t:upper>]} {
206 #[inline(always)]
207 fn get_bin_len(&self) -> usize
208 {
209 (self.bins_m1() as usize).saturating_add(1)
210 }
211
212 #[inline(always)]
215 fn get_bin_index<V: Borrow<$t>>(&self, val: V) -> Option<usize>{
216 self.get_bin_index_native(val)
217 .map(|v| v as usize)
218 }
219
220 #[inline(always)]
222 fn is_inside<V: Borrow<$t>>(&self, val: V) -> bool{
223 (self.start..=self.end_inclusive).contains(val.borrow())
224 }
225
226 #[inline(always)]
229 fn not_inside<V: Borrow<$t>>(&self, val: V) -> bool{
230 !self.is_inside(val)
231 }
232
233 #[inline(always)]
235 fn first_border(&self) -> $t{
236 self.start
237 }
238
239 #[inline(always)]
240 fn last_border(&self) -> $t{
241 self.end_inclusive
242 }
243
244 #[inline(always)]
245 fn last_border_is_inclusive(&self) -> bool
246 {
247 true
248 }
249
250 #[inline(always)]
253 fn distance<V: Borrow<$t>>(&self, v: V) -> f64{
254 let val = v.borrow();
255 if self.is_inside(val){
256 0.0
257 } else {
258 let dist = if *val < self.start {
259 to_u(self.start) - to_u(*val)
260 } else {
261 to_u(*val) - to_u(self.end_inclusive)
262 };
263 dist as f64
264 }
265 }
266
267 fn bin_iter(&self) -> Box<dyn Iterator<Item=Bin<$t>>>{
273 Box::new(
274 self.range_inclusive()
275 .map(|val| Bin::SingleValued(val))
276 )
277 }
278 }
279
280 impl HistogramPartition for paste!{[<FastBinning $t:upper>]}
281 {
282 paste!{
283 #[doc = "# partition the interval\
284 \n* returns Vector of `n` Binnings. Though `n` will be limited by the max value that `" $t "` can hold. \
285 ## parameter \n\
286 * `n` number of resulting intervals. \n\
287 * `overlap` How much overlap should there be? \n\
288 ## To understand overlap, we have to look at the formula for the i_th interval in the result vector: \n\
289 let ``left`` be the left border of ``self`` and ``right`` be the right border of self \n\
290 * left border of interval i = left + i * (right - left) / (n + overlap) \n\
291 * right border of interval i = left + (i + overlap) * (right - left) / (n + overlap) \n\
292 ## What is it for? \
293 \n * This is intended to create multiple overlapping intervals, e.g., for a Wang-Landau simulation\
294 \n # Note\
295 \n * Will fail if `overlap` + `n` are not representable as `" $t "`"]
296 fn overlapping_partition(&self, n: NonZeroUsize, overlap: usize) -> Result<Vec<Self>, HistErrors>
297 {
298 let mut result = Vec::with_capacity(n.get());
299 let right_minus_left = self.bins_m1();
300 let n_native = n.get() as <$t as HasUnsignedVersion>::Unsigned;
301 let overlap_native = overlap as <$t as HasUnsignedVersion>::Unsigned;
302 let denominator = n_native
303 .checked_add(overlap_native)
304 .ok_or(HistErrors::Overflow)?;
305 for c in 0..n_native {
306 let left_distance = paste::item! { [< checked_mul_div_ $t >] }(c, right_minus_left, denominator)
307 .ok_or(HistErrors::Overflow)?;
308 let left = to_u(self.start) + left_distance;
309
310 let right_sum = c.saturating_add(overlap_native)
311 .checked_add(1)
312 .ok_or(HistErrors::Overflow)?;
313
314 let right_distance = paste::item! { [< checked_mul_div_ $t >] }(right_sum, right_minus_left, denominator)
315 .ok_or(HistErrors::Overflow)?;
316 let right = to_u(self.start) + right_distance;
317
318 let left = from_u(left);
319 let right = from_u(right);
320
321 result.push(Self::new_inclusive(left, right));
322 }
323 debug_assert_eq!(
324 self.start,
325 result[0].start,
326 "eq1"
327 );
328 debug_assert_eq!(
329 self.end_inclusive,
330 result.last().unwrap().end_inclusive,
331 "eq2"
332 );
333 Ok(result)
334 }
335 }
336 }
337
338 impl HistogramCombine for GenericHist<paste!{[<FastBinning $t:upper>]}, $t>
339 {
340 fn align<S>(&self, right: S)-> Result<usize, HistErrors>
341 where S: Borrow<Self> {
342 let right = right.borrow();
343
344 self.get_bin_index(right.first_border())
345 }
346
347 fn encapsulating_hist<S>(hists: &[S]) -> Result<Self, HistErrors>
348 where S: Borrow<Self> {
349 if hists.is_empty(){
350 return Err(HistErrors::EmptySlice);
351 }
352 let first_binning = hists[0].borrow().binning();
353 let mut left = first_binning.first_border();
354 let mut right = first_binning.last_border();
355 for other in hists[1..].iter()
356 {
357 let binning = other.borrow().binning();
358 left = left.min(binning.first_border());
359 right = right.max(binning.last_border());
360
361 }
362 let outer_binning = <paste!{[<FastBinning $t:upper>]}>::new_inclusive(left, right);
363 let hist = GenericHist::new(outer_binning);
364 Ok(hist)
365 }
366 }
367 };
368 (
369 $($t:ty),* $(,)?
370 ) => {
371 $(
372 impl_binning!($t);
373 )*
374 }
375}
376
377impl_binning!(u8, i8, u16, i16, u32, i32, u64, i64, u128, i128, usize, isize);
378
379#[cfg(test)]
380mod tests {
381 use std::fmt::{Debug, Display};
382
383 use super::*;
384 use crate::histogram::*;
385 use crate::GenericHist;
386 use num_traits::{AsPrimitive, PrimInt};
387
388 fn hist_test_generic_all_inside<T>(left: T, right: T)
389 where
390 FastSingleIntBinning<T>: Binning<T>,
391 GenericHist<FastSingleIntBinning<T>, T>: Histogram,
392 T: PrimInt,
393 std::ops::RangeInclusive<T>: Iterator<Item = T>,
394 {
395 let binning = FastSingleIntBinning::<T> {
396 start: left,
397 end_inclusive: right,
398 };
399 let mut hist = GenericHist::<FastSingleIntBinning<T>, T>::new(binning);
400
401 for (id, i) in (left..=right).enumerate() {
402 assert!(hist.is_inside(i));
403 assert_eq!(hist.is_inside(i), !hist.not_inside(i));
404 assert!(hist.get_bin_index(i).unwrap() == id);
405 assert_eq!(hist.distance(i), 0.0);
406 hist.count_val(i).unwrap();
407 }
408 assert_eq!(hist.bin_enum_iter().count(), hist.bin_count());
409 }
410
411 #[test]
412 fn hist_inside() {
413 hist_test_generic_all_inside(-23i16, 31);
416 hist_test_generic_all_inside(1u8, 3u8);
417 hist_test_generic_all_inside(u8::MIN, u8::MAX);
420 hist_test_generic_all_inside(i8::MIN, i8::MAX);
421 hist_test_generic_all_inside(-100i8, 100i8);
422 }
423
424 fn hist_test_generic_all_outside_extensive<T>(left: T, right: T)
425 where
426 FastSingleIntBinning<T>: Binning<T>,
427 GenericHist<FastSingleIntBinning<T>, T>: Histogram,
428 T: PrimInt,
429 std::ops::Range<T>: Iterator<Item = T>,
430 std::ops::RangeInclusive<T>: Iterator<Item = T>,
431 {
432 let binning = FastSingleIntBinning::<T> {
433 start: left,
434 end_inclusive: right,
435 };
436 let hist = GenericHist::<FastSingleIntBinning<T>, T>::new(binning);
437
438 for i in T::min_value()..left {
439 assert!(hist.not_inside(i));
440 assert_eq!(hist.is_inside(i), !hist.not_inside(i));
441 assert!(matches!(
442 hist.get_bin_index(i),
443 Err(HistErrors::OutsideHist)
444 ));
445 assert!(hist.distance(i) > 0.0);
446 }
447 for i in right + T::one()..=T::max_value() {
448 assert!(hist.not_inside(i));
449 assert_eq!(hist.is_inside(i), !hist.not_inside(i));
450 assert!(matches!(
451 hist.get_bin_index(i),
452 Err(HistErrors::OutsideHist)
453 ));
454 assert!(hist.distance(i) > 0.0);
455 }
456 assert_eq!(hist.bin_enum_iter().count(), hist.bin_count());
457 }
458
459 fn binning_all_outside_extensive<T>(left: T, right: T)
460 where
461 FastSingleIntBinning<T>: Binning<T>,
462 T: PrimInt + Display,
463 std::ops::Range<T>: Iterator<Item = T>,
464 std::ops::RangeInclusive<T>: Iterator<Item = T> + Debug,
465 std::ops::RangeFrom<T>: Iterator<Item = T>,
466 {
467 let binning = FastSingleIntBinning::<T> {
468 start: left,
469 end_inclusive: right,
470 };
471
472 let mut last_dist = None;
473 for i in T::min_value()..left {
474 assert!(binning.not_inside(i));
475 assert_eq!(binning.is_inside(i), !binning.not_inside(i));
476 assert!(binning.get_bin_index(i).is_none());
477 let dist = binning.distance(i);
478 assert!(dist > 0.0);
479 match last_dist {
480 None => last_dist = Some(dist),
481 Some(d) => {
482 assert!(d > dist);
483 assert_eq!(d - 1.0, dist);
484 last_dist = Some(dist);
485 }
486 }
487 }
488 if let Some(d) = last_dist {
489 assert_eq!(d, 1.0);
490 }
491
492 last_dist = None;
493 for (i, dist_counter) in (right + T::one()..=T::max_value()).zip(1_u64..) {
494 assert!(binning.not_inside(i));
495 assert_eq!(binning.is_inside(i), !binning.not_inside(i));
496 assert!(binning.get_bin_index(i).is_none());
497 let dist = binning.distance(i);
498 assert!(dist > 0.0);
499 println!("{i}, {:?}", right + T::one()..=T::max_value());
500 let dist_counter_float: f64 = dist_counter.as_();
501 assert_eq!(dist, dist_counter_float);
502 match last_dist {
503 None => last_dist = Some(dist),
504 Some(d) => {
505 assert!(d < dist);
506 last_dist = Some(dist);
507 }
508 }
509 }
510
511 let binning = FastSingleIntBinning::<T> {
512 start: left,
513 end_inclusive: left,
514 };
515 assert_eq!(binning.get_bin_len(), 1);
516 assert_eq!(binning.get_bin_index(left), Some(0));
517 }
518
519 #[test]
520 fn hist_outside() {
521 hist_test_generic_all_outside_extensive(10u8, 20_u8);
522 hist_test_generic_all_outside_extensive(-100, 100_i8);
523 hist_test_generic_all_outside_extensive(-100, 100_i16);
524 hist_test_generic_all_outside_extensive(123, 299u16);
525 }
526
527 #[test]
528 fn binning_outside() {
529 println!("0");
530 binning_all_outside_extensive(0u8, 0_u8);
531 println!("2");
532 binning_all_outside_extensive(10u8, 20_u8);
533 binning_all_outside_extensive(-100, 100_i8);
534 binning_all_outside_extensive(-100, 100_i16);
535 binning_all_outside_extensive(123, 299u16);
536 }
538
539 #[test]
540 fn check_mul_div() {
541 fn check(a: u8, b: u8, denominator: u8) -> Option<u8> {
542 (a as u128 * b as u128 / denominator as u128)
543 .try_into()
544 .ok()
545 }
546
547 for i in 0..255 {
548 for j in 0..255 {
549 for k in 1..255 {
550 assert_eq!(
551 check(i, j, k),
552 checked_mul_div_u8(i, j, k),
553 "Error in {i} {j} {k}"
554 );
555 }
556 }
557 }
558 }
559
560 #[test]
561 fn mul_testing() {
562 use rand::distr::Uniform;
563 use rand::prelude::*;
564 use rand::SeedableRng;
565 use rand_pcg::Pcg64Mcg;
566 macro_rules! mul_t {
567 (
568 $t:ty, $o:ty
569 ) => {
570 paste::item! { fn [< mul_tests_ $t >]()
571 {
572 let mut rng = Pcg64Mcg::seed_from_u64(314668);
573 let uni_one = Uniform::new_inclusive(1, $t::MAX).unwrap();
574 let uni_all = Uniform::new_inclusive(0, $t::MAX).unwrap();
575 let max = <$t as HasUnsignedVersion>::Unsigned::MAX.into();
576 for _ in 0..100 {
577 let a = uni_all.sample(&mut rng);
578 let b = uni_all.sample(&mut rng);
579 let c = uni_one.sample(&mut rng);
580 let result: $o = a as $o * b as $o / c as $o;
581 let mul = paste::item! { [< checked_mul_div_ $t >]}(
582 a as <$t as HasUnsignedVersion>::Unsigned,
583 b as <$t as HasUnsignedVersion>::Unsigned,
584 c as <$t as HasUnsignedVersion>::Unsigned
585 );
586 if result <= max {
587 assert_eq!(
588 mul,
589 Some(result as <$t as HasUnsignedVersion>::Unsigned)
590 )
591 } else {
592 assert!(mul.is_none());
593 }
594 }
595 }
596 }
597 };
598 }
599 mul_t!(u8, u16);
600 mul_tests_u8();
601 mul_t!(u16, u64);
602 mul_tests_u16();
603 mul_t!(u32, u128);
604 mul_tests_u32();
605 mul_t!(i8, i16);
606 mul_tests_i8();
607 mul_t!(i32, i128);
608 mul_tests_i32();
609 }
610
611 #[test]
612 fn partion_test() {
613 let n = NonZeroUsize::new(2).unwrap();
614 let h = FastBinningU8::new_inclusive(0, u8::MAX);
615 for overlap in 0..10 {
616 let h_part = h.overlapping_partition(n, overlap).unwrap();
617 assert_eq!(h.first_border(), h_part[0].first_border());
618 assert_eq!(h.last_border(), h_part.last().unwrap().last_border());
619 }
620
621 let h = FastBinningI8::new_inclusive(i8::MIN, i8::MAX);
622 let h_part = h.overlapping_partition(n, 0).unwrap();
623 assert_eq!(h.first_border(), h_part[0].first_border());
624 assert_eq!(h.last_border(), h_part.last().unwrap().last_border());
625
626 let h = FastBinningI16::new_inclusive(i16::MIN, i16::MAX);
627 let h_part = h.overlapping_partition(n, 2).unwrap();
628 assert_eq!(h.first_border(), h_part[0].first_border());
629 assert_eq!(h.last_border(), h_part.last().unwrap().last_border());
630
631 let _ = h
632 .overlapping_partition(NonZeroUsize::new(2000).unwrap(), 0)
633 .unwrap();
634 }
635
636 #[test]
637 fn overlapping_partition_test2() {
638 use rand::distr::Uniform;
639 use rand::prelude::*;
640 use rand_pcg::Pcg64Mcg;
641 let mut rng = Pcg64Mcg::seed_from_u64(2314668);
642 let uni = Uniform::new_inclusive(-100, 100).unwrap();
643 for overlap in 0..=3 {
644 for i in 0..100 {
645 let (left, right) = loop {
646 let mut num_1 = uni.sample(&mut rng);
647 let mut num_2 = uni.sample(&mut rng);
648
649 if num_1 != num_2 {
650 if num_2 < num_1 {
651 std::mem::swap(&mut num_1, &mut num_2);
652 }
653 if (num_2 as isize - num_1 as isize) < (overlap as isize + 1) {
654 continue;
655 }
656 break (num_1, num_2);
657 }
658 };
659 println!("iteration {i}");
660 let hist_fast = FastBinningI8::new_inclusive(left, right);
661 let overlapping = hist_fast
662 .overlapping_partition(NonZeroUsize::new(3).unwrap(), overlap)
663 .unwrap();
664
665 assert_eq!(
666 overlapping.last().unwrap().last_border(),
667 hist_fast.last_border(),
668 "overlapping_partition_test2 - last border check"
669 );
670
671 assert_eq!(
672 overlapping.first().unwrap().first_border(),
673 hist_fast.first_border(),
674 "overlapping_partition_test2 - first border check"
675 );
676
677 for slice in overlapping.windows(2) {
678 assert!(slice[0].first_border() <= slice[1].first_border());
679 assert!(slice[0].last_border() <= slice[1].last_border());
680 }
681 }
682 }
683 }
684
685 #[test]
686 fn hist_combine() {
687 let binning_left = FastBinningI8::new_inclusive(-5, 0);
688 let binning_right = FastBinningI8::new_inclusive(-1, 2);
689 let left = GenericHist::new(binning_left);
690 let right = GenericHist::new(binning_right);
691
692 let encapsulating = GenericHist::encapsulating_hist(&[&left, &right]).unwrap();
693 let enc_binning = encapsulating.binning();
694 assert_eq!(enc_binning.first_border(), binning_left.first_border());
695 assert_eq!(enc_binning.last_border(), binning_right.last_border());
696 assert_eq!(encapsulating.bin_count(), 8);
697
698 let align = left.align(right).unwrap();
699
700 assert_eq!(align, 4);
701
702 let left = FastBinningI8::new_inclusive(i8::MIN, 0).to_generic_hist();
703 let right = FastBinningI8::new_inclusive(0, i8::MAX).to_generic_hist();
704
705 let en = GenericHist::encapsulating_hist(&[&left, &right]).unwrap();
706
707 assert_eq!(en.bin_count(), 256);
708
709 let align = left.align(right).unwrap();
710
711 assert_eq!(128, align);
712
713 let left = FastBinningI8::new_inclusive(i8::MIN, i8::MAX).to_generic_hist();
714 let small = FastBinningI8::new_inclusive(127, 127).to_generic_hist();
715
716 let align = left.align(&small).unwrap();
717
718 assert_eq!(255, align);
719
720 let en = GenericHist::encapsulating_hist(&[&left]).unwrap();
721 assert_eq!(en.bin_count(), 256);
722 let slice = [&left];
723 let en = GenericHist::encapsulating_hist(&slice[1..]);
724 assert_eq!(en.err(), Some(HistErrors::EmptySlice));
725 let en = GenericHist::encapsulating_hist(&[small, left]).unwrap();
726
727 assert_eq!(en.bin_count(), 256);
728 }
729
730 #[test]
731 fn unit_test_distance() {
732 let binning = FastBinningI8::new_inclusive(-50, 50);
734
735 let mut dist = binning.distance(i8::MIN);
736 for i in i8::MIN + 1..-50 {
737 let new_dist = binning.distance(i);
738 assert!(dist > new_dist);
739 dist = new_dist;
740 }
741 for i in -50..=50 {
742 assert_eq!(binning.distance(i), 0.0);
743 }
744 dist = 0.0;
745 for i in 51..=i8::MAX {
746 let new_dist = binning.distance(i);
747 assert!(dist < new_dist);
748 dist = new_dist;
749 }
750 }
751}