1use std::{
2 ops::{
3 RangeInclusive,
4 },
6 borrow::Borrow,
7 num::NonZeroUsize
8};
9use paste::paste;
10use crate::HistogramVal;
11
12use super::{
13 Binning,
14 HasUnsignedVersion,
15 to_u,
16 from_u,
17 Bin,
18 HistogramPartition,
19 HistErrors,
20 HistogramCombine,
21 GenericHist,
22 Histogram
23};
24use num_bigint::BigUint;
25
26#[cfg(feature = "serde_support")]
27use serde::{Serialize, Deserialize};
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
33#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
34pub struct FastSingleIntBinning<T>{
35 start: T,
37 end_inclusive: T
39}
40
41
42macro_rules! impl_binning {
43 (
44 $t:ty
45 ) => {
46
47
48 paste::item! {
49 #[doc = "# Checked multiply divide\n\
50 The operation is: a * b / denominator.\n\n \
51 However this function guards against an overflow of a * b. \n\n As long as the mathematical result of a * b / denominator \
52 is representable as unsigned version of `<" $t " as HasUnsignedVersion>::Unsigned` then the mathematical answer is returned. Otherwise, None is returned\n\n ## Note: \n\n `denominator` is not allowed to be 0"]
53 pub fn [< checked_mul_div_ $t >] (
54 a: <$t as HasUnsignedVersion>::Unsigned,
55 b: <$t as HasUnsignedVersion>::Unsigned,
56 denominator: <$t as HasUnsignedVersion>::Unsigned
57 ) -> Option<<$t as HasUnsignedVersion>::Unsigned>
58 {
59
60 if let Some(val) = a.checked_mul(b){
61 return Some(val / denominator);
62 }
63
64 enum Answer{
65 Known(Option<<$t as HasUnsignedVersion>::Unsigned>),
66 Unknown
67 }
68
69 #[inline(always)]
70 fn mul_div(
71 mut a: <$t as HasUnsignedVersion>::Unsigned,
72 mut b: <$t as HasUnsignedVersion>::Unsigned,
73 denominator: <$t as HasUnsignedVersion>::Unsigned
74 ) -> Answer
75 {
76 if a < b {
77 std::mem::swap(&mut a, &mut b);
78 }
79 let left = match (a / denominator).checked_mul(b){
86 None => return Answer::Known(None),
87 Some(val) => val
88 };
89 let right_mul = match (a % denominator)
90 .checked_mul(b){
91 None => return Answer::Unknown,
92 Some(v) => v
93 };
94
95
96 let result = left.checked_add(right_mul / denominator);
97 Answer::Known(result)
98 }
99
100 match mul_div(a, b, denominator){
101 Answer::Known(res) => return res,
102 Answer::Unknown => {
103 let a: BigUint = a.into();
104 let b: BigUint = b.into();
105 let denominator: BigUint = denominator.into();
106 let res = a * b / denominator;
107 res.try_into().ok()
108 }
109 }
110
111 }
112 }
113
114 paste!{
115 #[doc = "Efficient binning for `" $t "` with bins of width 1"]
116 pub type [<FastBinning $t:upper>] = FastSingleIntBinning<$t>;
117 }
118
119 impl paste!{[<FastBinning $t:upper>]}{
120 #[inline(always)]
126 pub const fn new_inclusive(start: $t, end_inclusive: $t) -> Self{
127 assert!(start <= end_inclusive, "Start needs to be <= end_inclusive!");
128 Self {start, end_inclusive }
129 }
130
131 #[inline(always)]
133 pub const fn left(&self) -> $t {
134 self.start
135 }
136
137 #[inline(always)]
139 pub const fn right(&self) -> $t
140 {
141 self.end_inclusive
142 }
143
144 #[inline(always)]
146 pub const fn range_inclusive(&self) -> RangeInclusive<$t>
147 {
148 self.start..=self.end_inclusive
149 }
150
151 paste!{
152 #[doc = "# Iterator over all the bins\
153 \nSince the bins have width 1, a bin can be defined by its corresponding value \
154 which we can iterate over.\n\
155 # Example\n\
156 ```\n\
157 use sampling::histogram::" [<FastBinning $t:upper>] ";\n\
158 let binning = " [<FastBinning $t:upper>] "::new_inclusive(2,5);\n\
159 let vec: Vec<_> = binning.native_bin_iter().collect();\n\
160 assert_eq!(&vec, &[2, 3, 4, 5]);\n\
161 ```"]
162 pub fn native_bin_iter(&self) -> impl Iterator<Item=$t>
163 {
164 self.range_inclusive()
165 }
166 }
167
168 pub fn bins_m1(&self) -> <$t as HasUnsignedVersion>::Unsigned{
177 let left = to_u(self.start);
178 let right = to_u(self.end_inclusive);
179
180 right - left
181 }
182 }
183
184 impl paste!{[<FastBinning $t:upper>]}
185 {
186 #[inline(always)]
189 pub fn get_bin_index_native<V: Borrow<$t>>(&self, val: V) -> Option<<$t as HasUnsignedVersion>::Unsigned>{
190 let val = *val.borrow();
191 if self.is_inside(val)
192 {
193 Some(to_u(val) - to_u(self.start))
194 } else{
195 None
196 }
197 }
198 }
199
200 impl GenericHist<paste!{[<FastBinning $t:upper>]}, $t>{
201 pub fn bin_hits_iter(&'_ self) -> impl Iterator<Item=($t, usize)> + '_
206 {
207 self.binning()
208 .native_bin_iter()
209 .zip(self.hist().iter().copied())
210 }
211 }
212
213
214 impl Binning<$t> for paste!{[<FastBinning $t:upper>]} {
215 #[inline(always)]
216 fn get_bin_len(&self) -> usize
217 {
218 (self.bins_m1() as usize).saturating_add(1)
219 }
220
221 #[inline(always)]
224 fn get_bin_index<V: Borrow<$t>>(&self, val: V) -> Option<usize>{
225 self.get_bin_index_native(val)
226 .map(|v| v as usize)
227 }
228
229 #[inline(always)]
231 fn is_inside<V: Borrow<$t>>(&self, val: V) -> bool{
232 (self.start..=self.end_inclusive).contains(val.borrow())
233 }
234
235 #[inline(always)]
238 fn not_inside<V: Borrow<$t>>(&self, val: V) -> bool{
239 !self.is_inside(val)
240 }
241
242 #[inline(always)]
244 fn first_border(&self) -> $t{
245 self.start
246 }
247
248 #[inline(always)]
249 fn last_border(&self) -> $t{
250 self.end_inclusive
251 }
252
253 #[inline(always)]
254 fn last_border_is_inclusive(&self) -> bool
255 {
256 true
257 }
258
259 #[inline(always)]
262 fn distance<V: Borrow<$t>>(&self, v: V) -> f64{
263 let val = v.borrow();
264 if self.is_inside(val){
265 0.0
266 } else {
267 let dist = if *val < self.start {
268 to_u(self.start) - to_u(*val)
269 } else {
270 to_u(*val) - to_u(self.end_inclusive)
271 };
272 dist as f64
273 }
274 }
275
276 fn bin_iter(&self) -> Box<dyn Iterator<Item=Bin<$t>>>{
282 Box::new(
283 self.range_inclusive()
284 .map(|val| Bin::SingleValued(val))
285 )
286 }
287 }
288
289 impl HistogramPartition for paste!{[<FastBinning $t:upper>]}
290 {
291 paste!{
292 #[doc = "# partition the interval\
293 \n* returns Vector of `n` Binnings. Though `n` will be limited by the max value that `" $t "` can hold. \
294 ## parameter \n\
295 * `n` number of resulting intervals. \n\
296 * `overlap` How much overlap should there be? \n\
297 ## To understand overlap, we have to look at the formula for the i_th interval in the result vector: \n\
298 let ``left`` be the left border of ``self`` and ``right`` be the right border of self \n\
299 * left border of interval i = left + i * (right - left) / (n + overlap) \n\
300 * right border of interval i = left + (i + overlap) * (right - left) / (n + overlap) \n\
301 ## What is it for? \
302 \n * This is intended to create multiple overlapping intervals, e.g., for a Wang-Landau simulation\
303 \n # Note\
304 \n * Will fail if `overlap` + `n` are not representable as `" $t "`"]
305 fn overlapping_partition(&self, n: NonZeroUsize, overlap: usize) -> Result<Vec<Self>, HistErrors>
306 {
307 let mut result = Vec::with_capacity(n.get());
308 let right_minus_left = self.bins_m1();
309 let n_native = n.get() as <$t as HasUnsignedVersion>::Unsigned;
310 let overlap_native = overlap as <$t as HasUnsignedVersion>::Unsigned;
311 let denominator = n_native
312 .checked_add(overlap_native)
313 .ok_or(HistErrors::Overflow)?;
314 for c in 0..n_native {
315 let left_distance = paste::item! { [< checked_mul_div_ $t >] }(c, right_minus_left, denominator)
316 .ok_or(HistErrors::Overflow)?;
317 let left = to_u(self.start) + left_distance;
318
319 let right_sum = c.saturating_add(overlap_native)
320 .checked_add(1)
321 .ok_or(HistErrors::Overflow)?;
322
323 let right_distance = paste::item! { [< checked_mul_div_ $t >] }(right_sum, right_minus_left, denominator)
324 .ok_or(HistErrors::Overflow)?;
325 let right = to_u(self.start) + right_distance;
326
327 let left = from_u(left);
328 let right = from_u(right);
329
330 result.push(Self::new_inclusive(left, right));
331 }
332 debug_assert_eq!(
333 self.start,
334 result[0].start,
335 "eq1"
336 );
337 debug_assert_eq!(
338 self.end_inclusive,
339 result.last().unwrap().end_inclusive,
340 "eq2"
341 );
342 Ok(result)
343 }
344 }
345 }
346
347 impl HistogramCombine for GenericHist<paste!{[<FastBinning $t:upper>]}, $t>
348 {
349 fn align<S>(&self, right: S)-> Result<usize, HistErrors>
350 where S: Borrow<Self> {
351 let right = right.borrow();
352
353 self.get_bin_index(right.first_border())
354 }
355
356 fn encapsulating_hist<S>(hists: &[S]) -> Result<Self, HistErrors>
357 where S: Borrow<Self> {
358 if hists.is_empty(){
359 return Err(HistErrors::EmptySlice);
360 }
361 let first_binning = hists[0].borrow().binning();
362 let mut left = first_binning.first_border();
363 let mut right = first_binning.last_border();
364 for other in hists[1..].iter()
365 {
366 let binning = other.borrow().binning();
367 left = left.min(binning.first_border());
368 right = right.max(binning.last_border());
369
370 }
371 let outer_binning = <paste!{[<FastBinning $t:upper>]}>::new_inclusive(left, right);
372 let hist = GenericHist::new(outer_binning);
373 Ok(hist)
374 }
375 }
376 };
377 (
378 $($t:ty),* $(,)?
379 ) => {
380 $(
381 impl_binning!($t);
382 )*
383 }
384}
385
386impl_binning!(
387 u8,
388 i8,
389 u16,
390 i16,
391 u32,
392 i32,
393 u64,
394 i64,
395 u128,
396 i128,
397 usize,
398 isize
399);
400
401#[cfg(test)]
402mod tests{
403 use std::fmt::{Debug, Display};
404
405 use crate::GenericHist;
406 use super::*;
407 use crate::histogram::*;
408 use num_traits::{PrimInt, AsPrimitive};
409
410 fn hist_test_generic_all_inside<T>(left: T, right: T)
411 where FastSingleIntBinning::<T>: Binning::<T>,
412 GenericHist::<FastSingleIntBinning::<T>, T>: Histogram,
413 T: PrimInt,
414 std::ops::RangeInclusive<T>: Iterator<Item=T>,
415 {
416 let binning = FastSingleIntBinning::<T>{start: left, end_inclusive: right};
417 let mut hist =
418 GenericHist::<FastSingleIntBinning::<T>, T>::new(binning);
419
420 for (id, i) in (left..=right).enumerate() {
421 assert!(hist.is_inside(i));
422 assert_eq!(hist.is_inside(i), !hist.not_inside(i));
423 assert!(hist.get_bin_index(i).unwrap() == id);
424 assert_eq!(hist.distance(i), 0.0);
425 hist.count_val(i).unwrap();
426 }
427 assert_eq!(hist.bin_enum_iter().count(), hist.bin_count());
428 }
429
430 #[test]
431 fn hist_inside()
432 {
433 hist_test_generic_all_inside(-23i16, 31);
436 hist_test_generic_all_inside(1u8, 3u8);
437 hist_test_generic_all_inside(u8::MIN, u8::MAX);
440 hist_test_generic_all_inside(i8::MIN, i8::MAX);
441 hist_test_generic_all_inside(-100i8, 100i8);
442 }
443
444 fn hist_test_generic_all_outside_extensive<T>(left: T, right: T)
445 where FastSingleIntBinning::<T>: Binning::<T>,
446 GenericHist::<FastSingleIntBinning::<T>, T>: Histogram,
447 T: PrimInt,
448 std::ops::Range<T>: Iterator<Item=T>,
449 std::ops::RangeInclusive<T>: Iterator<Item=T>,
450 {
451 let binning = FastSingleIntBinning::<T>{start: left, end_inclusive: right};
452 let hist =
453 GenericHist::<FastSingleIntBinning::<T>, T>::new(binning);
454
455 for i in T::min_value()..left {
456 assert!(hist.not_inside(i));
457 assert_eq!(hist.is_inside(i), !hist.not_inside(i));
458 assert!(matches!(hist.get_bin_index(i), Err(HistErrors::OutsideHist)));
459 assert!(hist.distance(i) > 0.0);
460 }
461 for i in right+T::one()..=T::max_value() {
462 assert!(hist.not_inside(i));
463 assert_eq!(hist.is_inside(i), !hist.not_inside(i));
464 assert!(matches!(hist.get_bin_index(i), Err(HistErrors::OutsideHist)));
465 assert!(hist.distance(i) > 0.0);
466 }
467 assert_eq!(hist.bin_enum_iter().count(), hist.bin_count());
468 }
469
470 fn binning_all_outside_extensive<T>(left: T, right: T)
471 where FastSingleIntBinning::<T>: Binning::<T>,
472 T: PrimInt + Display,
473 std::ops::Range<T>: Iterator<Item=T>,
474 std::ops::RangeInclusive<T>: Iterator<Item=T> + Debug,
475 std::ops::RangeFrom<T>: Iterator<Item=T>,
476 {
477 let binning = FastSingleIntBinning::<T>{start: left, end_inclusive: right};
478
479 let mut last_dist = None;
480 for i in T::min_value()..left {
481 assert!(binning.not_inside(i));
482 assert_eq!(binning.is_inside(i), !binning.not_inside(i));
483 assert!(binning.get_bin_index(i).is_none());
484 let dist = binning.distance(i);
485 assert!(dist > 0.0);
486 match last_dist{
487 None => last_dist = Some(dist),
488 Some(d) => {
489 assert!(d > dist);
490 assert_eq!(d - 1.0, dist);
491 last_dist = Some(dist);
492 }
493 }
494 }
495 if let Some(d) = last_dist
496 {
497 assert_eq!(d, 1.0);
498 }
499
500 last_dist = None;
501 for (i, dist_counter) in (right+T::one()..=T::max_value()).zip(1_u64..) {
502 assert!(binning.not_inside(i));
503 assert_eq!(binning.is_inside(i), !binning.not_inside(i));
504 assert!(binning.get_bin_index(i).is_none());
505 let dist = binning.distance(i);
506 assert!(dist > 0.0);
507 println!("{i}, {:?}", right+T::one()..=T::max_value());
508 let dist_counter_float: f64 = dist_counter.as_();
509 assert_eq!(dist, dist_counter_float);
510 match last_dist{
511 None => last_dist = Some(dist),
512 Some(d) => {
513 assert!(d < dist);
514 last_dist = Some(dist);
515 }
516 }
517 }
518
519 let binning = FastSingleIntBinning::<T>{start: left, end_inclusive: left};
520 assert_eq!(binning.get_bin_len(), 1);
521 assert_eq!(binning.get_bin_index(left), Some(0));
522
523 }
524
525 #[test]
526 fn hist_outside()
527 {
528 hist_test_generic_all_outside_extensive(10u8, 20_u8);
529 hist_test_generic_all_outside_extensive(-100, 100_i8);
530 hist_test_generic_all_outside_extensive(-100, 100_i16);
531 hist_test_generic_all_outside_extensive(123, 299u16);
532 }
533
534 #[test]
535 fn binning_outside()
536 {
537 println!("0");
538 binning_all_outside_extensive(0u8, 0_u8);
539 println!("2");
540 binning_all_outside_extensive(10u8, 20_u8);
541 binning_all_outside_extensive(-100, 100_i8);
542 binning_all_outside_extensive(-100, 100_i16);
543 binning_all_outside_extensive(123, 299u16);
544 }
546
547 #[test]
548 fn check_mul_div()
549 {
550 fn check(a: u8, b: u8, denominator: u8) -> Option<u8>
551 {
552 (a as u128 * b as u128 / denominator as u128).try_into().ok()
553 }
554
555 for i in 0..255{
556 for j in 0..255{
557 for k in 1..255{
558 assert_eq!(
559 check(i,j,k),
560 checked_mul_div_u8(i,j,k),
561 "Error in {i} {j} {k}"
562 );
563 }
564 }
565 }
566 }
567
568 #[test]
569 fn mul_testing()
570 {
571 use rand_pcg::Pcg64Mcg;
572 use rand::SeedableRng;
573 use rand::distr::Uniform;
574 use rand::prelude::*;
575 macro_rules! mul_t {
576 (
577 $t:ty, $o:ty
578 ) => {
579
580 paste::item!{ fn [< mul_tests_ $t >]()
581 {
582 let mut rng = Pcg64Mcg::seed_from_u64(314668);
583 let uni_one = Uniform::new_inclusive(1, $t::MAX).unwrap();
584 let uni_all = Uniform::new_inclusive(0, $t::MAX).unwrap();
585 let max = <$t as HasUnsignedVersion>::Unsigned::MAX.into();
586 for _ in 0..100 {
587 let a = uni_all.sample(&mut rng);
588 let b = uni_all.sample(&mut rng);
589 let c = uni_one.sample(&mut rng);
590 let result: $o = a as $o * b as $o / c as $o;
591 let mul = paste::item! { [< checked_mul_div_ $t >]}(
592 a as <$t as HasUnsignedVersion>::Unsigned,
593 b as <$t as HasUnsignedVersion>::Unsigned,
594 c as <$t as HasUnsignedVersion>::Unsigned
595 );
596 if result <= max {
597 assert_eq!(
598 mul,
599 Some(result as <$t as HasUnsignedVersion>::Unsigned)
600 )
601 } else {
602 assert!(mul.is_none());
603 }
604 }
605 }
606 }
607 }
608 }
609 mul_t!(u8, u16);
610 mul_tests_u8();
611 mul_t!(u16, u64);
612 mul_tests_u16();
613 mul_t!(u32, u128);
614 mul_tests_u32();
615 mul_t!(i8, i16);
616 mul_tests_i8();
617 mul_t!(i32, i128);
618 mul_tests_i32();
619 }
620
621
622
623 #[test]
624 fn partion_test()
625 {
626 let n = NonZeroUsize::new(2).unwrap();
627 let h = FastBinningU8::new_inclusive(0, u8::MAX);
628 for overlap in 0..10{
629 let h_part = h.overlapping_partition(n, overlap).unwrap();
630 assert_eq!(h.first_border(), h_part[0].first_border());
631 assert_eq!(h.last_border(), h_part.last().unwrap().last_border());
632 }
633
634
635
636 let h = FastBinningI8::new_inclusive(i8::MIN, i8::MAX);
637 let h_part = h.overlapping_partition(n, 0).unwrap();
638 assert_eq!(h.first_border(), h_part[0].first_border());
639 assert_eq!(h.last_border(), h_part.last().unwrap().last_border());
640
641 let h = FastBinningI16::new_inclusive(i16::MIN, i16::MAX);
642 let h_part = h.overlapping_partition(n, 2).unwrap();
643 assert_eq!(h.first_border(), h_part[0].first_border());
644 assert_eq!(h.last_border(), h_part.last().unwrap().last_border());
645
646
647 let _ = h.overlapping_partition(NonZeroUsize::new(2000).unwrap(), 0).unwrap();
648 }
649
650 #[test]
651 fn overlapping_partition_test2()
652 {
653 use rand_pcg::Pcg64Mcg;
654 use rand::distr::Uniform;
655 use rand::prelude::*;
656 let mut rng = Pcg64Mcg::seed_from_u64(2314668);
657 let uni = Uniform::new_inclusive(-100, 100)
658 .unwrap();
659 for overlap in 0..=3 {
660 for i in 0..100 {
661 let (left, right) = loop {
662 let mut num_1 = uni.sample(&mut rng);
663 let mut num_2 = uni.sample(&mut rng);
664
665 if num_1 != num_2 {
666 if num_2 < num_1 {
667 std::mem::swap(&mut num_1, &mut num_2);
668 }
669 if (num_2 as isize - num_1 as isize) < (overlap as isize + 1) {
670 continue;
671 }
672 break (num_1, num_2)
673 }
674 };
675 println!("iteration {i}");
676 let hist_fast = FastBinningI8::new_inclusive(left, right);
677 let overlapping = hist_fast
678 .overlapping_partition(NonZeroUsize::new(3).unwrap(), overlap)
679 .unwrap();
680
681 assert_eq!(
682 overlapping.last().unwrap().last_border(),
683 hist_fast.last_border(),
684 "overlapping_partition_test2 - last border check"
685 );
686
687 assert_eq!(
688 overlapping.first().unwrap().first_border(),
689 hist_fast.first_border(),
690 "overlapping_partition_test2 - first border check"
691 );
692
693 for slice in overlapping.windows(2){
694 assert!(
695 slice[0].first_border() <= slice[1].first_border()
696 );
697 assert!(
698 slice[0].last_border() <= slice[1].last_border()
699 );
700 }
701 }
702 }
703 }
704
705 #[test]
706 fn hist_combine()
707 {
708 let binning_left = FastBinningI8::new_inclusive(-5, 0);
709 let binning_right = FastBinningI8::new_inclusive(-1, 2);
710 let left = GenericHist::new(binning_left);
711 let right = GenericHist::new(binning_right);
712
713 let encapsulating = GenericHist::encapsulating_hist(&[&left, &right]).unwrap();
714 let enc_binning = encapsulating.binning();
715 assert_eq!(enc_binning.first_border(), binning_left.first_border());
716 assert_eq!(enc_binning.last_border(), binning_right.last_border());
717 assert_eq!(encapsulating.bin_count(), 8);
718
719 let align = left.align(right).unwrap();
720
721 assert_eq!(align, 4);
722
723 let left = FastBinningI8::new_inclusive(i8::MIN, 0)
724 .to_generic_hist();
725 let right = FastBinningI8::new_inclusive(0, i8::MAX)
726 .to_generic_hist();
727
728 let en = GenericHist::encapsulating_hist(&[&left, &right]).unwrap();
729
730 assert_eq!(en.bin_count(), 256);
731
732 let align = left.align(right).unwrap();
733
734 assert_eq!(128, align);
735
736 let left = FastBinningI8::new_inclusive(i8::MIN, i8::MAX)
737 .to_generic_hist();
738 let small = FastBinningI8::new_inclusive(127, 127)
739 .to_generic_hist();
740
741 let align = left.align(&small).unwrap();
742
743 assert_eq!(255, align);
744
745 let en = GenericHist::encapsulating_hist(&[&left]).unwrap();
746 assert_eq!(en.bin_count(), 256);
747 let slice = [&left];
748 let en = GenericHist::encapsulating_hist(&slice[1..]);
749 assert_eq!(en.err(), Some(HistErrors::EmptySlice));
750 let en = GenericHist::encapsulating_hist(&[small, left]).unwrap();
751
752 assert_eq!(en.bin_count(), 256);
753 }
754
755 #[test]
756 fn unit_test_distance()
757 {
758 let binning = FastBinningI8::new_inclusive(-50, 50);
760
761 let mut dist = binning.distance(i8::MIN);
762 for i in i8::MIN+1..-50{
763 let new_dist = binning.distance(i);
764 assert!(dist > new_dist);
765 dist = new_dist;
766 }
767 for i in -50..=50{
768 assert_eq!(binning.distance(i), 0.0);
769 }
770 dist = 0.0;
771 for i in 51..=i8::MAX{
772 let new_dist = binning.distance(i);
773 assert!(dist < new_dist);
774 dist = new_dist;
775 }
776 }
777}