use{
crate::{*, rewl::log_density_to_log10_density},
rand::Rng,
std::{
marker::PhantomData,
mem::*,
num::*,
sync::*,
}
};
#[cfg(feature = "sweep_time_optimization")]
use std::time::*;
#[cfg(feature = "serde_support")]
use serde::{Serialize, Deserialize};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
pub struct ReesWalker<R, Hist, Energy, S, Res>
{
id: usize,
sweep_size: NonZeroUsize,
pub(crate) rng: R,
hist: Hist,
log_density: Vec<f64>,
step_count: u64,
re: u64,
proposed_re: u64,
rejected_markov_steps: u64,
old_energy: Energy,
bin: usize,
markov_steps: Vec<S>,
marker_res: PhantomData<Res>,
step_size: usize,
step_threshold: u64,
#[cfg(feature = "sweep_time_optimization")]
duration: Duration
}
impl<R, Hist, Energy, S, Res> From<RewlWalker<R, Hist, Energy, S, Res>> for ReesWalker<R, Hist, Energy, S, Res>
where Hist: Histogram
{
fn from(mut rewl_walker: RewlWalker<R, Hist, Energy, S, Res>) -> Self
{
rewl_walker.hist.reset();
Self{
id: rewl_walker.id,
sweep_size: rewl_walker.sweep_size,
markov_steps: rewl_walker.markov_steps,
step_size: rewl_walker.step_size,
marker_res: PhantomData::<Res>,
rng: rewl_walker.rng,
log_density: rewl_walker.log_density,
bin: rewl_walker.bin,
hist: rewl_walker.hist,
old_energy: rewl_walker.old_energy,
step_count: 0,
step_threshold: rewl_walker.step_count,
re: 0,
proposed_re: 0,
rejected_markov_steps: 0,
#[cfg(feature = "sweep_time_optimization")]
duration: rewl_walker.duration,
}
}
}
impl<R, Hist, Energy, S, Res> ReesWalker<R, Hist, Energy, S, Res>
{
#[inline(always)]
pub fn id(&self) -> usize
{
self.id
}
#[cfg(feature = "sweep_time_optimization")]
pub fn duration(&self) -> Duration
{
self.duration
}
#[inline(always)]
pub fn energy(&self) -> &Energy
{
&self.old_energy
}
#[inline(always)]
pub fn energy_copy(&self) -> Energy
where Energy: Copy
{
self.old_energy
}
#[inline(always)]
pub fn hist(&self) -> &Hist
{
&self.hist
}
#[inline(always)]
pub fn sweep_size(&self) -> NonZeroUsize
{
self.sweep_size
}
pub fn sweep_size_change(&mut self, sweep_size: NonZeroUsize)
{
self.sweep_size = sweep_size;
}
#[inline(always)]
pub fn step_size(&self) -> usize
{
self.step_size
}
pub fn step_size_change(&mut self, step_size: usize)
{
self.step_size = step_size;
}
#[inline(always)]
pub fn step_count(&self) -> u64
{
self.step_count
}
#[inline(always)]
pub fn replica_exchanges(&self) -> u64
{
self.re
}
#[inline(always)]
pub fn proposed_replica_exchanges(&self) -> u64
{
self.proposed_re
}
#[inline(always)]
pub fn replica_exchange_frac(&self) -> f64
{
self.re as f64 / self.proposed_re as f64
}
#[inline(always)]
pub fn rejected_markov_steps(&self) -> u64
{
self.rejected_markov_steps
}
pub fn acceptance_rate_markov(&self) -> f64
{
let rej = self.rejected_markov_steps() as f64 / self.step_count() as f64;
1.0 - rej
}
#[inline(always)]
pub fn log_density(&self) -> &[f64]
{
&self.log_density
}
pub fn log_density_refined(&self) -> Vec<f64>
where Hist: Histogram
{
let mut refined_log_density = Vec::with_capacity(self.log_density.len());
refined_log_density.extend(
self.log_density
.iter()
.zip(self.hist.hist().iter())
.map(
|(&log_d, &h)|
{
if h == 0 {
log_d
} else {
log_d + (h as f64).ln()
}
}
)
);
refined_log_density
}
pub fn log10_density(&self) -> Vec<f64>
{
log_density_to_log10_density(self.log_density())
}
pub fn log10_density_refined(&self) -> Vec<f64>
where Hist: Histogram
{
let density = self.log_density_refined();
log_density_to_log10_density(&density)
}
#[inline(always)]
pub fn is_finished(&self) -> bool
{
self.step_count >= self.step_threshold
}
#[inline(always)]
pub fn step_threshold(&self) -> u64
{
self.step_threshold
}
pub fn refine(&mut self)
where Hist: Histogram
{
let refined = self.log_density_refined();
self.log_density = refined;
self.hist.reset();
self.step_count = 0;
}
#[inline(always)]
fn count_rejected(&mut self)
{
self.rejected_markov_steps += 1;
}
}
impl<R, Hist, Energy, S, Res> ReesWalker<R, Hist, Energy, S, Res>
where Hist: HistogramVal<Energy>,
R: Rng
{
pub(crate) fn check_energy_fn<F, Ensemble>(
&self,
ensemble_vec: &[RwLock<Ensemble>],
energy_fn: F
) -> bool
where Energy: PartialEq,F: Fn(&mut Ensemble) -> Option<Energy>,
{
let mut e = ensemble_vec[self.id]
.write()
.expect("Fatal Error encountered; ERRORCODE 0x1 - this should be \
impossible to reach. If you are using the latest version of the \
'sampling' library, please contact the library author via github by opening an \
issue! https://github.com/Pardoxa/sampling/issues");
let energy = match energy_fn(&mut e){
Some(energy) => energy,
None => {
return false;
}
};
energy == self.old_energy
}
pub(crate) fn sweep<Ensemble, F, Extra, P>
(
&mut self,
ensemble_vec: &[RwLock<Ensemble>],
extra: &mut Extra,
extra_fn: P,
energy_fn: F,
)
where F: Fn(&mut Ensemble) -> Option<Energy>,
P: Fn(&Self, &mut Ensemble, &mut Extra),
Ensemble: MarkovChain<S, Res>,
Hist: Histogram
{
#[cfg(feature = "sweep_time_optimization")]
let start = Instant::now();
let mut e = ensemble_vec[self.id]
.write()
.expect("Fatal Error encountered; ERRORCODE 0x3 - this should be \
impossible to reach. If you are using the latest version of the \
'sampling' library, please contact the library author via github by opening an \
issue! https://github.com/Pardoxa/sampling/issues");
for _ in 0..self.sweep_size.get()
{
self.step_count += 1;
e.m_steps(self.step_size, &mut self.markov_steps);
let energy = match energy_fn(&mut e){
Some(energy) => energy,
None => {
self.count_rejected();
e.undo_steps_quiet(&self.markov_steps);
continue;
}
};
match self.hist.get_bin_index(&energy)
{
Ok(current_bin) => {
let acception_prob = (self.log_density[self.bin] - self.log_density[current_bin])
.exp();
if self.rng.gen::<f64>() > acception_prob
{
self.count_rejected();
e.steps_rejected(&self.markov_steps);
e.undo_steps_quiet(&self.markov_steps);
} else {
self.old_energy = energy;
self.bin = current_bin;
e.steps_accepted(&self.markov_steps);
}
},
_ => {
self.count_rejected();
e.steps_rejected(&self.markov_steps);
e.undo_steps_quiet(&self.markov_steps);
}
}
self.hist.count_index(self.bin)
.expect("Histogram index Error, ERRORCODE 0x4");
extra_fn(self, &mut e, extra);
}
#[cfg(feature = "sweep_time_optimization")]
{
self.duration = start.elapsed();
}
}
}
pub(crate) fn replica_exchange<R, Hist, Energy, S, Res>
(
walker_a: &mut ReesWalker<R, Hist, Energy, S, Res>,
walker_b: &mut ReesWalker<R, Hist, Energy, S, Res>
) where Hist: HistogramVal<Energy> + Histogram,
R: Rng
{
walker_a.proposed_re += 1;
walker_b.proposed_re += 1;
let new_bin_a = match walker_a.hist.get_bin_index(&walker_b.old_energy)
{
Ok(bin) => bin,
_ => return,
};
let new_bin_b = match walker_b.hist.get_bin_index(&walker_a.old_energy)
{
Ok(bin) => bin,
_ => return,
};
let log_gi_x = walker_a.log_density[walker_a.bin];
let log_gi_y = walker_a.log_density[new_bin_a];
let log_gj_y = walker_b.log_density[walker_b.bin];
let log_gj_x = walker_b.log_density[new_bin_b];
let log_prob = log_gi_x + log_gj_y - log_gi_y - log_gj_x;
let prob = log_prob.exp();
if walker_b.rng.gen::<f64>() < prob
{
swap(&mut walker_b.id, &mut walker_a.id);
swap(&mut walker_b.old_energy, &mut walker_a.old_energy);
walker_b.bin = new_bin_b;
walker_a.bin = new_bin_a;
walker_a.re += 1;
walker_b.re += 1;
}
walker_a.hist.count_index(walker_a.bin).unwrap();
walker_b.hist.count_index(walker_b.bin).unwrap();
}
pub(crate) fn get_merged_refined_walker_prob<R, Hist, Energy, S, Res>(walker: &[ReesWalker<R, Hist, Energy, S, Res>]) -> Vec<f64>
where Hist: Histogram
{
let log_len = walker[0].log_density.len();
debug_assert!(
walker.iter()
.all(|w| w.log_density.len() == log_len)
);
let mut averaged_log_density = walker[0].log_density_refined();
if walker.len() > 1 {
walker[1..]
.iter()
.map(|w| w.log_density_refined())
.for_each(
|log_density|
{
averaged_log_density.iter_mut()
.zip(log_density)
.for_each(
|(average, other)|
{
*average += other;
}
)
}
);
let number_of_walkers = walker.len() as f64;
averaged_log_density.iter_mut()
.for_each(|average| *average /= number_of_walkers);
}
averaged_log_density
}