From fdf5ef27fec1368a33801a98bbc5ed3556476979 Mon Sep 17 00:00:00 2001 From: Douglas Rumbaugh Date: Mon, 8 May 2023 12:59:46 -0400 Subject: Began porting source files over from other repository --- include/ds/Alias.h | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 include/ds/Alias.h (limited to 'include/ds/Alias.h') diff --git a/include/ds/Alias.h b/include/ds/Alias.h new file mode 100644 index 0000000..855dc75 --- /dev/null +++ b/include/ds/Alias.h @@ -0,0 +1,72 @@ +/* + * include/ds/Alias.h + * + * Copyright (C) 2023 Douglas Rumbaugh + * Dong Xie + * + * All rights reserved. Published under the Modified BSD License. + * + */ +#pragma once + +#include +#include + +namespace de { + +/* + * An implementation of Walker's Alias Structure for Weighted Set Sampling. Requires + * that the input weight vector is already normalized. + */ +class Alias { +public: + Alias(const std::vector& weights) + : m_alias(weights.size()), m_cutoff(weights.size()) { + size_t n = weights.size(); + auto overfull = std::vector(); + auto underfull = std::vector(); + overfull.reserve(n); + underfull.reserve(n); + + // initialize the probability_table with n*p(i) as well as the overfull and + // underfull lists. + for (size_t i = 0; i < n; i++) { + m_cutoff[i] = (double) n * weights[i]; + if (m_cutoff[i] > 1) { + overfull.emplace_back(i); + } else if (m_cutoff[i] < 1) { + underfull.emplace_back(i); + } else { + m_alias[i] = i; + } + } + + while (overfull.size() > 0 && underfull.size() > 0) { + auto i = overfull.back(); overfull.pop_back(); + auto j = underfull.back(); underfull.pop_back(); + + m_alias[j] = i; + m_cutoff[i] = m_cutoff[i] + m_cutoff[j] - 1.0; + + if (m_cutoff[i] > 1.0) { + overfull.emplace_back(i); + } else if (m_cutoff[i] < 1.0) { + underfull.emplace_back(i); + } + } + } + + size_t get(const gsl_rng* rng) { + double coin1 = gsl_rng_uniform(rng); + double coin2 = gsl_rng_uniform(rng); + + size_t k = ((double) m_alias.size()) * coin1; + return coin2 < m_cutoff[k] ? k : m_alias[k]; + } + +private: + std::vector m_alias; + std::vector m_cutoff; +}; + +} -- cgit v1.2.3