1use super::{
2 glue_helper::{ln_to_log10, log10_to_ln},
3 glue_writer::*,
4 LogBase,
5};
6use crate::histogram::*;
7use std::{borrow::Borrow, num::NonZeroUsize};
8
9#[cfg(feature = "serde_support")]
10use serde::{Deserialize, Serialize};
11pub trait GlueAble<H> {
13 fn push_glue_entry(&self, job: &mut GlueJob<H>) {
15 self.push_glue_entry_ignoring(job, &[])
16 }
17
18 fn push_glue_entry_ignoring(&self, job: &mut GlueJob<H>, ignore_idx: &[usize]);
20}
21
22#[derive(Clone, Copy, Debug)]
23#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
24pub enum SimulationType {
26 WangLandau1T = 0,
28 WangLandau1TAdaptive = 1,
30 Entropic = 2,
32 EntropicAdaptive = 3,
34 REWL = 4,
36 REES = 5,
38 Unknown = 6,
40}
41
42impl SimulationType {
43 pub fn name(self) -> &'static str {
45 match self {
46 Self::Entropic => "Entropic",
47 Self::WangLandau1T => "WangLandau1T",
48 Self::EntropicAdaptive => "EntropicAdaptive",
49 Self::WangLandau1TAdaptive => "WangLandau1TAdaptive",
50 Self::REES => "REES",
51 Self::REWL => "REWL",
52 Self::Unknown => "Unknown",
53 }
54 }
55
56 pub(crate) fn from_usize(num: usize) -> Self {
57 match num {
58 0 => Self::WangLandau1T,
59 1 => Self::WangLandau1TAdaptive,
60 2 => Self::Entropic,
61 3 => Self::EntropicAdaptive,
62 4 => Self::REWL,
63 5 => Self::REES,
64 6 => Self::Unknown,
65 _ => unreachable!(),
66 }
67 }
68}
69
70pub(crate) struct AccumulatedIntervalStats {
71 worst_log_progress: f64,
72 worst_missing_steps_progress: u64,
73 log_progress_counter: u32,
74 missing_steps_progress_counter: u32,
75 unknown_progress_counter: u32,
76 interval_sim_type_counter: [usize; 7],
77 total_rejected_steps: u64,
78 total_accepted_steps: u64,
79 total_proposed_replica_exchanges: u64,
80 total_replica_exchanges: u64,
81 potential_for_replica_exchanges: bool,
82 potential_for_proposed_replica_exchanges: bool,
83}
84
85impl AccumulatedIntervalStats {
86 pub(crate) fn write<W: std::io::Write>(&self, mut writer: W) -> std::io::Result<()> {
87 let total_intervals: usize = self.interval_sim_type_counter.iter().sum();
88 writeln!(writer, "#Accumulated Stats of {total_intervals} Intervals")?;
89 if self.log_progress_counter > 0 {
90 writeln!(
91 writer,
92 "#Worst log progress: {} - out of {} intervals that tracked log progress",
93 self.worst_log_progress, self.log_progress_counter
94 )?;
95 }
96 if self.missing_steps_progress_counter > 0 {
97 writeln!(
98 writer,
99 "#Worst missing steps progress: {} - out of {} intervals that tracked missing steps progress",
100 self.worst_missing_steps_progress,
101 self.missing_steps_progress_counter
102 )?;
103 }
104 if self.unknown_progress_counter > 0 {
105 writeln!(
106 writer,
107 "# {} Intervals had unknown progress",
108 self.unknown_progress_counter
109 )?
110 }
111
112 for (index, &amount) in self.interval_sim_type_counter.iter().enumerate() {
113 if amount > 0 {
114 let sim_type = SimulationType::from_usize(index);
115 writeln!(
116 writer,
117 "#{} contributed {} intervals",
118 sim_type.name(),
119 amount
120 )?;
121 }
122 }
123
124 let a = self.total_accepted_steps;
125 let r = self.total_rejected_steps;
126 let total = a + r;
127 writeln!(
128 writer,
129 "#TOTAL: {a} accepted and {r} rejected steps, which makes a total of {total} steps"
130 )?;
131 let a_rate = a as f64 / total as f64;
132 writeln!(writer, "#TOTAL acceptance rate {a_rate}")?;
133 let r_rate = r as f64 / total as f64;
134 writeln!(writer, "#TOTAL rejection rate {r_rate}")?;
135
136 if self.potential_for_replica_exchanges {
137 writeln!(
138 writer,
139 "#TOTAL performed replica exchanges: {}",
140 self.total_replica_exchanges
141 )?;
142 }
143 if self.potential_for_proposed_replica_exchanges {
144 writeln!(
145 writer,
146 "#TOTAL proposed replica exchanges: {}",
147 self.total_proposed_replica_exchanges
148 )?;
149 if self.potential_for_replica_exchanges {
150 let rate = self.total_replica_exchanges as f64
151 / self.total_proposed_replica_exchanges as f64;
152 writeln!(writer, "#rate of accepting replica exchanges: {rate}")?;
153 }
154 }
155 Ok(())
156 }
157
158 pub(crate) fn generate_stats(interval_stats: &[IntervalSimStats]) -> Self {
159 let mut acc = AccumulatedIntervalStats {
160 worst_log_progress: f64::NEG_INFINITY,
161 worst_missing_steps_progress: 0,
162 log_progress_counter: 0,
163 missing_steps_progress_counter: 0,
164 unknown_progress_counter: 0,
165 interval_sim_type_counter: [0; 7],
166 total_accepted_steps: 0,
167 total_rejected_steps: 0,
168 total_proposed_replica_exchanges: 0,
169 total_replica_exchanges: 0,
170 potential_for_proposed_replica_exchanges: false,
171 potential_for_replica_exchanges: false,
172 };
173
174 for stats in interval_stats.iter() {
175 acc.interval_sim_type_counter[stats.interval_sim_type as usize] += 1;
176 match stats.sim_progress {
177 SimProgress::LogF(log_f) => {
178 acc.log_progress_counter += 1;
179 acc.worst_log_progress = acc.worst_log_progress.max(log_f);
180 }
181 SimProgress::MissingSteps(missing) => {
182 acc.missing_steps_progress_counter += 1;
183 acc.worst_missing_steps_progress =
184 acc.worst_missing_steps_progress.max(missing);
185 }
186 SimProgress::Unknown => {
187 acc.unknown_progress_counter += 1;
188 }
189 }
190
191 acc.total_accepted_steps += stats.accepted_steps;
192 acc.total_rejected_steps += stats.rejected_steps;
193 if let Some(replica) = stats.replica_exchanges {
194 acc.potential_for_replica_exchanges = true;
195 acc.total_replica_exchanges += replica;
196 }
197 if let Some(proposed) = stats.proposed_replica_exchanges {
198 acc.potential_for_proposed_replica_exchanges = true;
199 acc.total_proposed_replica_exchanges += proposed;
200 }
201 }
202 acc
203 }
204}
205
206#[derive(Clone, Copy, Debug)]
207#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
208pub enum SimProgress {
210 LogF(f64),
212 MissingSteps(u64),
214 Unknown,
216}
217
218#[derive(Clone, Debug)]
221#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
222pub struct IntervalSimStats {
223 pub sim_progress: SimProgress,
225 pub interval_sim_type: SimulationType,
227 pub rejected_steps: u64,
229 pub accepted_steps: u64,
231 pub replica_exchanges: Option<u64>,
234 pub proposed_replica_exchanges: Option<u64>,
237 pub merged_over_walkers: NonZeroUsize,
241}
242
243impl IntervalSimStats {
244 pub fn write<W: std::io::Write>(&self, mut writer: W) -> std::io::Result<()> {
256 writeln!(
257 writer,
258 "#Simulated via: {:?}",
259 self.interval_sim_type.name()
260 )?;
261 writeln!(writer, "#progress {:?}", self.sim_progress)?;
262 if self.merged_over_walkers.get() == 1 {
263 writeln!(writer, "#created from a single walker")?;
264 } else {
265 writeln!(
266 writer,
267 "#created from merging {} walkers",
268 self.merged_over_walkers
269 )?;
270 }
271
272 let a = self.accepted_steps;
273 let r = self.rejected_steps;
274 let total = a + r;
275 writeln!(
276 writer,
277 "#had {a} accepted and {r} rejected steps, which makes a total of {total} steps"
278 )?;
279 let a_rate = a as f64 / total as f64;
280 writeln!(writer, "#acceptance rate {a_rate}")?;
281 let r_rate = r as f64 / total as f64;
282 writeln!(writer, "#rejection rate {r_rate}")?;
283
284 if let Some(replica) = self.replica_exchanges {
285 writeln!(writer, "#performed replica exchanges: {replica}")?;
286 }
287 if let Some(proposed) = self.proposed_replica_exchanges {
288 writeln!(writer, "#proposed replica exchanges: {proposed}")?;
289 if let Some(replica) = self.replica_exchanges {
290 let rate = replica as f64 / proposed as f64;
291 writeln!(writer, "#rate of accepting replica exchanges: {rate}")?;
292 }
293 }
294 Ok(())
295 }
296}
297
298#[derive(Clone, Debug)]
299#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
300pub struct GlueEntry<H> {
302 pub hist: H,
304 pub prob: Vec<f64>,
306 pub log_base: LogBase,
308 pub interval_stats: IntervalSimStats,
310}
311
312impl<H> Borrow<H> for GlueEntry<H> {
313 fn borrow(&self) -> &H {
314 &self.hist
315 }
316}
317
318#[derive(Clone, Debug)]
321#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
322pub struct GlueJob<H> {
323 pub collection: Vec<GlueEntry<H>>,
325 pub round_trips: Vec<usize>,
327 pub desired_logbase: LogBase,
329}
330
331impl<H> GlueJob<H>
332where
333 H: Clone,
334{
335 pub fn new<B>(to_glue: &B, desired_logbase: LogBase) -> Self
339 where
340 B: GlueAble<H>,
341 {
342 let mut job = Self {
343 collection: Vec::new(),
344 round_trips: Vec::new(),
345 desired_logbase,
346 };
347
348 to_glue.push_glue_entry(&mut job);
349 job
350 }
351
352 pub fn new_from_slice<B>(to_glue: &[B], desired_logbase: LogBase) -> Self
356 where
357 B: GlueAble<H>,
358 {
359 Self::new_from_iter(to_glue.iter(), desired_logbase)
360 }
361
362 pub fn new_from_iter<'a, B, I>(to_glue: I, desired_logbase: LogBase) -> Self
366 where
367 B: GlueAble<H> + 'a,
368 I: Iterator<Item = &'a B>,
369 {
370 let mut job = Self {
371 collection: Vec::new(),
372 round_trips: Vec::new(),
373 desired_logbase,
374 };
375
376 job.add_iter(to_glue);
377 job
378 }
379
380 pub fn add_slice<B>(&mut self, to_glue: &[B])
382 where
383 B: GlueAble<H>,
384 {
385 self.add_iter(to_glue.iter())
386 }
387
388 pub fn add_iter<'a, I, B>(&mut self, to_glue: I)
390 where
391 B: GlueAble<H> + 'a,
392 I: Iterator<Item = &'a B>,
393 {
394 for entry in to_glue {
395 entry.push_glue_entry(self);
396 }
397 }
398
399 pub fn get_stats(&self) -> GlueStats {
401 let interval_stats = self
402 .collection
403 .iter()
404 .map(|e| e.interval_stats.clone())
405 .collect();
406 GlueStats {
407 interval_stats,
408 roundtrips: self.round_trips.clone(),
409 }
410 }
411
412 pub fn average_merged_and_aligned<T>(&mut self) -> Result<Glued<H, T>, HistErrors>
419 where
420 H: Histogram + HistogramCombine + HistogramVal<T>,
421 T: PartialOrd,
422 {
423 let log_prob = self.prepare_for_merge()?;
424 let mut res = average_merged_and_aligned(log_prob, &self.collection, self.desired_logbase)?;
425 let stats = self.get_stats();
426 res.set_stats(stats);
427 Ok(res)
428 }
429
430 pub fn derivative_glue_and_align<T>(&mut self) -> Result<Glued<H, T>, HistErrors>
436 where
437 H: Histogram + HistogramCombine + HistogramVal<T>,
438 T: PartialOrd,
439 {
440 let log_prob = self.prepare_for_merge()?;
441 let mut res =
442 derivative_merged_and_aligned(log_prob, &self.collection, self.desired_logbase)?;
443 let stats = self.get_stats();
444 res.set_stats(stats);
445 Ok(res)
446 }
447
448 fn prepare_for_merge<T>(&mut self) -> Result<Vec<Vec<f64>>, HistErrors>
449 where
450 H: Histogram + HistogramCombine + HistogramVal<T>,
451 T: PartialOrd,
452 {
453 self.make_entries_desired_logbase();
454
455 let mut encountered_invalid = false;
456
457 self.collection.sort_unstable_by(|a, b| {
458 match a.hist.first_border().partial_cmp(&b.hist.first_border()) {
459 None => {
460 encountered_invalid = true;
461 std::cmp::Ordering::Less
462 }
463 Some(o) => o,
464 }
465 });
466 if encountered_invalid {
467 return Err(HistErrors::InvalidVal);
468 }
469
470 Ok(self.collection.iter().map(|e| e.prob.clone()).collect())
471 }
472
473 fn make_entries_desired_logbase(&mut self) {
474 for e in self.collection.iter_mut() {
475 match self.desired_logbase {
476 LogBase::Base10 => {
477 if e.log_base.is_base_e() {
478 e.log_base = LogBase::Base10;
479 ln_to_log10(&mut e.prob)
480 }
481 }
482 LogBase::BaseE => {
483 if e.log_base.is_base10() {
484 e.log_base = LogBase::BaseE;
485 log10_to_ln(&mut e.prob)
486 }
487 }
488 }
489 }
490 }
491}