1use {
2 crate::{traits::*, *},
3 num_traits::{identities::*, ops::wrapping::*, Bounded},
4 rand::Rng,
5 std::{io::Write, marker::PhantomData, num::*},
6};
7
8#[cfg(feature = "serde_support")]
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone)]
21#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
22pub struct WangLandau1T<Hist, Rng, Ensemble, S, Res, Energy> {
23 pub(crate) ensemble: Ensemble,
24 pub(crate) rng: Rng,
25 pub(crate) marker_res: PhantomData<Res>,
26 pub(crate) steps: Vec<S>,
27 mode: WangLandauMode,
28 pub(crate) log_density: Vec<f64>,
29 pub(crate) log_f: f64,
30 pub(crate) log_f_threshold: f64,
31 pub(crate) step_size: usize,
32 step_count: usize,
33 accepted_steps_total: usize,
34 recected_steps_total: usize,
35 accepted_steps_current: usize,
36 recected_steps_current: usize,
37 pub(crate) old_bin: usize,
38 pub(crate) hist: Hist,
39 pub(crate) old_energy: Option<Energy>,
40 check_refine_every: usize,
41}
42
43impl<Hist, R, E, S, Res, Energy> GlueAble<Hist> for WangLandau1T<Hist, R, E, S, Res, Energy>
44where
45 Hist: Clone,
46{
47 fn push_glue_entry_ignoring(&self, job: &mut GlueJob<Hist>, ignore_idx: &[usize]) {
48 if !ignore_idx.contains(&0) {
49 let sim_progress = SimProgress::LogF(self.log_f);
50 let rejected = self.total_steps_rejected() as u64;
51 let accepted = self.total_steps_accepted() as u64;
52
53 let stats = IntervalSimStats {
54 sim_progress,
55 interval_sim_type: SimulationType::WangLandau1T,
56 rejected_steps: rejected,
57 accepted_steps: accepted,
58 replica_exchanges: None,
59 proposed_replica_exchanges: None,
60 merged_over_walkers: NonZeroUsize::new(1).unwrap(),
61 };
62
63 let glue_entry = GlueEntry {
64 hist: self.hist.clone(),
65 prob: self.log_density.clone(),
66 log_base: LogBase::BaseE,
67 interval_stats: stats,
68 };
69 job.collection.push(glue_entry);
70 }
71 }
72}
73
74impl<Hist, Rng, Ensemble, S, Res, Energy> WangLandau1T<Hist, Rng, Ensemble, S, Res, Energy> {
75 pub fn into_inner(self) -> (Ensemble, Hist, Rng) {
77 (self.ensemble, self.hist, self.rng)
78 }
79}
80
81impl<Hist, R, E, S, Res, Energy> WangLandau for WangLandau1T<Hist, R, E, S, Res, Energy> {
82 #[inline(always)]
83 fn log_f(&self) -> f64 {
84 self.log_f
85 }
86
87 #[inline(always)]
88 fn log_f_threshold(&self) -> f64 {
89 self.log_f_threshold
90 }
91
92 fn set_log_f_threshold(&mut self, log_f_threshold: f64) -> Result<f64, WangLandauErrors> {
93 if !log_f_threshold.is_finite() || log_f_threshold.is_sign_negative() {
94 return Err(WangLandauErrors::InvalidLogFThreshold);
95 }
96 let old_threshold = self.log_f_threshold;
97 self.log_f_threshold = log_f_threshold;
98 Ok(old_threshold)
99 }
100
101 #[inline(always)]
102 fn log_density(&self) -> &Vec<f64> {
103 &self.log_density
104 }
105
106 fn write_log<W: Write>(&self, mut writer: W) -> Result<(), std::io::Error> {
107 writeln!(writer,
108 "#Acceptance prob_total: {}\n#Acceptance prob current: {}\n#total_steps: {}\n#log_f: {:e}\n#Current_mode {:?}",
109 self.fraction_accepted_total(),
110 self.fraction_accepted_current(),
111 self.step_counter(),
112 self.log_f(),
113 self.mode
114 )?;
115 writeln!(
116 writer,
117 "#total_steps_accepted: {}\n#total_steps_rejected: {}\n#current_accepted_steps: {}\n#current_rejected_steps: {}",
118 self.accepted_steps_total,
119 self.recected_steps_total,
120 self.accepted_steps_current,
121 self.recected_steps_current
122 )
123 }
124
125 #[inline(always)]
126 fn mode(&self) -> WangLandauMode {
127 self.mode
128 }
129
130 #[inline(always)]
131 fn step_counter(&self) -> usize {
132 self.step_count
133 }
134
135 #[inline(always)]
136 fn total_steps_rejected(&self) -> usize {
137 self.recected_steps_total
138 }
139
140 #[inline(always)]
141 fn total_steps_accepted(&self) -> usize {
142 self.accepted_steps_total
143 }
144}
145
146impl<Hist, R, E, S, Res, Energy> WangLandau1T<Hist, R, E, S, Res, Energy>
147where
148 Hist: Histogram + HistogramVal<Energy>,
149{
150 pub fn is_initialized(&self) -> bool {
154 match &self.old_energy {
155 None => false,
156 Some(e) => self.hist.is_inside(e),
157 }
158 }
159}
160
161#[derive(Clone, Copy, Debug)]
163pub enum SetInitialError {
164 DimensionError,
167 NonFiniteEncountered,
169 InvalidLogF,
171}
172
173impl<Hist, R, E, S, Res, Energy> WangLandauEnsemble<E>
174 for WangLandau1T<Hist, R, E, S, Res, Energy>
175{
176 #[inline(always)]
177 fn ensemble(&self) -> &E {
178 &self.ensemble
179 }
180
181 #[inline(always)]
182 unsafe fn ensemble_mut(&mut self) -> &mut E {
183 &mut self.ensemble
184 }
185}
186
187impl<Hist, R, E, S, Res, Energy> WangLandauEnergy<Energy>
188 for WangLandau1T<Hist, R, E, S, Res, Energy>
189{
190 #[inline(always)]
191 fn energy(&self) -> Option<&Energy> {
192 self.old_energy.as_ref()
193 }
194}
195
196impl<Hist, R, E, S, Res, Energy> WangLandauHist<Hist> for WangLandau1T<Hist, R, E, S, Res, Energy> {
197 #[inline(always)]
198 fn hist(&self) -> &Hist {
199 &self.hist
200 }
201}
202
203impl<Hist, R, E, S, Res, Energy> WangLandau1T<Hist, R, E, S, Res, Energy> {
204 fn fraction_accepted_total(&self) -> f64 {
207 let sum = self.accepted_steps_total + self.recected_steps_total;
208 self.accepted_steps_total as f64 / sum as f64
209 }
210
211 fn fraction_accepted_current(&self) -> f64 {
215 let total = self.accepted_steps_current + self.recected_steps_current;
216 if total == 0 {
217 f64::NAN
218 } else {
219 self.accepted_steps_current as f64 / total as f64
220 }
221 }
222
223 pub fn set_initial_probability_guess(
233 mut self,
234 new_guess: Vec<f64>,
235 new_log_f: f64,
236 ) -> Result<Self, SetInitialError>
237 where
238 Hist: Histogram,
239 {
240 if 0.0 >= new_log_f || new_log_f > 10.0 {
241 Err(SetInitialError::InvalidLogF)
242 } else if new_guess.len() != self.log_density.len() {
243 Err(SetInitialError::DimensionError)
244 } else if new_guess.iter().any(|val| !val.is_finite()) {
245 Err(SetInitialError::NonFiniteEncountered)
246 } else {
247 self.log_density = new_guess;
248 self.log_f = new_log_f;
249 self.step_count = 0;
250 self.accepted_steps_current = 0;
251 self.accepted_steps_total = 0;
252 self.recected_steps_current = 0;
253 self.recected_steps_total = 0;
254 self.mode = WangLandauMode::RefineOriginal;
255 self.hist.reset();
256 self.old_energy = None;
257 self.old_bin = usize::MAX;
258 Ok(self)
259 }
260 }
261}
262
263impl<Hist, R, E, S, Res, Energy> WangLandau1T<Hist, R, E, S, Res, Energy>
264where
265 R: Rng,
266 E: MarkovChain<S, Res>,
267 Energy: Clone,
268 Hist: Histogram + HistogramVal<Energy>,
269{
270 pub fn new(
287 log_f_threshold: f64,
288 ensemble: E,
289 rng: R,
290 step_size: usize,
291 histogram: Hist,
292 check_refine_every: usize,
293 ) -> Result<Self, WangLandauErrors> {
294 if !log_f_threshold.is_finite() || log_f_threshold.is_sign_negative() {
295 return Err(WangLandauErrors::InvalidLogFThreshold);
296 } else if check_refine_every == 0 {
297 return Err(WangLandauErrors::CheckRefineEvery0);
298 }
299 let log_density = vec![0.0; histogram.bin_count()];
300 let steps = Vec::with_capacity(step_size);
301
302 Ok(Self {
303 ensemble,
304 step_count: 0,
305 step_size,
306 hist: histogram,
307 rng,
308 marker_res: PhantomData::<Res>,
309 log_f: 1.0,
310 log_density,
311 log_f_threshold,
312 mode: WangLandauMode::RefineOriginal,
313 recected_steps_current: 0,
314 recected_steps_total: 0,
315 accepted_steps_current: 0,
316 accepted_steps_total: 0,
317 old_bin: usize::MAX,
318 old_energy: None,
319 check_refine_every,
320 steps,
321 })
322 }
323
324 fn init<F>(&mut self, energy_fn: F, step_limit: Option<u64>) -> Result<(), WangLandauErrors>
325 where
326 F: Fn(&mut E) -> Option<Energy>,
327 {
328 self.old_energy = energy_fn(&mut self.ensemble);
329 if self.old_energy.is_some() {
330 return Ok(());
331 }
332
333 match step_limit {
334 None => loop {
335 self.ensemble.m_steps_quiet(self.step_size);
336 self.old_energy = energy_fn(&mut self.ensemble);
337
338 if self.old_energy.is_some() {
339 self.count_accepted();
340 return Ok(());
341 }
342 self.count_rejected();
343 },
344 Some(limit) => {
345 for _ in 0..limit {
346 self.ensemble.m_steps_quiet(self.step_size);
347 self.old_energy = energy_fn(&mut self.ensemble);
348
349 if self.old_energy.is_some() {
350 self.count_accepted();
351 return Ok(());
352 }
353 self.count_rejected();
354 }
355 Err(WangLandauErrors::InitFailed)
356 }
357 }
358 }
359
360 fn greedy_helper<F, H, J>(&mut self, old_distance: &mut J, energy_fn: F, distance_fn: H)
361 where
362 F: Fn(&mut E) -> Option<Energy> + Copy,
363 H: Fn(&Hist, &Energy) -> J,
364 J: PartialOrd,
365 {
366 self.ensemble.m_steps(self.step_size, &mut self.steps);
367
368 if let Some(energy) = energy_fn(&mut self.ensemble) {
369 let distance = distance_fn(&self.hist, &energy);
370 if distance <= *old_distance {
371 self.old_energy = Some(energy);
372 *old_distance = distance;
373 self.count_accepted();
374
375 return;
376 }
377 }
378
379 self.count_rejected();
380 self.ensemble.undo_steps_quiet(&self.steps);
381 }
382
383 pub fn init_greedy_heuristic<F>(
398 &mut self,
399 energy_fn: F,
400 step_limit: Option<u64>,
401 ) -> Result<(), WangLandauErrors>
402 where
403 F: Fn(&mut E) -> Option<Energy>,
404 {
405 self.init(&energy_fn, step_limit)?;
406 let mut old_distance = self.hist.distance(self.old_energy_ref());
407 let mut step_count = 0;
408 while old_distance != 0.0 {
409 self.greedy_helper(&mut old_distance, &energy_fn, |hist, energy| {
410 hist.distance(energy)
411 });
412 if let Some(limit) = step_limit {
413 if limit == step_count {
414 return Err(WangLandauErrors::InitFailed);
415 }
416 step_count += 1;
417 }
418 }
419 self.end_init();
420 Ok(())
421 }
422
423 pub fn init_interval_heuristik<F>(
439 &mut self,
440 overlap: NonZeroUsize,
441 energy_fn: F,
442 step_limit: Option<u64>,
443 ) -> Result<(), WangLandauErrors>
444 where
445 F: Fn(&mut E) -> Option<Energy>,
446 Hist: HistogramIntervalDistance<Energy>,
447 {
448 self.init(&energy_fn, step_limit)?;
449 let mut old_dist = self
450 .hist
451 .interval_distance_overlap(self.old_energy_ref(), overlap);
452
453 let dist = |h: &Hist, val: &Energy| h.interval_distance_overlap(val, overlap);
454 let mut step_count = 0;
455 while old_dist != 0 {
456 self.greedy_helper(&mut old_dist, &energy_fn, dist);
457 if let Some(limit) = step_limit {
458 if limit == step_count {
459 return Err(WangLandauErrors::InitFailed);
460 }
461 step_count += 1;
462 }
463 }
464 self.end_init();
465 Ok(())
466 }
467
468 pub fn init_mixed_heuristik<F, U>(
486 &mut self,
487 overlap: NonZeroUsize,
488 mid: U,
489 energy_fn: F,
490 step_limit: Option<u64>,
491 ) -> Result<(), WangLandauErrors>
492 where
493 F: Fn(&mut E) -> Option<Energy>,
494 Hist: HistogramIntervalDistance<Energy>,
495 U: One + Bounded + WrappingAdd + Eq + PartialOrd,
496 {
497 self.init(&energy_fn, step_limit)?;
498 if self.hist.is_inside(self.old_energy_ref()) {
499 self.end_init();
500 return Ok(());
501 }
502
503 let mut old_dist = f64::INFINITY;
504 let mut old_dist_interval = usize::MAX;
505 let mut counter: U = U::min_value();
506 let min_val = U::min_value();
507 let one = U::one();
508 let dist_interval = |h: &Hist, val: &Energy| h.interval_distance_overlap(val, overlap);
509 let mut step_count = 0;
510 loop {
511 if counter == min_val {
512 let current_energy = self.old_energy_ref();
513 old_dist = self.hist.distance(current_energy);
514 } else if counter == mid {
515 let current_energy = self.old_energy_ref();
516 old_dist_interval = dist_interval(&self.hist, current_energy);
517 }
518 if counter < mid {
519 self.greedy_helper(&mut old_dist, &energy_fn, |hist, val| hist.distance(val));
520 if old_dist == 0.0 {
521 break;
522 }
523 } else {
524 self.greedy_helper(&mut old_dist_interval, &energy_fn, dist_interval);
525 if old_dist_interval == 0 {
526 break;
527 }
528 }
529 counter = counter.wrapping_add(&one);
530 if let Some(limit) = step_limit {
531 if limit == step_count {
532 return Err(WangLandauErrors::InitFailed);
533 }
534 step_count += 1;
535 }
536 }
537 self.end_init();
538 Ok(())
539 }
540
541 fn end_init(&mut self) {
542 self.old_bin = self
543 .hist
544 .get_bin_index(self.old_energy_ref())
545 .expect("Error in heuristic - old bin invalid");
546 }
547
548 fn old_energy_clone(&self) -> Energy {
549 self.old_energy_ref().clone()
550 }
551
552 fn old_energy_ref(&self) -> &Energy {
553 self.old_energy.as_ref().unwrap()
554 }
555
556 fn count_accepted(&mut self) {
557 self.ensemble.steps_accepted(&self.steps);
558 self.accepted_steps_current += 1;
559 self.accepted_steps_total += 1;
560 }
561
562 fn count_rejected(&mut self) {
563 self.ensemble.steps_rejected(&self.steps);
564 self.recected_steps_current += 1;
565 self.recected_steps_total += 1;
566 }
567
568 fn check_refine(&mut self) {
569 match self.mode {
570 WangLandauMode::Refine1T => {
571 self.log_f = self.log_f_1_t();
572 }
573 WangLandauMode::RefineOriginal => {
574 if self.step_count % self.check_refine_every == 0 && !self.hist.any_bin_zero() {
575 self.recected_steps_current = 0;
576 self.accepted_steps_current = 0;
577 let ref_1_t = self.log_f_1_t();
578 self.log_f *= 0.5;
579 if self.log_f < ref_1_t {
580 self.log_f = ref_1_t;
581 self.mode = WangLandauMode::Refine1T;
582 }
583 self.hist.reset();
584 }
585 }
586 }
587 }
588
589 fn wl_step_helper(&mut self, energy: Option<Energy>) {
590 let current_energy = match energy {
591 Some(energy) => energy,
592 None => {
593 self.count_rejected();
594 self.hist.increment_index(self.old_bin).unwrap();
595 self.log_density[self.old_bin] += self.log_f;
596 self.ensemble.undo_steps_quiet(&self.steps);
597 return;
598 }
599 };
600
601 match self.hist.get_bin_index(¤t_energy) {
602 Ok(current_bin) => {
603 let accept_prob = self.metropolis_acception_prob(current_bin);
604
605 if self.rng.random::<f64>() > accept_prob {
606 self.count_rejected();
608 self.ensemble.undo_steps_quiet(&self.steps);
609 } else {
610 self.count_accepted();
612
613 self.old_energy = Some(current_energy);
614 self.old_bin = current_bin;
615 }
616 }
617 _ => {
618 self.count_rejected();
620 self.ensemble.undo_steps_quiet(&self.steps);
621 }
622 };
623
624 self.hist.increment_index(self.old_bin).unwrap();
625 self.log_density[self.old_bin] += self.log_f;
626 }
627
628 pub fn wang_landau_step<F>(&mut self, energy_fn: F)
640 where
641 F: Fn(&E) -> Option<Energy>,
642 {
643 unsafe { self.wang_landau_step_unsafe(|e| energy_fn(e)) }
644 }
645
646 pub unsafe fn wang_landau_step_unsafe<F>(&mut self, mut energy_fn: F)
662 where
663 F: FnMut(&mut E) -> Option<Energy>,
664 {
665 debug_assert!(
666 self.old_energy.is_some(),
667 "Error - self.old_energy invalid - Did you forget to call one of the `self.init*` members for initialization?"
668 );
669
670 self.step_count += 1;
671
672 self.ensemble.m_steps(self.step_size, &mut self.steps);
673
674 self.check_refine();
675 let current_energy = energy_fn(&mut self.ensemble);
676
677 self.wl_step_helper(current_energy);
678 }
679
680 pub fn wang_landau_step_acc<F>(&mut self, energy_fn: F)
689 where
690 F: FnMut(&E, &S, &mut Energy),
691 {
692 debug_assert!(
693 self.old_energy.is_some(),
694 "Error - self.old_energy invalid - Did you forget to call one of the `self.init*` members for initialization?"
695 );
696
697 self.step_count += 1;
698
699 let mut new_energy = self.old_energy_clone();
700
701 self.ensemble
702 .m_steps_acc(self.step_size, &mut self.steps, &mut new_energy, energy_fn);
703
704 self.check_refine();
705
706 self.wl_step_helper(Some(new_energy));
707 }
708
709 pub fn wang_landau_convergence<F>(&mut self, energy_fn: F)
713 where
714 F: Fn(&E) -> Option<Energy>,
715 {
716 while !self.is_finished() {
717 self.wang_landau_step(&energy_fn);
718 }
719 }
720
721 pub fn wang_landau_convergence_acc<F>(&mut self, mut energy_fn: F)
725 where
726 F: FnMut(&E, &S, &mut Energy),
727 {
728 while !self.is_finished() {
729 self.wang_landau_step_acc(&mut energy_fn);
730 }
731 }
732
733 pub unsafe fn wang_landau_convergence_unsafe<F>(&mut self, mut energy_fn: F)
742 where
743 F: FnMut(&mut E) -> Option<Energy>,
744 {
745 while !self.is_finished() {
746 self.wang_landau_step_unsafe(&mut energy_fn);
747 }
748 }
749
750 pub fn wang_landau_while<F, W>(&mut self, energy_fn: F, mut condition: W)
755 where
756 F: Fn(&E) -> Option<Energy>,
757 W: FnMut(&Self) -> bool,
758 {
759 while !self.is_finished() && condition(self) {
760 self.wang_landau_step(&energy_fn);
761 }
762 }
763
764 pub fn wang_landau_while_acc<F, W>(&mut self, mut energy_fn: F, mut condition: W)
769 where
770 F: FnMut(&E, &S, &mut Energy),
771 W: FnMut(&Self) -> bool,
772 {
773 while !self.is_finished() && condition(self) {
774 self.wang_landau_step_acc(&mut energy_fn);
775 }
776 }
777
778 pub unsafe fn wang_landau_while_unsafe<F, W>(&mut self, mut energy_fn: F, mut condition: W)
788 where
789 F: FnMut(&mut E) -> Option<Energy>,
790 W: FnMut(&Self) -> bool,
791 {
792 while !self.is_finished() && condition(self) {
793 self.wang_landau_step_unsafe(&mut energy_fn);
794 }
795 }
796
797 #[inline(always)]
799 fn metropolis_acception_prob(&self, new_bin: usize) -> f64 {
800 (self.log_density[self.old_bin] - self.log_density[new_bin]).exp()
801 }
802}
803
804#[cfg(test)]
805mod tests {
806 use super::*;
807 use crate::examples::coin_flips::*;
808 use rand::SeedableRng;
809 use rand_pcg::Pcg64Mcg;
810 #[test]
811 #[cfg_attr(miri, ignore)]
812 fn wl_simulations_equal() {
813 let mut rng = Pcg64Mcg::seed_from_u64(2239790);
814 let ensemble = CoinFlipSequence::new(100, Pcg64Mcg::from_rng(&mut rng));
815 let histogram = HistogramFast::new_inclusive(0, 100).unwrap();
816 let mut wl = WangLandau1T::new(0.0075, ensemble, rng, 1, histogram, 30).unwrap();
817
818 wl.init_mixed_heuristik(
819 NonZeroUsize::new(3).unwrap(),
820 6400i16,
821 |e| Some(e.head_count()),
822 None,
823 )
824 .unwrap();
825
826 let mut wl_backup = wl.clone();
827 let start_wl = std::time::Instant::now();
828 wl.wang_landau_convergence(|e| Some(e.head_count()));
829 let dur_1 = start_wl.elapsed();
830 let start_wl_acc = std::time::Instant::now();
831 wl_backup.wang_landau_convergence_acc(CoinFlipSequence::update_head_count);
832 let dur_2 = start_wl_acc.elapsed();
833 println!(
834 "WL: {}, WL_ACC: {}, difference: {}",
835 dur_1.as_nanos(),
836 dur_2.as_nanos(),
837 dur_1.as_nanos() - dur_2.as_nanos()
838 );
839
840 for (&log_value, &log_value_acc) in
842 wl.log_density().iter().zip(wl_backup.log_density().iter())
843 {
844 assert_eq!(log_value, log_value_acc);
845 }
846 }
847}