sampling/glue/
glue_job.rs

1use super::{
2    glue_helper::{ln_to_log10, log10_to_ln},
3    glue_writer::*,
4    LogBase,
5};
6use crate::histogram::*;
7use std::{borrow::Borrow, num::NonZeroUsize};
8
9#[cfg(feature = "serde_support")]
10use serde::{Deserialize, Serialize};
11/// Trait for objects that can contribute to a [GlueJob]
12pub trait GlueAble<H> {
13    /// Add `self` to the [GlueJob]
14    fn push_glue_entry(&self, job: &mut GlueJob<H>) {
15        self.push_glue_entry_ignoring(job, &[])
16    }
17
18    /// Add `self`to the [GlueJob], but ignore some indices
19    fn push_glue_entry_ignoring(&self, job: &mut GlueJob<H>, ignore_idx: &[usize]);
20}
21
22#[derive(Clone, Copy, Debug)]
23#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
24/// Enum to track simulation type
25pub enum SimulationType {
26    /// Simulation was a 1/t Wang-Landau simulation
27    WangLandau1T = 0,
28    /// Simulation was an adaptive 1/t Wang-Landau simulation
29    WangLandau1TAdaptive = 1,
30    /// Simulation was Entropic sampling
31    Entropic = 2,
32    /// Simulation was adaptive Entropic sampling
33    EntropicAdaptive = 3,
34    /// Simulation was Replica exchange Wang Landau (1/t)
35    REWL = 4,
36    /// Simulation was Replica exchange Entropic sampling
37    REES = 5,
38    /// Simulation type is unknown
39    Unknown = 6,
40}
41
42impl SimulationType {
43    /// # Name of simulation type as &str
44    pub fn name(self) -> &'static str {
45        match self {
46            Self::Entropic => "Entropic",
47            Self::WangLandau1T => "WangLandau1T",
48            Self::EntropicAdaptive => "EntropicAdaptive",
49            Self::WangLandau1TAdaptive => "WangLandau1TAdaptive",
50            Self::REES => "REES",
51            Self::REWL => "REWL",
52            Self::Unknown => "Unknown",
53        }
54    }
55
56    pub(crate) fn from_usize(num: usize) -> Self {
57        match num {
58            0 => Self::WangLandau1T,
59            1 => Self::WangLandau1TAdaptive,
60            2 => Self::Entropic,
61            3 => Self::EntropicAdaptive,
62            4 => Self::REWL,
63            5 => Self::REES,
64            6 => Self::Unknown,
65            _ => unreachable!(),
66        }
67    }
68}
69
70pub(crate) struct AccumulatedIntervalStats {
71    worst_log_progress: f64,
72    worst_missing_steps_progress: u64,
73    log_progress_counter: u32,
74    missing_steps_progress_counter: u32,
75    unknown_progress_counter: u32,
76    interval_sim_type_counter: [usize; 7],
77    total_rejected_steps: u64,
78    total_accepted_steps: u64,
79    total_proposed_replica_exchanges: u64,
80    total_replica_exchanges: u64,
81    potential_for_replica_exchanges: bool,
82    potential_for_proposed_replica_exchanges: bool,
83}
84
85impl AccumulatedIntervalStats {
86    pub(crate) fn write<W: std::io::Write>(&self, mut writer: W) -> std::io::Result<()> {
87        let total_intervals: usize = self.interval_sim_type_counter.iter().sum();
88        writeln!(writer, "#Accumulated Stats of {total_intervals} Intervals")?;
89        if self.log_progress_counter > 0 {
90            writeln!(
91                writer,
92                "#Worst log progress: {} - out of {} intervals that tracked log progress",
93                self.worst_log_progress, self.log_progress_counter
94            )?;
95        }
96        if self.missing_steps_progress_counter > 0 {
97            writeln!(
98                writer,
99                "#Worst missing steps progress: {} - out of {} intervals that tracked missing steps progress",
100                self.worst_missing_steps_progress,
101                self.missing_steps_progress_counter
102            )?;
103        }
104        if self.unknown_progress_counter > 0 {
105            writeln!(
106                writer,
107                "# {} Intervals had unknown progress",
108                self.unknown_progress_counter
109            )?
110        }
111
112        for (index, &amount) in self.interval_sim_type_counter.iter().enumerate() {
113            if amount > 0 {
114                let sim_type = SimulationType::from_usize(index);
115                writeln!(
116                    writer,
117                    "#{} contributed {} intervals",
118                    sim_type.name(),
119                    amount
120                )?;
121            }
122        }
123
124        let a = self.total_accepted_steps;
125        let r = self.total_rejected_steps;
126        let total = a + r;
127        writeln!(
128            writer,
129            "#TOTAL: {a} accepted and {r} rejected steps, which makes a total of {total} steps"
130        )?;
131        let a_rate = a as f64 / total as f64;
132        writeln!(writer, "#TOTAL acceptance rate {a_rate}")?;
133        let r_rate = r as f64 / total as f64;
134        writeln!(writer, "#TOTAL rejection rate {r_rate}")?;
135
136        if self.potential_for_replica_exchanges {
137            writeln!(
138                writer,
139                "#TOTAL performed replica exchanges: {}",
140                self.total_replica_exchanges
141            )?;
142        }
143        if self.potential_for_proposed_replica_exchanges {
144            writeln!(
145                writer,
146                "#TOTAL proposed replica exchanges: {}",
147                self.total_proposed_replica_exchanges
148            )?;
149            if self.potential_for_replica_exchanges {
150                let rate = self.total_replica_exchanges as f64
151                    / self.total_proposed_replica_exchanges as f64;
152                writeln!(writer, "#rate of accepting replica exchanges: {rate}")?;
153            }
154        }
155        Ok(())
156    }
157
158    pub(crate) fn generate_stats(interval_stats: &[IntervalSimStats]) -> Self {
159        let mut acc = AccumulatedIntervalStats {
160            worst_log_progress: f64::NEG_INFINITY,
161            worst_missing_steps_progress: 0,
162            log_progress_counter: 0,
163            missing_steps_progress_counter: 0,
164            unknown_progress_counter: 0,
165            interval_sim_type_counter: [0; 7],
166            total_accepted_steps: 0,
167            total_rejected_steps: 0,
168            total_proposed_replica_exchanges: 0,
169            total_replica_exchanges: 0,
170            potential_for_proposed_replica_exchanges: false,
171            potential_for_replica_exchanges: false,
172        };
173
174        for stats in interval_stats.iter() {
175            acc.interval_sim_type_counter[stats.interval_sim_type as usize] += 1;
176            match stats.sim_progress {
177                SimProgress::LogF(log_f) => {
178                    acc.log_progress_counter += 1;
179                    acc.worst_log_progress = acc.worst_log_progress.max(log_f);
180                }
181                SimProgress::MissingSteps(missing) => {
182                    acc.missing_steps_progress_counter += 1;
183                    acc.worst_missing_steps_progress =
184                        acc.worst_missing_steps_progress.max(missing);
185                }
186                SimProgress::Unknown => {
187                    acc.unknown_progress_counter += 1;
188                }
189            }
190
191            acc.total_accepted_steps += stats.accepted_steps;
192            acc.total_rejected_steps += stats.rejected_steps;
193            if let Some(replica) = stats.replica_exchanges {
194                acc.potential_for_replica_exchanges = true;
195                acc.total_replica_exchanges += replica;
196            }
197            if let Some(proposed) = stats.proposed_replica_exchanges {
198                acc.potential_for_proposed_replica_exchanges = true;
199                acc.total_proposed_replica_exchanges += proposed;
200            }
201        }
202        acc
203    }
204}
205
206#[derive(Clone, Copy, Debug)]
207#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
208/// Enum which contains information about the current progress of the simulation
209pub enum SimProgress {
210    /// The logarithm of the factor f. Useful, since we often want to simulate until we hit a target value for log(f)
211    LogF(f64),
212    /// How many steps do we still need to perform?
213    MissingSteps(u64),
214    /// The simulation progress is unknown
215    Unknown,
216}
217
218/// Statistics of one interval, used to gauge how well
219/// the simulation works etc.
220#[derive(Clone, Debug)]
221#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
222pub struct IntervalSimStats {
223    /// the progress of the Interval
224    pub sim_progress: SimProgress,
225    /// Which type of simulation did the interval come from
226    pub interval_sim_type: SimulationType,
227    /// How many steps were rejected in total in the interval
228    pub rejected_steps: u64,
229    /// How many steps were accepted in total in the interval
230    pub accepted_steps: u64,
231    /// How many replica exchanges were performed?
232    /// None for Simulations that don't do replica exchanges
233    pub replica_exchanges: Option<u64>,
234    /// How many replica exchanges were proposed?
235    /// None for simulations that do not perform replica exchanges
236    pub proposed_replica_exchanges: Option<u64>,
237    /// The number of walkers used to generate this sim.
238    /// In Replica exchange sims you can have more than one walker
239    /// per interval, which is where this comes from
240    pub merged_over_walkers: NonZeroUsize,
241}
242
243impl IntervalSimStats {
244    /// # Write Stats to file
245    /// Use this function to output the simulation statistics of an interval to a file.
246    ///
247    /// Every line will be preceded by an '#' to mark it as comment
248    ///
249    /// # Contained information
250    /// * Simulation type
251    /// * Progress
252    /// * How many walkers were used for this interval?
253    /// * Rejection/Acceptance rate
254    /// * If applicable: Number of replica exchanges and acceptance rate of replica exchanges
255    pub fn write<W: std::io::Write>(&self, mut writer: W) -> std::io::Result<()> {
256        writeln!(
257            writer,
258            "#Simulated via: {:?}",
259            self.interval_sim_type.name()
260        )?;
261        writeln!(writer, "#progress {:?}", self.sim_progress)?;
262        if self.merged_over_walkers.get() == 1 {
263            writeln!(writer, "#created from a single walker")?;
264        } else {
265            writeln!(
266                writer,
267                "#created from merging {} walkers",
268                self.merged_over_walkers
269            )?;
270        }
271
272        let a = self.accepted_steps;
273        let r = self.rejected_steps;
274        let total = a + r;
275        writeln!(
276            writer,
277            "#had {a} accepted and {r} rejected steps, which makes a total of {total} steps"
278        )?;
279        let a_rate = a as f64 / total as f64;
280        writeln!(writer, "#acceptance rate {a_rate}")?;
281        let r_rate = r as f64 / total as f64;
282        writeln!(writer, "#rejection rate {r_rate}")?;
283
284        if let Some(replica) = self.replica_exchanges {
285            writeln!(writer, "#performed replica exchanges: {replica}")?;
286        }
287        if let Some(proposed) = self.proposed_replica_exchanges {
288            writeln!(writer, "#proposed replica exchanges: {proposed}")?;
289            if let Some(replica) = self.replica_exchanges {
290                let rate = replica as f64 / proposed as f64;
291                writeln!(writer, "#rate of accepting replica exchanges: {rate}")?;
292            }
293        }
294        Ok(())
295    }
296}
297
298#[derive(Clone, Debug)]
299#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
300/// # Struct that is used to create a glue job
301pub struct GlueEntry<H> {
302    /// The histogram
303    pub hist: H,
304    /// The probability density distribution
305    pub prob: Vec<f64>,
306    /// Information about which logarithm base was used to store the probability density distribution
307    pub log_base: LogBase,
308    /// Statistics about the intervals
309    pub interval_stats: IntervalSimStats,
310}
311
312impl<H> Borrow<H> for GlueEntry<H> {
313    fn borrow(&self) -> &H {
314        &self.hist
315    }
316}
317
318/// # Used to merge probability densities from WL, REWL, Entropic or REES simulations
319/// * You can also mix those methods and still glue them
320#[derive(Clone, Debug)]
321#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
322pub struct GlueJob<H> {
323    /// Contains all the Intervals to glue
324    pub collection: Vec<GlueEntry<H>>,
325    /// Contains information about the number of roundtrips of the walkers used for this gluing job
326    pub round_trips: Vec<usize>,
327    /// The logarithm base that we want our final output to be in
328    pub desired_logbase: LogBase,
329}
330
331impl<H> GlueJob<H>
332where
333    H: Clone,
334{
335    /// Create a new glue job from something GlueAble See [GlueAble]
336    ///
337    /// You need to specify the desired Logarithm base of the final output
338    pub fn new<B>(to_glue: &B, desired_logbase: LogBase) -> Self
339    where
340        B: GlueAble<H>,
341    {
342        let mut job = Self {
343            collection: Vec::new(),
344            round_trips: Vec::new(),
345            desired_logbase,
346        };
347
348        to_glue.push_glue_entry(&mut job);
349        job
350    }
351
352    /// Create a glue job from a slice of [GlueAble] objects
353    ///
354    /// You need to specify the desired Logarithm base of the final output
355    pub fn new_from_slice<B>(to_glue: &[B], desired_logbase: LogBase) -> Self
356    where
357        B: GlueAble<H>,
358    {
359        Self::new_from_iter(to_glue.iter(), desired_logbase)
360    }
361
362    /// Create a glue job from an iterator of [GlueAble] objects
363    ///
364    /// You need to specify the desired Logarithm base of the final output
365    pub fn new_from_iter<'a, B, I>(to_glue: I, desired_logbase: LogBase) -> Self
366    where
367        B: GlueAble<H> + 'a,
368        I: Iterator<Item = &'a B>,
369    {
370        let mut job = Self {
371            collection: Vec::new(),
372            round_trips: Vec::new(),
373            desired_logbase,
374        };
375
376        job.add_iter(to_glue);
377        job
378    }
379
380    /// Add a slice of [GlueAble] objects to the glue job
381    pub fn add_slice<B>(&mut self, to_glue: &[B])
382    where
383        B: GlueAble<H>,
384    {
385        self.add_iter(to_glue.iter())
386    }
387
388    /// Add [GlueAble] objects via an iterator
389    pub fn add_iter<'a, I, B>(&mut self, to_glue: I)
390    where
391        B: GlueAble<H> + 'a,
392        I: Iterator<Item = &'a B>,
393    {
394        for entry in to_glue {
395            entry.push_glue_entry(self);
396        }
397    }
398
399    /// Get statistics of the current glue job. See [GlueStats]
400    pub fn get_stats(&self) -> GlueStats {
401        let interval_stats = self
402            .collection
403            .iter()
404            .map(|e| e.interval_stats.clone())
405            .collect();
406        GlueStats {
407            interval_stats,
408            roundtrips: self.round_trips.clone(),
409        }
410    }
411
412    /// # Calculate the probability density function from overlapping intervals
413    ///
414    /// This uses a average merge, which first align all intervals and then merges
415    /// the probability densities by averaging in the logarithmic space
416    ///
417    /// The [Glued] allows you to easily write the probability density function to a file
418    pub fn average_merged_and_aligned<T>(&mut self) -> Result<Glued<H, T>, HistErrors>
419    where
420        H: Histogram + HistogramCombine + HistogramVal<T>,
421        T: PartialOrd,
422    {
423        let log_prob = self.prepare_for_merge()?;
424        let mut res = average_merged_and_aligned(log_prob, &self.collection, self.desired_logbase)?;
425        let stats = self.get_stats();
426        res.set_stats(stats);
427        Ok(res)
428    }
429
430    /// # Calculate the probability density function from overlapping intervals
431    ///
432    /// This uses a derivative merge
433    ///
434    /// The [Glued] allows you to easily write the probability density function to a file
435    pub fn derivative_glue_and_align<T>(&mut self) -> Result<Glued<H, T>, HistErrors>
436    where
437        H: Histogram + HistogramCombine + HistogramVal<T>,
438        T: PartialOrd,
439    {
440        let log_prob = self.prepare_for_merge()?;
441        let mut res =
442            derivative_merged_and_aligned(log_prob, &self.collection, self.desired_logbase)?;
443        let stats = self.get_stats();
444        res.set_stats(stats);
445        Ok(res)
446    }
447
448    fn prepare_for_merge<T>(&mut self) -> Result<Vec<Vec<f64>>, HistErrors>
449    where
450        H: Histogram + HistogramCombine + HistogramVal<T>,
451        T: PartialOrd,
452    {
453        self.make_entries_desired_logbase();
454
455        let mut encountered_invalid = false;
456
457        self.collection.sort_unstable_by(|a, b| {
458            match a.hist.first_border().partial_cmp(&b.hist.first_border()) {
459                None => {
460                    encountered_invalid = true;
461                    std::cmp::Ordering::Less
462                }
463                Some(o) => o,
464            }
465        });
466        if encountered_invalid {
467            return Err(HistErrors::InvalidVal);
468        }
469
470        Ok(self.collection.iter().map(|e| e.prob.clone()).collect())
471    }
472
473    fn make_entries_desired_logbase(&mut self) {
474        for e in self.collection.iter_mut() {
475            match self.desired_logbase {
476                LogBase::Base10 => {
477                    if e.log_base.is_base_e() {
478                        e.log_base = LogBase::Base10;
479                        ln_to_log10(&mut e.prob)
480                    }
481                }
482                LogBase::BaseE => {
483                    if e.log_base.is_base10() {
484                        e.log_base = LogBase::BaseE;
485                        log10_to_ln(&mut e.prob)
486                    }
487                }
488            }
489        }
490    }
491}