From a2fe4b1616a1b2318f70e842382818ee44aea9e6 Mon Sep 17 00:00:00 2001 From: Douglas Rumbaugh Date: Tue, 7 Nov 2023 12:29:03 -0500 Subject: Alias shard fixes --- CMakeLists.txt | 6 +++--- include/query/wss.h | 28 ++++++++++++++-------------- include/shard/Alias.h | 13 +++++++++++-- tests/alias_tests.cpp | 8 ++++---- 4 files changed, 32 insertions(+), 23 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c8188a9..13f24d9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,9 +69,9 @@ if (tests) target_link_libraries(memisam_tests PUBLIC gsl check subunit pthread) target_include_directories(memisam_tests PRIVATE include external/psudb-common/cpp/include) - add_executable(wss_tests ${CMAKE_CURRENT_SOURCE_DIR}/tests/wss_tests.cpp) - target_link_libraries(wss_tests PUBLIC gsl check subunit pthread) - target_include_directories(wss_tests PRIVATE include external/psudb-common/cpp/include) + add_executable(alias_tests ${CMAKE_CURRENT_SOURCE_DIR}/tests/alias_tests.cpp) + target_link_libraries(alias_tests PUBLIC gsl check subunit pthread) + target_include_directories(alias_tests PRIVATE include external/psudb-common/cpp/include) add_executable(triespline_tests ${CMAKE_CURRENT_SOURCE_DIR}/tests/triespline_tests.cpp) target_link_libraries(triespline_tests PUBLIC gsl check subunit pthread) diff --git a/include/query/wss.h b/include/query/wss.h index b8a5d54..794485c 100644 --- a/include/query/wss.h +++ b/include/query/wss.h @@ -90,15 +90,19 @@ public: static void process_query_states(void *query_parms, std::vector &shard_states, std::vector &buffer_states) { auto p = (Parms *) query_parms; - auto bs = (BufferState *) buffer_states[0]; - std::vector shard_sample_sizes(shard_states.size()+1, 0); + std::vector shard_sample_sizes(shard_states.size()+buffer_states.size(), 0); size_t buffer_sz = 0; std::vector weights; - weights.push_back(bs->total_weight); decltype(R::weight) total_weight = 0; + for (auto &s : buffer_states) { + auto bs = (BufferState *) s; + total_weight += bs->total_weight; + weights.push_back(bs->total_weight); + } + for (auto &s : shard_states) { auto state = (State *) s; total_weight += state->total_weight; @@ -113,19 +117,15 @@ public: auto shard_alias = psudb::Alias(normalized_weights); for (size_t i=0; isample_size; i++) { auto idx = shard_alias.get(p->rng); - if (idx == 0) { - buffer_sz++; + + if (idx < buffer_states.size()) { + auto state = (BufferState *) buffer_states[idx]; + state->sample_size++; } else { - shard_sample_sizes[idx - 1]++; + auto state = (State *) shard_states[idx - buffer_states.size()]; + state->sample_size++; } } - - - bs->sample_size = buffer_sz; - for (size_t i=0; i *) shard_states[i]; - state->sample_size = shard_sample_sizes[i+1]; - } } static std::vector> query(S *shard, void *q_state, void *parms) { @@ -142,7 +142,7 @@ public: size_t attempts = 0; do { attempts++; - size_t idx = shard->m_alias->get(rng); + size_t idx = shard->get_weighted_sample(rng); result_set.emplace_back(*shard->get_record_at(idx)); } while (attempts < sample_size); diff --git a/include/shard/Alias.h b/include/shard/Alias.h index b6b16c5..a4a7d02 100644 --- a/include/shard/Alias.h +++ b/include/shard/Alias.h @@ -19,7 +19,7 @@ #include "psu-ds/PriorityQueue.h" #include "util/Cursor.h" -#include "psu-ds/psudb::Alias.h" +#include "psu-ds/Alias.h" #include "psu-ds/BloomFilter.h" #include "util/bf_config.h" @@ -207,7 +207,13 @@ public: return 0; } -private: + W get_total_weight() { + return m_total_weight; + } + + size_t get_weighted_sample(gsl_rng *rng) const { + return m_alias->get(rng); + } size_t get_lower_bound(const K& key) const { size_t min = 0; @@ -227,6 +233,8 @@ private: return min; } +private: + void build_alias_structure(std::vector &weights) { // normalize the weights vector @@ -249,3 +257,4 @@ private: size_t m_alloc_size; BloomFilter *m_bf; }; +} diff --git a/tests/alias_tests.cpp b/tests/alias_tests.cpp index b9e678b..c4a302d 100644 --- a/tests/alias_tests.cpp +++ b/tests/alias_tests.cpp @@ -180,7 +180,7 @@ START_TEST(t_alias_query) size_t k = 1000; size_t cnt[3] = {0}; - wss:Parms parms = {k}; + wss::Parms parms = {k}; parms.rng = gsl_rng_alloc(gsl_rng_mt19937); size_t total_samples = 0; @@ -223,7 +223,7 @@ START_TEST(t_alias_query_merge) size_t k = 1000; size_t cnt[3] = {0}; - wss:Parms parms = {k}; + wss::Parms parms = {k}; parms.rng = gsl_rng_alloc(gsl_rng_mt19937); std::vector>> results(2); @@ -267,7 +267,7 @@ START_TEST(t_alias_buffer_query_scan) size_t k = 1000; size_t cnt[3] = {0}; - wss:Parms parms = {k}; + wss::Parms parms = {k}; parms.rng = gsl_rng_alloc(gsl_rng_mt19937); size_t total_samples = 0; @@ -306,7 +306,7 @@ START_TEST(t_alias_buffer_query_rejection) size_t k = 1000; size_t cnt[3] = {0}; - wss:Parms parms = {k}; + wss::Parms parms = {k}; parms.rng = gsl_rng_alloc(gsl_rng_mt19937); size_t total_samples = 0; -- cgit v1.2.3