1use super::binning::BinDisplay;
2use {
3 crate::histogram::*,
4 num_traits::{
5 cast::*,
6 identities::*,
7 int::*,
8 ops::{checked::*, wrapping::*},
9 Bounded,
10 },
11 std::{borrow::*, num::*, ops::*},
12};
13
14#[cfg(feature = "serde_support")]
15use serde::{Deserialize, Serialize};
16
17#[derive(Debug, PartialEq, Eq, Clone, Copy)]
19pub enum Outcome {
20 Success,
22 Failure,
24}
25
26impl Outcome {
27 pub fn is_success(self) -> bool {
29 self == Outcome::Success
30 }
31
32 pub fn is_failure(self) -> bool {
34 self == Outcome::Failure
35 }
36}
37
38#[derive(Debug, Clone)]
42#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
43pub struct HistogramFast<T> {
44 left: T,
45 right: T,
46 hist: Vec<usize>,
47}
48
49impl<T> BinDisplay for HistogramFast<T>
50where
51 T: PrimInt + HasUnsignedVersion + Copy + std::fmt::Display + WrappingAdd,
52 T::Unsigned: Bounded
53 + HasUnsignedVersion<LeBytes = T::LeBytes>
54 + WrappingAdd
55 + ToPrimitive
56 + Sub<Output = T::Unsigned>,
57{
58 type BinEntry = T;
59
60 fn display_bin_iter(&'_ self) -> Box<dyn Iterator<Item = Self::BinEntry> + '_> {
61 Box::new(self.bin_iter())
62 }
63
64 fn write_bin<W: std::io::Write>(entry: &Self::BinEntry, mut writer: W) -> std::io::Result<()> {
65 write!(writer, "{entry}")
66 }
67
68 fn write_header<W: std::io::Write>(&self, mut writer: W) -> std::io::Result<()> {
69 write!(writer, "Bin")
70 }
71}
72
73impl<T> HistogramFast<T>
74where
75 T: Copy,
76{
77 pub fn left(&self) -> T {
79 self.left
80 }
81
82 pub fn right(&self) -> T {
84 self.right
85 }
86
87 pub fn range_inclusive(&self) -> RangeInclusive<T> {
89 self.left..=self.right
90 }
91}
92
93impl<T> HistogramFast<T>
94where
95 T: PrimInt + HasUnsignedVersion + WrappingAdd,
96 T::Unsigned: Bounded
97 + HasUnsignedVersion<LeBytes = T::LeBytes>
98 + WrappingAdd
99 + ToPrimitive
100 + Sub<Output = T::Unsigned>,
101{
102 pub fn new(left: T, right: T) -> Result<Self, HistErrors> {
106 let right = match right.checked_sub(&T::one()) {
107 Some(res) => res,
108 None => return Err(HistErrors::Underflow),
109 };
110 Self::new_inclusive(left, right)
111 }
112
113 pub fn new_inclusive(left: T, right: T) -> Result<Self, HistErrors> {
118 if left > right {
119 Err(HistErrors::OutsideHist)
120 } else {
121 let left_u = to_u(left);
122 let right_u = to_u(right);
123 let size = match (right_u - left_u).to_usize() {
124 Some(val) => match val.checked_add(1) {
125 Some(val) => val,
126 None => return Err(HistErrors::Overflow),
127 },
128 None => return Err(HistErrors::UsizeCastError),
129 };
130
131 Ok(Self {
132 left,
133 right,
134 hist: vec![0; size],
135 })
136 }
137 }
138
139 pub fn bin_iter(&self) -> impl Iterator<Item = T> {
154 HistFastIterHelper {
155 current: self.left,
156 right: self.right,
157 invalid: false,
158 }
159 }
160
161 pub fn bin_hits_iter(&'_ self) -> impl Iterator<Item = (T, usize)> + '_ {
180 self.bin_iter().zip(self.hist.iter().copied())
181 }
182
183 pub fn equal_range(&self, other: &Self) -> bool
186 where
187 T: Eq,
188 {
189 self.left.eq(&other.left) && self.right.eq(&other.right)
190 }
191
192 pub fn try_add(&mut self, other: &Self) -> Outcome
199 where
200 T: Eq,
201 {
202 if self.equal_range(other) {
203 self.hist
204 .iter_mut()
205 .zip(other.hist().iter())
206 .for_each(|(s, o)| *s += o);
207 Outcome::Success
208 } else {
209 Outcome::Failure
210 }
211 }
212
213 #[inline]
214 pub fn increment<V: Borrow<T>>(&mut self, val: V) -> Result<usize, HistErrors> {
221 self.count_val(val)
222 }
223
224 #[inline]
225 pub fn increment_quiet<V: Borrow<T>>(&mut self, val: V) {
229 let _ = self.increment(val);
230 }
231}
232
233pub(crate) struct HistFastIterHelper<T> {
234 pub(crate) current: T,
235 pub(crate) right: T,
236 pub(crate) invalid: bool,
237}
238
239impl<T> Iterator for HistFastIterHelper<T>
240where
241 T: PrimInt + WrappingAdd,
242{
243 type Item = T;
244
245 #[inline]
246 fn next(&mut self) -> Option<T> {
247 if self.invalid {
248 return None;
249 }
250
251 let next = self.current.wrapping_add(&T::one());
252 let current = std::mem::replace(&mut self.current, next);
253 self.invalid = current == self.right;
254 Some(current)
255 }
256}
257
258pub(crate) struct BinModIterHelper<T> {
259 pub(crate) current: T,
260 pub(crate) right: T,
261 pub(crate) step_by: T,
262 pub(crate) invalid: bool,
263}
264
265impl<T> BinModIterHelper<T> {
266 pub(crate) fn new_unchecked(left: T, right: T, step_by: T) -> Self {
267 Self {
268 current: left,
269 right,
270 step_by,
271 invalid: false,
272 }
273 }
274}
275
276impl<T> Iterator for BinModIterHelper<T>
277where
278 T: Add<T, Output = T> + Ord + Copy + WrappingAdd + WrappingSub + One,
279{
280 type Item = RangeInclusive<T>;
281
282 #[inline]
283 fn next(&mut self) -> Option<RangeInclusive<T>> {
284 if self.invalid {
285 return None;
286 }
287
288 let next = self.current.wrapping_add(&self.step_by);
289 let right = next.wrapping_sub(&T::one());
290 self.invalid = right == self.right;
291 let left = std::mem::replace(&mut self.current, next);
292 Some(left..=right)
293 }
294}
295
296impl<T> HistogramPartition for HistogramFast<T>
297where
298 T: PrimInt
299 + CheckedSub
300 + ToPrimitive
301 + CheckedAdd
302 + One
303 + FromPrimitive
304 + HasUnsignedVersion
305 + Bounded
306 + WrappingAdd,
307 T::Unsigned: Bounded
308 + HasUnsignedVersion<LeBytes = T::LeBytes, Unsigned = T::Unsigned>
309 + WrappingAdd
310 + ToPrimitive
311 + Sub<Output = T::Unsigned>
312 + FromPrimitive
313 + WrappingSub,
314{
315 fn overlapping_partition(
316 &self,
317 n: NonZeroUsize,
318 overlap: usize,
319 ) -> Result<Vec<Self>, HistErrors> {
320 let mut result = Vec::with_capacity(n.get());
321 let size = self.bin_count() - 1;
322 let denominator = n.get() + overlap;
323 for c in 0..n.get() {
324 let left_distance = c.checked_mul(size).ok_or(HistErrors::Overflow)? / denominator;
325
326 let left = to_u(self.left)
327 + T::Unsigned::from_usize(left_distance).ok_or(HistErrors::CastError)?;
328
329 let right_distance = (c + overlap + 1)
330 .checked_mul(size)
331 .ok_or(HistErrors::Overflow)?
332 / denominator;
333
334 let right = to_u(self.left)
335 + T::Unsigned::from_usize(right_distance).ok_or(HistErrors::CastError)?;
336
337 let left = from_u(left);
338 let right = from_u(right);
339
340 result.push(Self::new_inclusive(left, right)?);
341 if result.last().unwrap().hist.is_empty() {
342 return Err(HistErrors::IntervalWidthZero);
343 }
344 }
345 Ok(result)
346 }
347}
348
349pub type HistUsizeFast = HistogramFast<usize>;
351pub type HistU128Fast = HistogramFast<u128>;
353pub type HistU64Fast = HistogramFast<u64>;
355pub type HistU32Fast = HistogramFast<u32>;
357pub type HistU16Fast = HistogramFast<u16>;
359pub type HistU8Fast = HistogramFast<u8>;
361
362pub type HistIsizeFast = HistogramFast<isize>;
364pub type HistI128Fast = HistogramFast<i128>;
366pub type HistI64Fast = HistogramFast<i64>;
368pub type HistI32Fast = HistogramFast<i32>;
370pub type HistI16Fast = HistogramFast<i16>;
372pub type HistI8Fast = HistogramFast<i8>;
374
375impl<T> Histogram for HistogramFast<T> {
376 #[inline]
377 fn increment_index_by(&mut self, index: usize, count: usize) -> Result<(), HistErrors> {
378 match self.hist.get_mut(index) {
379 None => Err(HistErrors::OutsideHist),
380 Some(val) => {
381 *val += count;
382 Ok(())
383 }
384 }
385 }
386
387 #[inline]
388 fn hist(&self) -> &Vec<usize> {
389 &self.hist
390 }
391
392 #[inline]
393 fn bin_count(&self) -> usize {
394 self.hist.len()
395 }
396
397 #[inline]
398 fn reset(&mut self) {
399 self.hist.iter_mut().for_each(|h| *h = 0);
401 }
402}
403
404impl<T> HistogramVal<T> for HistogramFast<T>
405where
406 T: PrimInt + HasUnsignedVersion + WrappingAdd,
407 T::Unsigned: Bounded
408 + HasUnsignedVersion<LeBytes = T::LeBytes>
409 + WrappingAdd
410 + ToPrimitive
411 + Sub<Output = T::Unsigned>,
412{
413 #[inline]
414 fn first_border(&self) -> T {
415 self.left
416 }
417
418 fn last_border(&self) -> T {
419 self.right
420 }
421
422 #[inline(always)]
423 fn last_border_is_inclusive(&self) -> bool {
424 true
425 }
426
427 fn distance<V: Borrow<T>>(&self, val: V) -> f64 {
428 let val = val.borrow();
429 if self.not_inside(val) {
430 let dist = if *val < self.first_border() {
431 self.first_border() - *val
432 } else {
433 val.saturating_sub(self.right)
434 };
435 dist.to_f64().unwrap_or(f64::INFINITY)
436 } else {
437 0.0
438 }
439 }
440
441 #[inline(always)]
442 fn get_bin_index<V: Borrow<T>>(&self, val: V) -> Result<usize, HistErrors> {
443 let val = *val.borrow();
444 if val <= self.right {
445 match val.checked_sub(&self.left) {
446 None => {
447 let left = self.left.to_isize().ok_or(HistErrors::CastError)?;
448 let val = val.to_isize().ok_or(HistErrors::CastError)?;
449 match val.checked_sub(left) {
450 None => Err(HistErrors::OutsideHist),
451 Some(index) => index.to_usize().ok_or(HistErrors::OutsideHist),
452 }
453 }
454 Some(index) => index.to_usize().ok_or(HistErrors::OutsideHist),
455 }
456 } else {
457 Err(HistErrors::OutsideHist)
458 }
459 }
460
461 fn bin_enum_iter(&self) -> Box<dyn Iterator<Item = Bin<T>> + '_> {
465 let iter = self.bin_iter().map(|bin| Bin::SingleValued(bin));
466 Box::new(iter)
467 }
468
469 #[inline]
470 fn is_inside<V: Borrow<T>>(&self, val: V) -> bool {
471 let val = *val.borrow();
472 val >= self.left && val <= self.right
473 }
474
475 #[inline]
476 fn not_inside<V: Borrow<T>>(&self, val: V) -> bool {
477 let val = *val.borrow();
478 val > self.right || val < self.left
479 }
480
481 #[inline]
482 fn count_val<V: Borrow<T>>(&mut self, val: V) -> Result<usize, HistErrors> {
483 let index = self.get_bin_index(val)?;
484 self.hist[index] += 1;
485 Ok(index)
486 }
487}
488
489impl<T> HistogramIntervalDistance<T> for HistogramFast<T>
490where
491 Self: HistogramVal<T>,
492 T: PartialOrd + std::ops::Sub<Output = T> + NumCast + Copy,
493{
494 fn interval_distance_overlap<V: Borrow<T>>(&self, val: V, overlap: NonZeroUsize) -> usize {
495 let val = val.borrow();
496 if self.not_inside(val) {
497 let num_bins_overlap = 1usize.max(self.bin_count() / overlap.get());
498 let dist = if *val < self.left {
499 self.left - *val
500 } else {
501 *val - self.right
502 };
503 1 + dist.to_usize().unwrap() / num_bins_overlap
504 } else {
505 0
506 }
507 }
508}
509
510impl<T> HistogramCombine for HistogramFast<T>
511where
512 Self: HistogramVal<T>,
513 T: PrimInt + HasUnsignedVersion + WrappingAdd,
514 T::Unsigned: Bounded
515 + HasUnsignedVersion<LeBytes = T::LeBytes>
516 + WrappingAdd
517 + ToPrimitive
518 + Sub<Output = T::Unsigned>,
519{
520 fn encapsulating_hist<S>(hists: &[S]) -> Result<Self, HistErrors>
521 where
522 S: Borrow<Self>,
523 {
524 if hists.is_empty() {
525 Err(HistErrors::EmptySlice)
526 } else if hists.len() == 1 {
527 let h = hists[0].borrow();
528 Ok(Self {
529 left: h.left,
530 right: h.right,
531 hist: vec![0; h.hist.len()],
532 })
533 } else {
534 let mut min = hists[0].borrow().left;
535 let mut max = hists[0].borrow().right;
536 hists[1..].iter().for_each(|h| {
537 let h = h.borrow();
538 if h.left < min {
539 min = h.left;
540 }
541 if h.right > max {
542 max = h.right;
543 }
544 });
545 Self::new_inclusive(min, max)
546 }
547 }
548
549 fn align<S>(&self, right: S) -> Result<usize, HistErrors>
550 where
551 S: Borrow<Self>,
552 {
553 let right = right.borrow();
554
555 if self.is_inside(right.left) {
556 (to_u(right.left) - to_u(self.left))
557 .to_usize()
558 .ok_or(HistErrors::UsizeCastError)
559 } else {
560 Err(HistErrors::OutsideHist)
561 }
562 }
563}
564
565impl<T> IntervalOrder for HistogramFast<T>
566where
567 T: PrimInt,
568{
569 fn left_compare(&self, other: &Self) -> std::cmp::Ordering {
570 let order = self.left.cmp(&other.left);
571 if order.is_eq() {
572 return self.right.cmp(&other.right);
573 }
574 order
575 }
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581 use rand::{distr::*, SeedableRng};
582 use rand_pcg::Pcg64Mcg;
583
584 fn hist_test_fast<T>(left: T, right: T)
585 where
586 T: PrimInt
587 + num_traits::Bounded
588 + PartialOrd
589 + CheckedSub
590 + One
591 + NumCast
592 + Copy
593 + FromPrimitive
594 + HasUnsignedVersion
595 + WrappingAdd,
596 std::ops::RangeInclusive<T>: Iterator<Item = T>,
597 T::Unsigned: Bounded
598 + HasUnsignedVersion<LeBytes = T::LeBytes>
599 + WrappingAdd
600 + ToPrimitive
601 + Sub<Output = T::Unsigned>,
602 {
603 let mut hist = HistogramFast::<T>::new_inclusive(left, right).unwrap();
604 assert!(hist.not_inside(T::max_value()));
605 assert!(hist.not_inside(T::min_value()));
606 let two = unsafe { NonZeroUsize::new_unchecked(2) };
607 for (id, i) in (left..=right).enumerate() {
608 assert!(hist.is_inside(i));
609 assert_eq!(hist.is_inside(i), !hist.not_inside(i));
610 assert!(hist.get_bin_index(i).unwrap() == id);
611 assert_eq!(hist.distance(i), 0.0);
612 assert_eq!(hist.interval_distance_overlap(i, two), 0);
613 hist.count_val(i).unwrap();
614 }
615 let lm1 = left - T::one();
616 let rp1 = right + T::one();
617 assert!(hist.not_inside(lm1));
618 assert!(hist.not_inside(rp1));
619 assert_eq!(hist.is_inside(lm1), !hist.not_inside(lm1));
620 assert_eq!(hist.is_inside(rp1), !hist.not_inside(rp1));
621 assert_eq!(hist.distance(lm1), 1.0);
622 assert_eq!(hist.distance(rp1), 1.0);
623 let one = unsafe { NonZeroUsize::new_unchecked(1) };
624 assert_eq!(hist.interval_distance_overlap(rp1, one), 1);
625 assert_eq!(hist.interval_distance_overlap(lm1, one), 1);
626 assert_eq!(hist.bin_enum_iter().count(), hist.bin_count());
627 }
628
629 #[test]
630 fn hist_fast() {
631 hist_test_fast(20usize, 31usize);
632 hist_test_fast(-23isize, 31isize);
633 hist_test_fast(-23i16, 31);
634 hist_test_fast(1u8, 3u8);
635 hist_test_fast(123u128, 300u128);
636 hist_test_fast(-123i128, 300i128);
637
638 hist_test_fast(-100i8, 100i8);
639 }
640
641 #[test]
642 fn hist_creation() {
643 let _ = HistU8Fast::new_inclusive(0, u8::MAX).unwrap();
644 let _ = HistI8Fast::new_inclusive(i8::MIN, i8::MAX).unwrap();
645 }
646
647 #[test]
648 fn partion_test() {
649 let n = NonZeroUsize::new(2).unwrap();
650 let h = HistU8Fast::new_inclusive(0, u8::MAX).unwrap();
651 let h_part = h.overlapping_partition(n, 0).unwrap();
652 assert_eq!(h.left, h_part[0].left);
653 assert_eq!(h.right, h_part.last().unwrap().right);
654
655 let h = HistI8Fast::new_inclusive(i8::MIN, i8::MAX).unwrap();
656 let h_part = h.overlapping_partition(n, 0).unwrap();
657 assert_eq!(h.left, h_part[0].left);
658 assert_eq!(h.right, h_part.last().unwrap().right);
659
660 let h = HistI16Fast::new_inclusive(i16::MIN, i16::MAX).unwrap();
661 let h_part = h.overlapping_partition(n, 2).unwrap();
662 assert_eq!(h.left, h_part[0].left);
663 assert_eq!(h.right, h_part.last().unwrap().right);
664
665 let _ = h
666 .overlapping_partition(NonZeroUsize::new(2000).unwrap(), 0)
667 .unwrap();
668 }
669
670 #[test]
671 fn overlapping_partition_test2() {
672 let mut rng = Pcg64Mcg::seed_from_u64(2314668);
673 let uni = Uniform::new_inclusive(-100, 100).unwrap();
674 for overlap in 0..=3 {
675 for _ in 0..100 {
676 let (left, right) = loop {
677 let mut num_1 = uni.sample(&mut rng);
678 let mut num_2 = uni.sample(&mut rng);
679
680 if num_1 != num_2 {
681 if num_2 < num_1 {
682 std::mem::swap(&mut num_1, &mut num_2);
683 }
684 if (num_2 as isize - num_1 as isize) < (overlap as isize + 1) {
685 continue;
686 }
687 break (num_1, num_2);
688 }
689 };
690 let hist_fast = HistI8Fast::new_inclusive(left, right).unwrap();
691 let overlapping = hist_fast
692 .overlapping_partition(NonZeroUsize::new(3).unwrap(), overlap)
693 .unwrap();
694
695 assert_eq!(
696 overlapping.last().unwrap().last_border(),
697 hist_fast.last_border()
698 );
699
700 assert_eq!(
701 overlapping.first().unwrap().first_border(),
702 hist_fast.first_border()
703 );
704 }
705 }
706 }
707
708 #[test]
709 fn hist_combine() {
710 let left = HistI8Fast::new_inclusive(-5, 0).unwrap();
711 let right = HistI8Fast::new_inclusive(-1, 2).unwrap();
712
713 let en = HistI8Fast::encapsulating_hist(&[&left, &right]).unwrap();
714
715 assert_eq!(en.left, left.left);
716 assert_eq!(en.right, right.right);
717 assert_eq!(en.bin_count(), 8);
718
719 let align = left.align(right).unwrap();
720
721 assert_eq!(align, 4);
722
723 let left = HistI8Fast::new_inclusive(i8::MIN, 0).unwrap();
724 let right = HistI8Fast::new_inclusive(0, i8::MAX).unwrap();
725
726 let en = HistI8Fast::encapsulating_hist(&[&left, &right]).unwrap();
727
728 assert_eq!(en.bin_count(), 256);
729
730 let align = left.align(right).unwrap();
731
732 assert_eq!(128, align);
733
734 let left = HistI8Fast::new_inclusive(i8::MIN, i8::MAX).unwrap();
735 let small = HistI8Fast::new_inclusive(127, 127).unwrap();
736
737 let align = left.align(&small).unwrap();
738
739 assert_eq!(255, align);
740
741 let en = HistI8Fast::encapsulating_hist(&[&left]).unwrap();
742 assert_eq!(en.bin_count(), 256);
743 let slice = [&left];
744 let en = HistI8Fast::encapsulating_hist(&slice[1..]);
745 assert_eq!(en.err(), Some(HistErrors::EmptySlice));
746 let en = HistI8Fast::encapsulating_hist(&[small, left]).unwrap();
747
748 assert_eq!(en.bin_count(), 256);
749 }
750
751 #[test]
752 fn hist_try_add() {
753 let mut first = HistU8Fast::new_inclusive(0, 23).unwrap();
754 let mut second = HistU8Fast::new_inclusive(0, 23).unwrap();
755
756 for i in 0..=23 {
757 first.increment(i).unwrap();
758 }
759 for i in 0..=11 {
760 second.increment(i).unwrap();
761 }
762
763 let outcome = first.try_add(&second);
764 assert!(outcome.is_success());
765
766 let hist = first.hist();
767
768 #[allow(clippy::needless_range_loop)]
769 for i in 0..=11 {
770 assert_eq!(hist[i], 2);
771 }
772 #[allow(clippy::needless_range_loop)]
773 for i in 12..=23 {
774 assert_eq!(hist[i], 1);
775 }
776
777 let third = HistU8Fast::new(0, 23).unwrap();
778
779 let outcome = first.try_add(&third);
780 assert!(
781 outcome.is_failure(),
782 "Needs to be Err because ranges do not match"
783 )
784 }
785}