sampling/glue/
glue_job.rs

1
2use std::{borrow::Borrow, num::NonZeroUsize};
3use crate::histogram::*;
4use super::{
5    glue_writer::*,
6    glue_helper::{
7        ln_to_log10,
8        log10_to_ln
9    },
10    LogBase
11};
12
13#[cfg(feature = "serde_support")]
14use serde::{Serialize, Deserialize};
15/// Trait for objects that can contribute to a [GlueJob]
16pub trait GlueAble<H>{
17    /// Add `self` to the [GlueJob]
18    fn push_glue_entry(&self, job: &mut GlueJob<H>)
19    {
20        self.push_glue_entry_ignoring(job, &[])
21    }
22
23    /// Add `self`to the [GlueJob], but ignore some indices
24    fn push_glue_entry_ignoring(
25        &self, 
26        job: &mut GlueJob<H>,
27        ignore_idx: &[usize]
28    );
29}
30
31#[derive(Clone, Copy, Debug)]
32#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
33/// Enum to track simulation type
34pub enum SimulationType{
35    /// Simulation was a 1/t Wang-Landau simulation
36    WangLandau1T = 0,
37    /// Simulation was an adaptive 1/t Wang-Landau simulation
38    WangLandau1TAdaptive = 1,
39    /// Simulation was Entropic sampling
40    Entropic = 2,
41    /// Simulation was adaptive Entropic sampling
42    EntropicAdaptive = 3,
43    /// Simulation was Replica exchange Wang Landau (1/t)
44    REWL = 4,
45    /// Simulation was Replica exchange Entropic sampling
46    REES = 5,
47    /// Simulation type is unknown
48    Unknown = 6
49}
50
51impl SimulationType
52{
53    /// # Name of simulation type as &str
54    pub fn name(self) -> &'static str
55    {
56        match self{
57            Self::Entropic => "Entropic",
58            Self::WangLandau1T => "WangLandau1T",
59            Self::EntropicAdaptive => "EntropicAdaptive",
60            Self::WangLandau1TAdaptive => "WangLandau1TAdaptive",
61            Self::REES => "REES",
62            Self::REWL => "REWL",
63            Self::Unknown => "Unknown"
64        }
65    }
66
67    pub(crate) fn from_usize(num: usize) -> Self
68    {
69        match num{
70            0 => Self::WangLandau1T,
71            1 => Self::WangLandau1TAdaptive,
72            2 => Self::Entropic,
73            3 => Self::EntropicAdaptive,
74            4 => Self::REWL,
75            5 => Self::REES,
76            6 => Self::Unknown,
77            _ => unreachable!()
78        }
79    } 
80}
81
82pub(crate) struct AccumulatedIntervalStats{
83    worst_log_progress: f64,
84    worst_missing_steps_progress: u64,
85    log_progress_counter: u32,
86    missing_steps_progress_counter: u32,
87    unknown_progress_counter: u32,
88    interval_sim_type_counter: [usize; 7],
89    total_rejected_steps: u64,
90    total_accepted_steps: u64,
91    total_proposed_replica_exchanges: u64,
92    total_replica_exchanges: u64,
93    potential_for_replica_exchanges: bool,
94    potential_for_proposed_replica_exchanges: bool
95}
96
97impl AccumulatedIntervalStats{
98
99    pub(crate) fn write<W: std::io::Write>(&self, mut writer: W) -> std::io::Result<()>
100    {
101        let total_intervals: usize = self
102            .interval_sim_type_counter
103            .iter()
104            .sum();
105        writeln!(writer, "#Accumulated Stats of {total_intervals} Intervals")?;
106        if self.log_progress_counter > 0 {
107            writeln!(
108                writer,
109                "#Worst log progress: {} - out of {} intervals that tracked log progress",
110                self.worst_log_progress,
111                self.log_progress_counter
112            )?;
113        }
114        if self.missing_steps_progress_counter > 0 {
115            writeln!(
116                writer,
117                "#Worst missing steps progress: {} - out of {} intervals that tracked missing steps progress",
118                self.worst_missing_steps_progress,
119                self.missing_steps_progress_counter
120            )?;
121        }
122        if self.unknown_progress_counter > 0 {
123            writeln!(writer, "# {} Intervals had unknown progress", self.unknown_progress_counter)?
124        }
125        
126        for (index, &amount) in self.interval_sim_type_counter.iter().enumerate()
127        {
128            if amount > 0 {
129                let sim_type = SimulationType::from_usize(index);
130                writeln!(writer, "#{} contributed {} intervals", sim_type.name(), amount)?;
131            }
132        }
133
134        let a = self.total_accepted_steps;
135        let r = self.total_rejected_steps;
136        let total = a + r;
137        writeln!(writer, "#TOTAL: {a} accepted and {r} rejected steps, which makes a total of {total} steps")?;
138        let a_rate = a as f64 / total as f64;
139        writeln!(writer, "#TOTAL acceptance rate {a_rate}")?;
140        let r_rate = r as f64 / total as f64;
141        writeln!(writer, "#TOTAL rejection rate {r_rate}")?;
142
143        if self.potential_for_replica_exchanges {
144            writeln!(writer, "#TOTAL performed replica exchanges: {}", self.total_replica_exchanges)?;
145        }
146        if self.potential_for_proposed_replica_exchanges
147        {
148            writeln!(writer, "#TOTAL proposed replica exchanges: {}", self.total_proposed_replica_exchanges)?;
149            if self.potential_for_replica_exchanges{
150                let rate = self.total_replica_exchanges as f64 / self.total_proposed_replica_exchanges as f64;
151                writeln!(writer, "#rate of accepting replica exchanges: {rate}")?;
152            }
153        }
154        Ok(())
155    }
156
157    pub(crate) fn generate_stats(interval_stats: &[IntervalSimStats]) -> Self
158    {
159        let mut acc = AccumulatedIntervalStats{
160            worst_log_progress: f64::NEG_INFINITY,
161            worst_missing_steps_progress: 0,
162            log_progress_counter: 0,
163            missing_steps_progress_counter: 0,
164            unknown_progress_counter: 0,
165            interval_sim_type_counter: [0;7],
166            total_accepted_steps: 0,
167            total_rejected_steps: 0,
168            total_proposed_replica_exchanges: 0,
169            total_replica_exchanges: 0,
170            potential_for_proposed_replica_exchanges: false,
171            potential_for_replica_exchanges: false
172        };
173
174        for stats in interval_stats.iter()
175        {
176            acc.interval_sim_type_counter[stats.interval_sim_type as usize] += 1;
177            match stats.sim_progress{
178                SimProgress::LogF(log_f) => {
179                    acc.log_progress_counter += 1;
180                    acc.worst_log_progress = acc.worst_log_progress.max(log_f);
181                },
182                SimProgress::MissingSteps(missing) => {
183                    acc.missing_steps_progress_counter += 1;
184                    acc.worst_missing_steps_progress = acc.worst_missing_steps_progress.max(missing);
185                },
186                SimProgress::Unknown => {
187                    acc.unknown_progress_counter += 1;
188                }
189            }
190
191            acc.total_accepted_steps += stats.accepted_steps;
192            acc.total_rejected_steps += stats.rejected_steps;
193            if let Some(replica) = stats.replica_exchanges{
194                acc.potential_for_replica_exchanges = true;
195                acc.total_replica_exchanges += replica;
196            }
197            if let Some(proposed) = stats.proposed_replica_exchanges
198            {
199                acc.potential_for_proposed_replica_exchanges = true;
200                acc.total_proposed_replica_exchanges += proposed;
201            }
202        }
203        acc
204    }
205}
206
207#[derive(Clone, Copy, Debug)]
208#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
209/// Enum which contains information about the current progress of the simulation
210pub enum SimProgress{
211    /// The logarithm of the factor f. Useful, since we often want to simulate until we hit a target value for log(f)
212    LogF(f64),
213    /// How many steps do we still need to perform?
214    MissingSteps(u64),
215    /// The simulation progress is unknown
216    Unknown
217}
218
219/// Statistics of one interval, used to gauge how well
220/// the simulation works etc.
221#[derive(Clone, Debug)]
222#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
223pub struct IntervalSimStats{
224    /// the progress of the Interval
225    pub sim_progress: SimProgress,
226    /// Which type of simulation did the interval come from
227    pub interval_sim_type: SimulationType,
228    /// How many steps were rejected in total in the interval
229    pub rejected_steps: u64,
230    /// How many steps were accepted in total in the interval
231    pub accepted_steps: u64,
232    /// How many replica exchanges were performed?
233    /// None for Simulations that don't do replica exchanges
234    pub replica_exchanges: Option<u64>,
235    /// How many replica exchanges were proposed?
236    /// None for simulations that do not perform replica exchanges
237    pub proposed_replica_exchanges: Option<u64>,
238    /// The number of walkers used to generate this sim.
239    /// In Replica exchange sims you can have more than one walker 
240    /// per interval, which is where this comes from
241    pub merged_over_walkers: NonZeroUsize
242}
243
244impl IntervalSimStats{
245    /// # Write Stats to file
246    /// Use this function to output the simulation statistics of an interval to a file.
247    /// 
248    /// Every line will be preceded by an '#' to mark it as comment
249    /// 
250    /// # Contained information
251    /// * Simulation type
252    /// * Progress
253    /// * How many walkers were used for this interval?
254    /// * Rejection/Acceptance rate
255    /// * If applicable: Number of replica exchanges and acceptance rate of replica exchanges
256    pub fn write<W: std::io::Write>(&self, mut writer: W) -> std::io::Result<()>
257    {
258        writeln!(writer, "#Simulated via: {:?}", self.interval_sim_type.name())?;
259        writeln!(writer, "#progress {:?}", self.sim_progress)?;
260        if self.merged_over_walkers.get() == 1 {
261            writeln!(writer, "#created from a single walker")?;
262        } else {
263            writeln!(writer, "#created from merging {} walkers", self.merged_over_walkers)?;
264        }
265        
266        let a = self.accepted_steps;
267        let r = self.rejected_steps;
268        let total = a + r;
269        writeln!(writer, "#had {a} accepted and {r} rejected steps, which makes a total of {total} steps")?;
270        let a_rate = a as f64 / total as f64;
271        writeln!(writer, "#acceptance rate {a_rate}")?;
272        let r_rate = r as f64 / total as f64;
273        writeln!(writer, "#rejection rate {r_rate}")?;
274
275        if let Some(replica) = self.replica_exchanges {
276            writeln!(writer, "#performed replica exchanges: {replica}")?;
277        }
278        if let Some(proposed) = self.proposed_replica_exchanges
279        {
280            writeln!(writer, "#proposed replica exchanges: {proposed}")?;
281            if let Some(replica) = self.replica_exchanges{
282                let rate = replica as f64 / proposed as f64;
283                writeln!(writer, "#rate of accepting replica exchanges: {rate}")?;
284            }
285        }
286        Ok(())
287    }
288}
289
290#[derive(Clone, Debug)]
291#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
292/// # Struct that is used to create a glue job
293pub struct GlueEntry<H>{
294    /// The histogram
295    pub hist: H,
296    /// The probability density distribution
297    pub prob: Vec<f64>,
298    /// Information about which logarithm base was used to store the probability density distribution
299    pub log_base: LogBase,
300    /// Statistics about the intervals
301    pub interval_stats: IntervalSimStats
302}
303
304impl<H> Borrow<H> for GlueEntry<H>
305{
306    fn borrow(&self) -> &H {
307        &self.hist
308    }
309}
310
311/// # Used to merge probability densities from WL, REWL, Entropic or REES simulations
312/// * You can also mix those methods and still glue them
313#[derive(Clone, Debug)]
314#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
315pub struct GlueJob<H>
316{
317    /// Contains all the Intervals to glue
318    pub collection: Vec<GlueEntry<H>>,
319    /// Contains information about the number of roundtrips of the walkers used for this gluing job
320    pub round_trips: Vec<usize>,
321    /// The logarithm base that we want our final output to be in
322    pub desired_logbase: LogBase
323}
324
325impl<H> GlueJob<H>
326    where H: Clone
327{
328    /// Create a new glue job from something GlueAble See [GlueAble]
329    /// 
330    /// You need to specify the desired Logarithm base of the final output
331    pub fn new<B>(
332        to_glue: &B,
333        desired_logbase: LogBase
334    ) -> Self
335    where B: GlueAble<H>
336    {
337        let mut job = Self { 
338            collection: Vec::new(),
339            round_trips: Vec::new(),
340            desired_logbase
341        };
342
343        to_glue.push_glue_entry(&mut job);
344        job
345    }
346
347    /// Create a glue job from a slice of [GlueAble] objects
348    /// 
349    /// You need to specify the desired Logarithm base of the final output
350    pub fn new_from_slice<B>(to_glue: &[B], desired_logbase: LogBase) -> Self
351        where B: GlueAble<H>
352    {
353        Self::new_from_iter(to_glue.iter(), desired_logbase)
354    }
355
356    /// Create a glue job from an iterator of [GlueAble] objects
357    /// 
358    /// You need to specify the desired Logarithm base of the final output
359    pub fn new_from_iter<'a, B, I>(
360        to_glue: I,
361        desired_logbase: LogBase
362    ) -> Self
363    where B: GlueAble<H> + 'a,
364    I: Iterator<Item=&'a B> 
365    {
366        let mut job = Self { 
367            collection: Vec::new(),
368            round_trips: Vec::new(),
369            desired_logbase
370        };
371
372        job.add_iter(to_glue);
373        job
374    }
375
376    /// Add a slice of [GlueAble] objects to the glue job
377    pub fn add_slice<B>(&mut self, to_glue: &[B])
378        where B: GlueAble<H>
379    {
380        self.add_iter(to_glue.iter())
381    }
382
383    /// Add [GlueAble] objects via an iterator
384    pub fn add_iter<'a, I, B>(&mut self, to_glue: I)
385    where B: GlueAble<H> + 'a,
386        I: Iterator<Item=&'a B> 
387    {
388        for entry in to_glue {
389            entry.push_glue_entry(self);
390        }
391    }
392
393    /// Get statistics of the current glue job. See [GlueStats]
394    pub fn get_stats(&self) -> GlueStats
395    {
396        let interval_stats = self
397            .collection
398            .iter()
399            .map(|e| e.interval_stats.clone())
400            .collect();
401        GlueStats{
402            interval_stats,
403            roundtrips: self.round_trips.clone()
404        }
405    }
406
407    /// # Calculate the probability density function from overlapping intervals
408    /// 
409    /// This uses a average merge, which first align all intervals and then merges 
410    /// the probability densities by averaging in the logarithmic space
411    /// 
412    /// The [Glued] allows you to easily write the probability density function to a file
413    pub fn average_merged_and_aligned<T>(&mut self) -> Result<Glued<H, T>, HistErrors>
414    where H: Histogram + HistogramCombine + HistogramVal<T>,
415        T: PartialOrd{
416
417        let log_prob = self.prepare_for_merge()?;
418        let mut res = average_merged_and_aligned(
419            log_prob, 
420            &self.collection, 
421            self.desired_logbase
422        )?;
423        let stats = self.get_stats();
424        res.set_stats(stats);
425        Ok(res)
426    }
427
428    /// # Calculate the probability density function from overlapping intervals
429    /// 
430    /// This uses a derivative merge
431    /// 
432    /// The [Glued] allows you to easily write the probability density function to a file
433    pub fn derivative_glue_and_align<T>(&mut self) -> Result<Glued<H, T>, HistErrors>
434    where H: Histogram + HistogramCombine + HistogramVal<T>,
435        T: PartialOrd{
436
437        let log_prob = self.prepare_for_merge()?;
438        let mut res = derivative_merged_and_aligned(
439            log_prob, 
440            &self.collection, 
441            self.desired_logbase
442        )?;
443        let stats = self.get_stats();
444        res.set_stats(stats);
445        Ok(res)
446    }
447
448    fn prepare_for_merge<T>(
449        &mut self
450    ) -> Result<Vec<Vec<f64>>, HistErrors>
451    where H: Histogram + HistogramCombine + HistogramVal<T>,
452    T: PartialOrd
453    {
454        self.make_entries_desired_logbase();
455        
456        let mut encountered_invalid = false;
457
458        self.collection
459            .sort_unstable_by(
460                |a, b|
461                {
462                    match a.hist
463                        .first_border()
464                        .partial_cmp(
465                            &b.hist.first_border()
466                        ){
467                        None => {
468                            encountered_invalid = true;
469                            std::cmp::Ordering::Less
470                        },
471                        Some(o) => o
472                    }
473                }
474            );
475        if encountered_invalid {
476            return Err(HistErrors::InvalidVal);
477        }
478
479        Ok(
480            self.collection
481                .iter()
482                .map(|e| e.prob.clone())
483                .collect()
484        )
485
486    }
487
488    fn make_entries_desired_logbase(&mut self)
489    {
490        for e in self.collection.iter_mut()
491        {
492            match self.desired_logbase{
493                LogBase::Base10 => {
494                    if e.log_base.is_base_e(){
495                        e.log_base = LogBase::Base10;
496                        ln_to_log10(&mut e.prob)
497                    }
498                },
499                LogBase::BaseE => {
500                    if e.log_base.is_base10() {
501                        e.log_base = LogBase::BaseE;
502                        log10_to_ln(&mut e.prob)
503                    }
504                }
505            }
506        }
507    }
508}