sampling/rewl/
walker.rs

1use{
2    crate::{*, wang_landau::WangLandauMode},
3    rand::Rng,
4    std::{marker::PhantomData, mem::*, num::*, sync::*}
5};
6
7#[cfg(feature = "sweep_time_optimization")]
8use std::time::*;
9
10#[cfg(feature = "serde_support")]
11use serde::{Serialize, Deserialize};
12
13#[derive(Debug, Clone, Copy)]
14#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
15/// Errors encountered during the creation of a Rewl struct (**R**eplica **e**xchange **W**ang **L**andau)
16pub enum RewlCreationErrors
17{
18    /// histograms must have at least two bins - everything else makes no sense!
19    HistsizeError,
20
21    /// You tried to pass an empty slice
22    EmptySlice,
23
24    /// The length of the histogram vector has to be equal to the length of the ensemble vector!
25    LenMissmatch,
26}
27
28
29pub(crate) fn log_density_to_log10_density(log_density: &[f64]) -> Vec<f64>
30{
31
32    let max = log_density.iter()
33        .fold(f64::NEG_INFINITY,  |acc, &val| acc.max(val));
34    let mut log_density_res: Vec<f64> = Vec::with_capacity(log_density.len());
35    log_density_res.extend(
36        log_density.iter()
37            .map(|&val| val - max)
38    );
39    
40    let sum = log_density_res.iter()
41        .fold(0.0, |acc, &val| 
42            {
43                if val.is_finite(){
44                    acc +  val.exp()
45                } else {
46                    acc
47                }
48            }
49        );
50    let sum = -sum.log10();
51
52    log_density_res.iter_mut()
53        .for_each(|val| *val = val.mul_add(std::f64::consts::LOG10_E, sum));
54    log_density_res
55            
56    
57}
58
59
60/// # Walker for Replica exchange Wang Landau
61/// * used by [`Rewl`](`crate::rewl::Rewl`)
62/// * performes the random walk in its respective domain 
63#[derive(Debug, Clone)]
64#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
65pub struct RewlWalker<R, Hist, Energy, S, Res>
66{
67    pub(crate) id: usize,
68    pub(crate) sweep_size: NonZeroUsize,
69    pub(crate) rng: R,
70    pub(crate) hist: Hist,
71    pub(crate) log_density: Vec<f64>,
72    log_f: f64,
73    pub(crate) step_count: u64,
74    pub(crate) proposed_re: u64,
75    pub(crate) re: usize,
76    pub(crate) rejected_markov_steps: u64,
77    mode: WangLandauMode,
78    pub(crate) old_energy: Energy,
79    pub(crate) bin: usize,
80    pub(crate) markov_steps: Vec<S>,
81    marker_res: PhantomData<Res>,
82    pub(crate) step_size: usize,
83    #[cfg(feature = "sweep_time_optimization")]
84    pub(crate) duration: Duration,
85    #[cfg(feature = "sweep_stats")]
86    pub(crate) sweep_stats: SweepStats,
87}
88
89impl<R, Hist, Energy, S, Res> RewlWalker<R, Hist, Energy, S, Res>{
90    /// # Returns id of walker
91    /// * important for mapping the ensemble to the walker
92    pub fn id(&self) -> usize
93    {
94        self.id
95    }
96
97    /// # Which mode is this walker currently in?
98    /// see [WangLandauMode]
99    pub fn wang_landau_mode(&self) -> WangLandauMode
100    {
101        self.mode
102    }
103
104    /// # Returns duration of last sweep that was performed
105    #[cfg(feature = "sweep_time_optimization")]
106    pub fn duration(&self) -> Duration
107    {
108        self.duration
109    }
110
111    /// # Returns average sweep duration
112    /// * Averages over all sweep durations that are stored in the buffer
113    /// * There are up to 512 durations in the buffer
114    #[cfg(feature = "sweep_stats")]
115    pub fn average_sweep_duration(&self) -> Duration
116    {
117        self.sweep_stats.averag_duration()
118    }
119
120    /// # Returns hightest and lowest 10 percent
121    /// * returns Duration, where only 10 percent of the durations in 
122    /// the buffer took longer
123    /// * returns Duration, where only 10 percent of the durations in the
124    /// buffer finished quicker
125    #[cfg(feature = "sweep_stats")]
126    pub fn high_low_10_percent(&self) -> (Duration, Duration)
127    {
128        self.sweep_stats.percent_high_low()
129    }
130
131    /// Durations stored in the Buffer
132    #[cfg(feature = "sweep_stats")]
133    pub fn last_durations(&self) -> &[Duration]
134    {
135        self.sweep_stats.buf()
136    }
137
138    /// Returns reference of current energy
139    pub fn energy(&self) -> &Energy
140    {
141        &self.old_energy
142    }
143
144    /// Returns current energy
145    pub fn energy_copy(&self) -> Energy
146    where Energy: Copy
147    {
148        self.old_energy
149    }
150
151    /// Returns current energy
152    pub fn energy_clone(&self) -> Energy
153    where Energy: Clone
154    {
155        self.old_energy.clone()
156    }
157
158    /// # Reference to internal histogram
159    pub fn hist(&self) -> &Hist
160    {
161        &self.hist
162    }
163
164    /// # Current (logarithm of) factor f
165    /// * See the paper for more info
166    pub fn log_f(&self) -> f64
167    {
168        self.log_f
169    }
170
171    /// # how many steps per sweep
172    pub fn sweep_size(&self) -> NonZeroUsize
173    {
174        self.sweep_size
175    }
176
177    /// # change how many steps per sweep are performed
178    pub fn sweep_size_change(&mut self, sweep_size: NonZeroUsize)
179    {
180        self.sweep_size = sweep_size;
181    }
182
183    /// # step size for markov steps
184    pub fn step_size(&self) -> usize 
185    {
186        self.step_size
187    }
188
189    /// # Change step sitze for markov steps
190    pub fn step_size_change(&mut self, step_size: usize)
191    {
192        self.step_size = step_size;
193    }
194
195    /// # How many steps were performed until now?
196    #[inline(always)]
197    pub fn step_count(&self) -> u64
198    {
199        self.step_count
200    }
201
202    /// # How many successful replica exchanges were performed until now?
203    pub fn replica_exchanges(&self) -> usize
204    {
205        self.re
206    }
207
208    /// # How many replica exchanges were proposed until now?
209    pub fn proposed_replica_exchanges(&self) -> u64
210    {
211        self.proposed_re
212    }
213
214    /// fraction of how many replica exchanges were accepted and how many were proposed
215    pub fn replica_exchange_frac(&self) -> f64
216    {
217        self.re as f64 / self.proposed_re as f64
218    }
219
220    /// # How many markov steps were rejected until now
221    #[inline(always)]
222    pub fn rejected_markov_steps(&self) -> u64
223    {
224        self.rejected_markov_steps
225    }
226
227    /// # rate/fraction of acceptance
228    pub fn acceptance_rate_markov(&self) -> f64
229    {
230        let rej = self.rejected_markov_steps() as f64 / self.step_count() as f64;
231        1.0 - rej
232    }
233
234    /// Current non normalized estimate of the natural logarithm of the probability density function
235    pub fn log_density(&self) -> &[f64]
236    {
237        &self.log_density
238    }
239
240    fn count_rejected(&mut self)
241    {
242        self.rejected_markov_steps += 1;
243    }
244}
245
246impl <R, Hist, Energy, S, Res> RewlWalker<R, Hist, Energy, S, Res> 
247    where Hist: Histogram + HistogramVal<Energy>,
248{
249    fn log_f_1_t(&self) -> f64
250    { 
251        self.hist.bin_count() as f64 / self.step_count as f64
252    }
253}
254
255impl<R, Hist, Energy, S, Res> RewlWalker<R, Hist, Energy, S, Res> 
256where R: Rng + Send + Sync,
257    Self: Send + Sync,
258    Hist: Histogram + HistogramVal<Energy>,
259    Energy: Send + Sync,
260    S: Send + Sync,
261    Res: Send + Sync
262{
263    pub(crate) fn new
264    (
265        id: usize,
266        rng: R,
267        hist: Hist,
268        sweep_size: NonZeroUsize,
269        step_size: usize,
270        old_energy: Energy,
271    ) -> RewlWalker<R, Hist, Energy, S, Res>
272    {
273        let log_density = vec![0.0; hist.bin_count()];
274        let bin = hist.get_bin_index(&old_energy).unwrap();
275        let markov_steps = Vec::with_capacity(step_size);
276        RewlWalker{
277            id,
278            rng,
279            hist,
280            log_density,
281            sweep_size,
282            log_f: 1.0,
283            step_count: 0,
284            re: 0,
285            proposed_re: 0,
286            mode: WangLandauMode::RefineOriginal,
287            old_energy,
288            bin,
289            marker_res: PhantomData::<Res>,
290            markov_steps,
291            step_size,
292            rejected_markov_steps: 0,
293            #[cfg(feature = "sweep_time_optimization")]
294            duration: Duration::from_millis(0),
295            #[cfg(feature = "sweep_stats")]
296            sweep_stats: SweepStats::new(),
297        }
298    }
299
300    
301
302    /// # Current estimate of log10 of probability density
303    /// * normalized (sum over non log values is 1 (within numerical precision))
304    pub fn log10_density(&self) -> Vec<f64>
305    {
306        log_density_to_log10_density(self.log_density())
307    }
308
309    pub(crate) fn all_bins_reached(&self) -> bool
310    {
311        !self.hist.any_bin_zero()
312    }
313
314    pub(crate) fn refine_f_reset_hist(&mut self)
315    {
316        // Check if log_f should be halfed or mode should be changed
317        if self.mode.is_mode_original() && !self.hist.any_bin_zero() {
318            let ref_1_t = self.log_f_1_t();
319            self.log_f *= 0.5;
320
321            if self.log_f < ref_1_t {
322                self.log_f = ref_1_t;
323                self.mode = WangLandauMode::Refine1T;
324            }
325        }
326        self.hist.reset();
327    }
328
329    pub(crate) fn check_energy_fn<F, Ensemble>(
330        &self,
331        ensemble_vec: &[RwLock<Ensemble>],
332        energy_fn: F
333    )   -> bool
334    where Energy: PartialEq,F: Fn(&mut Ensemble) -> Option<Energy>,
335        
336    {
337        let mut e = ensemble_vec[self.id]
338            .write()
339            .expect("Fatal Error encountered; ERRORCODE 0x5 - this should be \
340                impossible to reach. If you are using the latest version of the \
341                'sampling' library, please contact the library author via github by opening an \
342                issue! https://github.com/Pardoxa/sampling/issues");
343        
344        let energy = match energy_fn(&mut e){
345            Some(energy) => energy,
346            None => {
347                return false;
348            }
349        };
350        energy == self.old_energy
351    }
352
353    pub(crate) fn wang_landau_sweep<Ensemble, F>
354    (
355        &mut self,
356        ensemble_vec: &[RwLock<Ensemble>],
357        energy_fn: F
358    )
359    where F: Fn(&mut Ensemble) -> Option<Energy>,
360        Ensemble: MarkovChain<S, Res>
361    {
362        #[cfg(feature = "sweep_time_optimization")]
363        let start = Instant::now();
364
365        let mut e = ensemble_vec[self.id]
366            .write()
367            .expect("Fatal Error encountered; ERRORCODE 0x6 - this should be \
368                impossible to reach. If you are using the latest version of the \
369                'sampling' library, please contact the library author via github by opening an \
370                issue! https://github.com/Pardoxa/sampling/issues");
371        
372        for _ in 0..self.sweep_size.get()
373        {   
374            self.step_count = self.step_count.saturating_add(1);
375            e.m_steps(self.step_size, &mut self.markov_steps);
376
377
378            let energy = match energy_fn(&mut e){
379                Some(energy) => energy,
380                None => {
381                    self.count_rejected();
382                    e.undo_steps_quiet(&self.markov_steps);
383                    continue;
384                }
385            };
386
387            
388            if self.mode.is_mode_1_t() {
389                self.log_f = self.log_f_1_t();
390            }
391
392            match self.hist.get_bin_index(&energy) 
393            {
394                Ok(current_bin) => {
395                    // metropolis hastings
396                    let acception_prob = (self.log_density[self.bin] - self.log_density[current_bin])
397                        .exp();
398                    if self.rng.random::<f64>() > acception_prob 
399                    {
400                        e.steps_rejected(&self.markov_steps);
401                        self.count_rejected();
402                        e.undo_steps_quiet(&self.markov_steps);
403                    } else {
404                        self.old_energy = energy;
405                        self.bin = current_bin;
406                        e.steps_accepted(&self.markov_steps);
407                    }
408                },
409                _ => {
410                    e.steps_rejected(&self.markov_steps);
411                    self.count_rejected();
412                    e.undo_steps_quiet(&self.markov_steps);
413                }
414            }
415
416            self.hist.increment_index(self.bin)
417                .expect("Histogram index Error, ERRORCODE 0x7");
418            
419            self.log_density[self.bin] += self.log_f;
420
421        }
422        #[cfg(feature = "sweep_time_optimization")]
423            {
424                self.duration = start.elapsed();
425                #[cfg(feature = "sweep_stats")]
426                self.sweep_stats.push(self.duration);
427            }
428    }
429}
430
431
432pub(crate) fn merge_walker_prob<R, Hist, Energy, S, Res>(walker: &mut [RewlWalker<R, Hist, Energy, S, Res>])
433{
434    
435    if walker.len() < 2 {
436        return;
437    }
438    let averaged = get_merged_walker_prob(walker);
439    
440    walker.iter_mut()
441        .skip(1)
442        .for_each(
443            |w|
444            {
445                w.log_density
446                    .copy_from_slice(&averaged)
447            }
448        );
449    walker[0].log_density = averaged;
450}
451
452pub(crate) fn get_merged_walker_prob<R, Hist, Energy, S, Res>(walker: &[RewlWalker<R, Hist, Energy, S, Res>]) -> Vec<f64>
453{
454    let log_len = walker[0].log_density.len();
455    debug_assert!(
456        walker.iter()
457            .all(|w| w.log_density.len() == log_len)
458    );
459
460    let mut averaged_log_density = walker[0].log_density
461        .clone();
462
463    if walker.len() > 1 {
464    
465        walker[1..].iter()
466            .for_each(
467                |w|
468                {
469                    averaged_log_density.iter_mut()
470                        .zip(w.log_density.iter())
471                        .for_each(
472                            |(average, other)|
473                            {
474                                *average += other;
475                            }
476                        )
477                }
478            );
479    
480        let number_of_walkers = walker.len() as f64;
481        averaged_log_density.iter_mut()
482            .for_each(|average| *average /= number_of_walkers);
483    }
484
485    averaged_log_density
486}
487
488pub(crate) fn replica_exchange<R, Hist, Energy, S, Res>
489(
490    walker_a: &mut RewlWalker<R, Hist, Energy, S, Res>,
491    walker_b: &mut RewlWalker<R, Hist, Energy, S, Res>
492) where Hist: Histogram + HistogramVal<Energy>,
493    R: Rng
494{
495    walker_a.proposed_re += 1;
496    walker_b.proposed_re += 1;
497    // check if exchange is even possible
498    let new_bin_a = match walker_a.hist.get_bin_index(&walker_b.old_energy)
499    {
500        Ok(bin) => bin,
501        _ => return,
502    };
503
504    let new_bin_b = match walker_b.hist.get_bin_index(&walker_a.old_energy)
505    {
506        Ok(bin) => bin,
507        _ => return,
508    };
509
510    // see paper equation 1
511    let log_gi_x = walker_a.log_density[walker_a.bin];
512    let log_gi_y = walker_a.log_density[new_bin_a];
513
514    let log_gj_y = walker_b.log_density[walker_b.bin];
515    let log_gj_x = walker_b.log_density[new_bin_b];
516
517    let log_prob = log_gi_x + log_gj_y - log_gi_y - log_gj_x;
518
519    let prob = log_prob.exp();
520
521    // if exchange is accepted
522    if walker_b.rng.random::<f64>() < prob 
523    {
524        swap(&mut walker_b.id, &mut walker_a.id);
525        swap(&mut walker_b.old_energy, &mut walker_a.old_energy);
526        walker_b.bin = new_bin_b;
527        walker_a.bin = new_bin_a;
528        walker_b.re +=1;
529        walker_a.re +=1;
530
531    }
532    {
533        if walker_a.mode.is_mode_1_t() {
534            walker_a.log_f =  walker_a.log_f_1_t();
535        }
536    
537        if walker_b.mode.is_mode_1_t() {
538            walker_b.log_f =  walker_b.log_f_1_t();
539        }
540    
541        walker_a.hist.increment_index(walker_a.bin)
542                    .expect("Histogram index Error, ERRORCODE 0x8");
543        walker_a.log_density[walker_a.bin] += walker_a.log_f;
544    
545        walker_b.hist.increment_index(walker_b.bin)
546                    .expect("Histogram index Error, ERRORCODE 0x8");
547        walker_b.log_density[walker_b.bin] += walker_b.log_f;
548    }
549}