1use {
2 crate::histogram::*,
3 num_traits::{cast::*, float::*, identities::*},
4 std::{borrow::*, num::*, sync::atomic::AtomicUsize},
5};
6
7#[cfg(feature = "serde_support")]
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone)]
12#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
13pub struct HistogramFloat<T> {
14 pub(crate) bin_borders: Vec<T>,
15 pub(crate) hist: Vec<usize>,
16}
17
18impl<T> From<AtomicHistogramFloat<T>> for HistogramFloat<T> {
19 fn from(other: AtomicHistogramFloat<T>) -> Self {
20 let hist = other
21 .hist
22 .into_iter()
23 .map(AtomicUsize::into_inner)
24 .collect();
25 Self {
26 hist,
27 bin_borders: other.bin_borders,
28 }
29 }
30}
31
32impl<T> HistogramFloat<T> {
33 pub fn borders(&self) -> &Vec<T> {
35 &self.bin_borders
36 }
37}
38
39impl<T> HistogramFloat<T>
40where
41 T: Copy,
42{
43 fn get_right(&self) -> T {
44 self.bin_borders[self.bin_borders.len() - 1]
45 }
46}
47
48impl<T> HistogramFloat<T>
49where
50 T: Float + PartialOrd + FromPrimitive,
51{
52 pub fn new(left: T, right: T, bins: usize) -> Result<Self, HistErrors> {
57 if left >= right {
58 return Err(HistErrors::IntervalWidthZero);
59 } else if bins < 1 {
60 return Err(HistErrors::NoBins);
61 }
62 if !left.is_finite() || !right.is_finite() {
63 return Err(HistErrors::InvalidVal);
64 }
65
66 let bins_as_t = match T::from_usize(bins) {
67 Some(val) => val,
68 None => return Err(HistErrors::UsizeCastError),
69 };
70
71 let bin_size = (right - left) / bins_as_t;
72 let hist = vec![0; bins];
73 let mut bin_borders = Vec::with_capacity(bins + 1);
74 bin_borders
75 .extend((0..bins).map(|val| bin_size.mul_add(T::from_usize(val).unwrap(), left)));
76 bin_borders.push(right);
77 Ok(Self { bin_borders, hist })
78 }
79
80 pub fn interval_length(&self) -> T {
82 self.get_right() - self.first_border()
83 }
84
85 pub fn bin_iter(&self) -> impl Iterator<Item = &[T; 2]> {
103 BorderWindow::new(self.bin_borders.as_slice())
104 }
105
106 pub fn bin_hits_iter(&self) -> impl Iterator<Item = (&[T; 2], usize)> {
129 self.bin_iter().zip(self.hist.iter().copied())
130 }
131
132 #[inline]
133 pub fn increment<B: Borrow<T>>(&mut self, val: B) -> Result<usize, HistErrors> {
137 self.count_val(val)
138 }
139
140 #[inline]
141 pub fn increment_quiet<B: Borrow<T>>(&mut self, val: B) {
145 let _ = self.increment(val);
146 }
147}
148
149impl<T> Histogram for HistogramFloat<T> {
150 #[inline(always)]
151 fn bin_count(&self) -> usize {
152 self.hist.len()
153 }
154
155 #[inline(always)]
156 fn hist(&self) -> &Vec<usize> {
157 &self.hist
158 }
159
160 #[inline]
161 fn increment_index_by(&mut self, index: usize, count: usize) -> Result<(), HistErrors> {
162 match self.hist.get_mut(index) {
163 None => Err(HistErrors::OutsideHist),
164 Some(val) => {
165 *val += count;
166 Ok(())
167 }
168 }
169 }
170
171 #[inline]
172 fn reset(&mut self) {
173 self.hist.iter_mut().for_each(|h| *h = 0);
175 }
176}
177
178impl<T> HistogramVal<T> for HistogramFloat<T>
179where
180 T: Float + Zero + NumCast + PartialOrd + FromPrimitive,
181{
182 fn count_val<V: Borrow<T>>(&mut self, val: V) -> Result<usize, HistErrors> {
183 let id = self.get_bin_index(val)?;
184 self.increment_index(id).map(|_| id)
185 }
186
187 fn distance<V: Borrow<T>>(&self, val: V) -> f64 {
188 let val = val.borrow();
189 if self.is_inside(val) {
190 0.0
191 } else if !val.is_finite() {
192 f64::INFINITY
193 } else if *val < self.first_border() {
194 (self.first_border() - *val).to_f64().unwrap()
195 } else {
196 (*val - self.get_right() + T::epsilon()).to_f64().unwrap()
197 }
198 }
199
200 #[inline]
201 fn first_border(&self) -> T {
202 self.bin_borders[0]
203 }
204
205 #[inline]
206 fn last_border(&self) -> T {
207 self.bin_borders[self.bin_borders.len() - 1]
208 }
209
210 #[inline(always)]
211 fn last_border_is_inclusive(&self) -> bool {
212 false
213 }
214
215 fn is_inside<V: Borrow<T>>(&self, val: V) -> bool {
216 *val.borrow() >= self.bin_borders[0]
217 && *val.borrow() < self.bin_borders[self.bin_borders.len() - 1]
218 }
219
220 fn not_inside<V: Borrow<T>>(&self, val: V) -> bool {
221 !(*val.borrow()).is_finite()
222 || *val.borrow() < self.bin_borders[0]
223 || *val.borrow() >= self.bin_borders[self.bin_borders.len() - 1]
224 }
225
226 fn get_bin_index<V: Borrow<T>>(&self, val: V) -> Result<usize, HistErrors> {
227 let val = val.borrow();
228 if !val.is_finite() {
229 Err(HistErrors::InvalidVal)
230 } else if self.is_inside(val) {
231 let search_res = self
232 .bin_borders
233 .binary_search_by(|v| v.partial_cmp(val).expect("Should never be NaN"));
234 match search_res {
235 Result::Ok(index) => Ok(index),
236 Result::Err(index_p1) => Ok(index_p1 - 1),
237 }
238 } else {
239 Err(HistErrors::OutsideHist)
240 }
241 }
242
243 fn bin_enum_iter(&self) -> Box<dyn Iterator<Item = Bin<T>> + '_> {
248 let iter = self
249 .bin_iter()
250 .map(|slice| Bin::InclusiveExclusive(slice[0], slice[1]));
251 Box::new(iter)
252 }
253}
254
255impl<T> HistogramIntervalDistance<T> for HistogramFloat<T>
256where
257 T: Float + FromPrimitive + Zero + NumCast,
258{
259 fn interval_distance_overlap<V: Borrow<T>>(&self, val: V, overlap: NonZeroUsize) -> usize {
260 let val = val.borrow();
261
262 debug_assert!(self.interval_length() > T::zero());
263 debug_assert!(val.is_finite());
264 if self.not_inside(val) {
265 let num_bins_overlap = self.bin_count() / overlap.get();
266 let dist = if *val < self.first_border() {
267 let tmp = self.first_border() - *val;
268 (tmp / self.interval_length()).floor()
269 } else {
270 let tmp = *val - self.get_right();
271 (tmp / self.interval_length()).ceil()
272 };
273 let int_dist = dist.to_usize().unwrap();
274 1 + int_dist / num_bins_overlap
275 } else {
276 0
277 }
278 }
279}
280
281pub type HistF32 = HistogramFloat<f32>;
283
284pub type HistF64 = HistogramFloat<f64>;
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use num_traits::Bounded;
291 use rand::{distr::*, SeedableRng};
292 use rand_pcg::Pcg64Mcg;
293 #[test]
294 fn f64_hist() {
295 let rng = Pcg64Mcg::new(0xcafef00dd15ea5e5);
296 let dist = Uniform::new(f64::EPSILON, 1.0).unwrap();
297 let mut iter = dist.sample_iter(rng);
298
299 for i in 1..100 {
300 let left = iter.next().unwrap();
301 let right = left + iter.next().unwrap();
302
303 let hist = HistogramFloat::<f64>::new(left, right, i).unwrap();
304
305 assert_eq!(left, hist.first_border(), "i={}", i);
306 assert_eq!(right, hist.get_right(), "i={}", i);
307 assert_eq!(i + 1, hist.borders().len(), "i={}", i);
308 }
309 }
310
311 fn hist_test_float<T>(left: T, right: T, bin_count: usize)
312 where
313 T: Float
314 + num_traits::Bounded
315 + PartialOrd
316 + One
317 + NumCast
318 + Copy
319 + FromPrimitive
320 + Bounded
321 + std::fmt::Debug
322 + PartialOrd,
323 {
324 let hist_wrapped = HistogramFloat::<T>::new(left, right, bin_count);
325 if hist_wrapped.is_err() {
326 dbg!(&hist_wrapped);
327 }
328 let hist = hist_wrapped.unwrap();
329 assert!(hist.not_inside(T::infinity()));
330 assert!(hist.not_inside(T::nan()));
331 let len = hist.borders().len();
332
333 for (id, border) in hist.borders().iter().take(len - 1).enumerate() {
334 assert!(hist.is_inside(border));
335 assert_eq!(hist.is_inside(border), !hist.not_inside(border));
336 assert_eq!(hist.get_bin_index(border).unwrap(), id);
337 }
338
339 let last_border = hist.borders()[len - 1];
340 assert!(hist.not_inside(last_border));
341 assert_eq!(hist.is_inside(last_border), !hist.not_inside(last_border));
342 assert!(hist.get_bin_index(last_border).is_err());
343
344 for (id, border) in hist.borders().iter().skip(1).enumerate() {
345 let mut m_epsilon = *border;
346 for mut i in 1.. {
347 if i > 100 {
348 i = i * i;
349 }
350 m_epsilon = T::epsilon().mul_add(T::from_isize(-i).unwrap(), *border);
351 if m_epsilon < *border {
352 break;
353 }
354 }
355 assert!(hist.is_inside(m_epsilon));
356 assert_eq!(hist.get_bin_index(m_epsilon).unwrap(), id);
357 }
358
359 assert_eq!(
360 HistErrors::InvalidVal,
361 HistogramFloat::<T>::new(T::nan(), right, bin_count).unwrap_err()
362 );
363 assert_eq!(
364 HistErrors::InvalidVal,
365 HistogramFloat::<T>::new(left, T::nan(), bin_count).unwrap_err()
366 );
367 assert_eq!(
368 HistErrors::InvalidVal,
369 HistogramFloat::<T>::new(left, T::infinity(), bin_count).unwrap_err()
370 );
371 assert_eq!(
372 HistErrors::InvalidVal,
373 HistogramFloat::<T>::new(T::neg_infinity(), right, bin_count).unwrap_err()
374 );
375 }
376
377 #[test]
378 fn hist_float() {
379 let mut rng = Pcg64Mcg::new(0xcafef00dd15ea5e5);
380 let dist = Uniform::new(1usize, 111).unwrap();
381 let mut iter = dist.sample_iter(Pcg64Mcg::from_rng(&mut rng));
382 hist_test_float(20.0, 31.0, iter.next().unwrap());
383 hist_test_float(-23.0f32, 31.1232f32, iter.next().unwrap());
384 hist_test_float(-13.0f32, 31.4657f32, iter.next().unwrap());
385 hist_test_float(1.0f64, 3f64, iter.next().unwrap());
386
387 let dist2 = Uniform::new(0.0, 76257f64).unwrap();
388 for _ in 0..10 {
389 let (left, right) = loop {
390 let left = dist2.sample(&mut rng);
391 let right = left + dist2.sample(&mut rng);
392 if left.is_finite() && right.is_finite() {
393 break (left, right);
394 }
395 };
396 hist_test_float(left, right, iter.next().unwrap());
397 }
398 }
399}