From e932e12dad4c9d91cd0b9123f04a6ce14f051815 Mon Sep 17 00:00:00 2001 From: Khoa Nguyen <4763945+khoanguyen-dev@users.noreply.github.com> Date: Thu, 24 Aug 2023 17:42:53 +0200 Subject: [PATCH 01/18] Add template for TestingCriteria to be apply for LocationType/Location --- cpp/examples/abm_history_object.cpp | 16 +-- cpp/examples/abm_minimal.cpp | 16 +-- cpp/models/abm/testing_strategy.cpp | 106 +------------- cpp/models/abm/testing_strategy.h | 134 ++++++++++++++---- cpp/simulations/abm.cpp | 16 +-- cpp/tests/test_abm_testing_strategy.cpp | 30 ++-- cpp/tests/test_abm_world.cpp | 5 +- .../memilio/simulation/abm.cpp | 2 +- 8 files changed, 156 insertions(+), 169 deletions(-) diff --git a/cpp/examples/abm_history_object.cpp b/cpp/examples/abm_history_object.cpp index eebb701b62..7b763af899 100644 --- a/cpp/examples/abm_history_object.cpp +++ b/cpp/examples/abm_history_object.cpp @@ -116,14 +116,14 @@ int main() world.get_individualized_location(work).get_infection_parameters().set(10); // People can get tested at work (and do this with 0.5 probability) from time point 0 to day 30. - auto testing_min_time = mio::abm::days(1); - auto probability = 0.5; - auto start_date = mio::abm::TimePoint(0); - auto end_date = mio::abm::TimePoint(0) + mio::abm::days(30); - auto test_type = mio::abm::AntigenTest(); - auto test_at_work = std::vector{mio::abm::LocationType::Work}; - auto testing_criteria_work = - std::vector{mio::abm::TestingCriteria({}, test_at_work, {})}; + auto testing_min_time = mio::abm::days(1); + auto probability = 0.5; + auto start_date = mio::abm::TimePoint(0); + auto end_date = mio::abm::TimePoint(0) + mio::abm::days(30); + auto test_type = mio::abm::AntigenTest(); + auto test_at_work = std::vector{mio::abm::LocationType::Work}; + auto testing_criteria_work = std::vector>{ + mio::abm::TestingCriteria({}, test_at_work, {})}; auto testing_scheme_work = mio::abm::TestingScheme(testing_criteria_work, testing_min_time, start_date, end_date, test_type, probability); world.get_testing_strategy().add_testing_scheme(testing_scheme_work); diff --git a/cpp/examples/abm_minimal.cpp b/cpp/examples/abm_minimal.cpp index a0fe88067e..a075c38f45 100644 --- a/cpp/examples/abm_minimal.cpp +++ b/cpp/examples/abm_minimal.cpp @@ -111,14 +111,14 @@ int main() world.get_individualized_location(work).get_infection_parameters().set(10); // People can get tested at work (and do this with 0.5 probability) from time point 0 to day 30. - auto testing_min_time = mio::abm::days(1); - auto probability = 0.5; - auto start_date = mio::abm::TimePoint(0); - auto end_date = mio::abm::TimePoint(0) + mio::abm::days(30); - auto test_type = mio::abm::AntigenTest(); - auto test_at_work = std::vector{mio::abm::LocationType::Work}; - auto testing_criteria_work = - std::vector{mio::abm::TestingCriteria({}, test_at_work, {})}; + auto testing_min_time = mio::abm::days(1); + auto probability = 0.5; + auto start_date = mio::abm::TimePoint(0); + auto end_date = mio::abm::TimePoint(0) + mio::abm::days(30); + auto test_type = mio::abm::AntigenTest(); + auto test_at_work = std::vector{mio::abm::LocationType::Work}; + auto testing_criteria_work = std::vector>{ + mio::abm::TestingCriteria({}, test_at_work, {})}; auto testing_scheme_work = mio::abm::TestingScheme(testing_criteria_work, testing_min_time, start_date, end_date, test_type, probability); world.get_testing_strategy().add_testing_scheme(testing_scheme_work); diff --git a/cpp/models/abm/testing_strategy.cpp b/cpp/models/abm/testing_strategy.cpp index e6360a7958..18cdc8b6d6 100644 --- a/cpp/models/abm/testing_strategy.cpp +++ b/cpp/models/abm/testing_strategy.cpp @@ -1,8 +1,7 @@ /* -* Copyright (C) 2020-2021 German Aerospace Center (DLR-SC) -* & Helmholtz Centre for Infection Research (HZI) +* Copyright (C) 2020-2023 German Aerospace Center (DLR-SC) * -* Authors: Elisabeth Kluth, David Kerkmann, Sascha Korf, Martin J. Kuehn +* Authors: Elisabeth Kluth, David Kerkmann, Sascha Korf, Martin J. Kuehn, Khoa Nguyen * * Contact: Martin J. Kuehn * @@ -27,100 +26,7 @@ namespace mio namespace abm { -TestingCriteria::TestingCriteria(const std::vector& ages, const std::vector& location_types, - const std::vector& infection_states) - : m_ages(ages) - , m_location_types(location_types) - , m_infection_states(infection_states) -{ -} - -bool TestingCriteria::operator==(TestingCriteria other) const -{ - auto to_compare_ages = this->m_ages; - auto to_compare_infection_states = this->m_infection_states; - auto to_compare_location_types = this->m_location_types; - - std::sort(to_compare_ages.begin(), to_compare_ages.end()); - std::sort(other.m_ages.begin(), other.m_ages.end()); - std::sort(to_compare_infection_states.begin(), to_compare_infection_states.end()); - std::sort(other.m_infection_states.begin(), other.m_infection_states.end()); - std::sort(to_compare_location_types.begin(), to_compare_location_types.end()); - std::sort(other.m_location_types.begin(), other.m_location_types.end()); - - return to_compare_ages == other.m_ages && to_compare_location_types == other.m_location_types && - to_compare_infection_states == other.m_infection_states; -} - -void TestingCriteria::add_age_group(const AgeGroup age_group) -{ - if (std::find(m_ages.begin(), m_ages.end(), age_group) == m_ages.end()) { - m_ages.push_back(age_group); - } -} - -void TestingCriteria::remove_age_group(const AgeGroup age_group) -{ - auto last = std::remove(m_ages.begin(), m_ages.end(), age_group); - m_ages.erase(last, m_ages.end()); -} - -void TestingCriteria::add_location_type(const LocationType location_type) -{ - if (std::find(m_location_types.begin(), m_location_types.end(), location_type) == m_location_types.end()) { - m_location_types.push_back(location_type); - } -} -void TestingCriteria::remove_location_type(const LocationType location_type) -{ - auto last = std::remove(m_location_types.begin(), m_location_types.end(), location_type); - m_location_types.erase(last, m_location_types.end()); -} - -void TestingCriteria::add_infection_state(const InfectionState infection_state) -{ - if (std::find(m_infection_states.begin(), m_infection_states.end(), infection_state) == m_infection_states.end()) { - m_infection_states.push_back(infection_state); - } -} - -void TestingCriteria::remove_infection_state(const InfectionState infection_state) -{ - auto last = std::remove(m_infection_states.begin(), m_infection_states.end(), infection_state); - m_infection_states.erase(last, m_infection_states.end()); -} - -bool TestingCriteria::evaluate(const Person& p, const Location& l, TimePoint t) const -{ - return has_requested_age(p) && is_requested_location_type(l) && has_requested_infection_state(p, t); -} - -bool TestingCriteria::has_requested_age(const Person& p) const -{ - if (m_ages.empty()) { - return true; // no condition on the age - } - return std::find(m_ages.begin(), m_ages.end(), p.get_age()) != m_ages.end(); -} - -bool TestingCriteria::is_requested_location_type(const Location& l) const -{ - if (m_location_types.empty()) { - return true; // no condition on the location - } - return std::find(m_location_types.begin(), m_location_types.end(), l.get_type()) != m_location_types.end(); -} - -bool TestingCriteria::has_requested_infection_state(const Person& p, TimePoint t) const -{ - if (m_infection_states.empty()) { - return true; // no condition on infection state - } - return std::find(m_infection_states.begin(), m_infection_states.end(), p.get_infection_state(t)) != - m_infection_states.end(); -} - -TestingScheme::TestingScheme(const std::vector& testing_criteria, +TestingScheme::TestingScheme(const std::vector>& testing_criteria, TimeSpan minimal_time_since_last_test, TimePoint start_date, TimePoint end_date, const GenericTest& test_type, double probability) : m_testing_criteria(testing_criteria) @@ -143,14 +49,14 @@ bool TestingScheme::operator==(const TestingScheme& other) const //To be adjusted and also TestType should be static. } -void TestingScheme::add_testing_criteria(const TestingCriteria criteria) +void TestingScheme::add_testing_criteria(const TestingCriteria criteria) { if (std::find(m_testing_criteria.begin(), m_testing_criteria.end(), criteria) == m_testing_criteria.end()) { m_testing_criteria.push_back(criteria); } } -void TestingScheme::remove_testing_criteria(const TestingCriteria criteria) +void TestingScheme::remove_testing_criteria(const TestingCriteria criteria) { auto last = std::remove(m_testing_criteria.begin(), m_testing_criteria.end(), criteria); m_testing_criteria.erase(last, m_testing_criteria.end()); @@ -171,7 +77,7 @@ bool TestingScheme::run_scheme(Person& person, const Location& location, TimePoi double random = UniformDistribution::get_instance()(); if (random < m_probability) { if (std::any_of(m_testing_criteria.begin(), m_testing_criteria.end(), - [person, location, t](TestingCriteria tr) { + [person, location, t](TestingCriteria tr) { return tr.evaluate(person, location, t); })) { return !person.get_tested(t, m_test_type.get_default()); diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index 6ebb67f4b0..5a7cc140e8 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -1,8 +1,7 @@ /* -* Copyright (C) 2020-2021 German Aerospace Center (DLR-SC) -* & Helmholtz Centre for Infection Research (HZI) +* Copyright (C) 2020-2023 German Aerospace Center (DLR-SC) * -* Authors: Elisabeth Kluth, David Kerkmann, Sascha Korf, Martin J. Kuehn +* Authors: Elisabeth Kluth, David Kerkmann, Sascha Korf, Martin J. Kuehn, Khoa Nguyen * * Contact: Martin J. Kuehn * @@ -34,62 +33,120 @@ namespace abm /** * @brief TestingCriteria for TestingScheme. */ +template class TestingCriteria { public: /** * @brief Create a TestingCriteria. * @param[in] ages Vector of AgeGroup%s that are either allowed or required to be tested. - * @param[in] location_types Vector of #LocationType%s that are either allowed or required to be tested. + * @param[in] locations Vector of #Location%s or #LocationType%s that are either allowed or required to be tested. * @param[in] infection_states Vector of #InfectionState%s that are either allowed or required to be tested. * An empty vector of ages/#LocationType%s/#InfectionStates% means that no condition on the corresponding property * is set! */ TestingCriteria() = default; - TestingCriteria(const std::vector& ages, const std::vector& location_types, - const std::vector& infection_states); + + TestingCriteria(const std::vector& ages, const std::vector& locations, + const std::vector& infection_states) + : m_ages(ages) + , m_locations(locations) + , m_infection_states(infection_states) + { + } + + TestingCriteria(const std::vector& ages, const std::vector& locations, + const std::vector& infection_states) + : m_ages(ages) + , m_locations(locations) + , m_infection_states(infection_states) + { + } /** * @brief Compares two TestingCriteria for functional equality. */ - bool operator==(TestingCriteria other) const; + bool operator==(TestingCriteria other) const + { + auto to_compare_ages = this->m_ages; + auto to_compare_infection_states = this->m_infection_states; + auto to_compare_locations = this->m_locations; + + std::sort(to_compare_ages.begin(), to_compare_ages.end()); + std::sort(other.m_ages.begin(), other.m_ages.end()); + std::sort(to_compare_infection_states.begin(), to_compare_infection_states.end()); + std::sort(other.m_infection_states.begin(), other.m_infection_states.end()); + std::sort(to_compare_locations.begin(), to_compare_locations.end()); + std::sort(other.m_locations.begin(), other.m_locations.end()); + + return to_compare_ages == other.m_ages && to_compare_locations == other.m_locations && + to_compare_infection_states == other.m_infection_states; + } /** * @brief Add an AgeGroup to the set of AgeGroup%s that are either allowed or required to be tested. * @param[in] age_group AgeGroup to be added. */ - void add_age_group(const AgeGroup age_group); + void add_age_group(const AgeGroup age_group) + { + if (std::find(m_ages.begin(), m_ages.end(), age_group) == m_ages.end()) { + m_ages.push_back(age_group); + } + } /** * @brief Remove an AgeGroup from the set of AgeGroup%s that are either allowed or required to be tested. * @param[in] age_group AgeGroup to be removed. */ - void remove_age_group(const AgeGroup age_group); + void remove_age_group(const AgeGroup age_group) + { + auto last = std::remove(m_ages.begin(), m_ages.end(), age_group); + m_ages.erase(last, m_ages.end()); + } /** - * @brief Add a #LocationType to the set of #LocationType%s that are either allowed or required to be tested. - * @param[in] location_type #LocationType to be added. + * @brief Add a #Location or #LocationType to the set of #LocationType%s that are either allowed or required to be tested. + * @param[in] location_type #Location%s or #LocationType to be added. */ - void add_location_type(const LocationType location_type); + void add_location(const L location) + { + if (std::find(m_locations.begin(), m_locations.end(), location) == m_locations.end()) { + m_locations.push_back(location); + } + } /** - * @brief Remove a #LocationType from the set of #LocationType%s that are either allowed or required to be tested. - * @param[in] location_type #LocationType to be removed. + * @brief Remove a #Location or #LocationType from the set of #LocationType%s that are either allowed or required to be tested. + * @param[in] location_type #Location or #LocationType to be removed. */ - void remove_location_type(const LocationType location_type); + void remove_location(const L location) + { + auto last = std::remove(m_locations.begin(), m_locations.end(), location); + m_locations.erase(last, m_locations.end()); + } /** * @brief Add an #InfectionState to the set of #InfectionState%s that are either allowed or required to be tested. * @param[in] infection_state #InfectionState to be added. */ - void add_infection_state(const InfectionState infection_state); + void add_infection_state(const InfectionState infection_state) + { + if (std::find(m_infection_states.begin(), m_infection_states.end(), infection_state) == + m_infection_states.end()) { + m_infection_states.push_back(infection_state); + } + } /** * @brief Remove an #InfectionState from the set of #InfectionState%s that are either allowed or required to be * tested. * @param[in] infection_state #InfectionState to be removed. */ - void remove_infection_state(const InfectionState infection_state); + void remove_infection_state(const InfectionState infection_state) + { + auto last = std::remove(m_infection_states.begin(), m_infection_states.end(), infection_state); + m_infection_states.erase(last, m_infection_states.end()); + } /** * @brief Check if a Person and a Location meet all the required properties to get tested. @@ -97,30 +154,52 @@ class TestingCriteria * @param[in] l Location to be checked. * @param[in] t TimePoint when to evaluate the TestingCriteria. */ - bool evaluate(const Person& p, const Location& l, TimePoint t) const; + bool evaluate(const Person& p, const Location& l, TimePoint t) const + { + return has_requested_age(p) && is_requested_location_type(l) && has_requested_infection_state(p, t); + } private: /** * @brief Check if a Person has the required age to get tested. * @param[in] p Person to be checked. */ - bool has_requested_age(const Person& p) const; + bool has_requested_age(const Person& p) const + { + if (m_ages.empty()) { + return true; // no condition on the age + } + return std::find(m_ages.begin(), m_ages.end(), p.get_age()) != m_ages.end(); + } /** * @brief Check if a Location is in the set of Location%s that are allowed for testing. * @param[in] l Location to be checked. */ - bool is_requested_location_type(const Location& l) const; + bool is_requested_location_type(const Location& l) const + { + if (m_locations.empty()) { + return true; // no condition on the location + } + return std::find(m_locations.begin(), m_locations.end(), l.get_type()) != m_locations.end(); + } /** * @brief Check if a Person has the required InfectionState to get tested. * @param[in] p Person to be checked. * @param[in] t TimePoint when to check. */ - bool has_requested_infection_state(const Person& p, TimePoint t) const; + bool has_requested_infection_state(const Person& p, TimePoint t) const + { + if (m_infection_states.empty()) { + return true; // no condition on infection state + } + return std::find(m_infection_states.begin(), m_infection_states.end(), p.get_infection_state(t)) != + m_infection_states.end(); + } std::vector m_ages; ///< Set of #AgeGroup%s that are either allowed or required to be tested. - std::vector m_location_types; /**< Set of #LocationState%s that are either allowed or required to be + std::vector m_locations; /**< Set of #Location%s or #LocationState%s that are either allowed or required to be tested.*/ std::vector m_infection_states; /**< Set of #InfectionState%s that are either allowed or required to be tested.*/ @@ -142,8 +221,9 @@ class TestingScheme * @param test_type The type of test to be performed. * @param probability Probability of the test to be performed if a testing rule applies. */ - TestingScheme(const std::vector& testing_criteria, TimeSpan minimal_time_since_last_test, - TimePoint start_date, TimePoint end_date, const GenericTest& test_type, ScalarType probability); + TestingScheme(const std::vector>& testing_criteria, + TimeSpan minimal_time_since_last_test, TimePoint start_date, TimePoint end_date, + const GenericTest& test_type, ScalarType probability); /** * @brief Compares two TestingScheme%s for functional equality. @@ -154,13 +234,13 @@ class TestingScheme * @brief Add a TestingCriteria to the set of TestingCriteria that are checked for testing. * @param[in] criteria TestingCriteria to be added. */ - void add_testing_criteria(const TestingCriteria criteria); + void add_testing_criteria(const TestingCriteria criteria); /** * @brief Remove a TestingCriteria from the set of TestingCriteria that are checked for testing. * @param[in] criteria TestingCriteria to be removed. */ - void remove_testing_criteria(const TestingCriteria criteria); + void remove_testing_criteria(const TestingCriteria criteria); /** * @brief Get the activity status of the scheme. @@ -184,7 +264,7 @@ class TestingScheme bool run_scheme(Person& person, const Location& location, TimePoint t) const; private: - std::vector m_testing_criteria; ///< Vector with all TestingCriteria of the scheme. + std::vector> m_testing_criteria; ///< Vector with all TestingCriteria of the scheme. TimeSpan m_minimal_time_since_last_test; ///< Shortest period of time between two tests. TimePoint m_start_date; ///< Starting date of the scheme. TimePoint m_end_date; ///< Ending date of the scheme. diff --git a/cpp/simulations/abm.cpp b/cpp/simulations/abm.cpp index fee91e6616..8e23c703ce 100644 --- a/cpp/simulations/abm.cpp +++ b/cpp/simulations/abm.cpp @@ -309,8 +309,8 @@ void create_assign_locations(mio::abm::World& world) world.get_individualized_location(event).set_capacity(100, 375); std::vector test_at_social_event = {mio::abm::LocationType::SocialEvent}; - auto testing_criteria = - std::vector{mio::abm::TestingCriteria({}, test_at_social_event, {})}; + auto testing_criteria = std::vector>{ + mio::abm::TestingCriteria({}, test_at_social_event, {})}; auto testing_min_time = mio::abm::days(2); auto start_date = mio::abm::TimePoint(0); auto end_date = mio::abm::TimePoint(0) + mio::abm::days(60); @@ -411,18 +411,18 @@ void create_assign_locations(mio::abm::World& world) } // add the testing schemes for school and work - auto test_at_school = std::vector{mio::abm::LocationType::School}; - auto testing_criteria_school = - std::vector{mio::abm::TestingCriteria({}, test_at_school, {})}; + auto test_at_school = std::vector{mio::abm::LocationType::School}; + auto testing_criteria_school = std::vector>{ + mio::abm::TestingCriteria({}, test_at_school, {})}; testing_min_time = mio::abm::days(7); auto testing_scheme_school = mio::abm::TestingScheme(testing_criteria_school, testing_min_time, start_date, end_date, test_type, probability.draw_sample()); world.get_testing_strategy().add_testing_scheme(testing_scheme_school); - auto test_at_work = std::vector{mio::abm::LocationType::Work}; - auto testing_criteria_work = - std::vector{mio::abm::TestingCriteria({}, test_at_work, {})}; + auto test_at_work = std::vector{mio::abm::LocationType::Work}; + auto testing_criteria_work = std::vector>{ + mio::abm::TestingCriteria({}, test_at_work, {})}; assign_uniform_distribution(probability, 0.1, 0.5); testing_min_time = mio::abm::days(1); diff --git a/cpp/tests/test_abm_testing_strategy.cpp b/cpp/tests/test_abm_testing_strategy.cpp index 6a83fd1a37..b3293cca6e 100644 --- a/cpp/tests/test_abm_testing_strategy.cpp +++ b/cpp/tests/test_abm_testing_strategy.cpp @@ -26,12 +26,12 @@ TEST(TestTestingCriteria, addRemoveAndEvaluateTestCriteria) auto person = make_test_person(home, mio::abm::AgeGroup::Age15to34, mio::abm::InfectionState::InfectedSymptoms); mio::abm::TimePoint t{0}; - auto testing_criteria = mio::abm::TestingCriteria(); + auto testing_criteria = mio::abm::TestingCriteria(); ASSERT_EQ(testing_criteria.evaluate(person, work, t), true); testing_criteria.add_infection_state(mio::abm::InfectionState::InfectedSymptoms); testing_criteria.add_infection_state(mio::abm::InfectionState::InfectedNoSymptoms); - testing_criteria.add_location_type(mio::abm::LocationType::Home); - testing_criteria.add_location_type(mio::abm::LocationType::Work); + testing_criteria.add_location(mio::abm::LocationType::Home); + testing_criteria.add_location(mio::abm::LocationType::Work); ASSERT_EQ(testing_criteria.evaluate(person, work, t), true); ASSERT_EQ(testing_criteria.evaluate(person, home, t), true); @@ -46,13 +46,13 @@ TEST(TestTestingCriteria, addRemoveAndEvaluateTestCriteria) ASSERT_EQ(testing_criteria.evaluate(person, home, t), false); testing_criteria.add_infection_state(mio::abm::InfectionState::InfectedSymptoms); - testing_criteria.remove_location_type(mio::abm::LocationType::Home); + testing_criteria.remove_location(mio::abm::LocationType::Home); ASSERT_EQ(testing_criteria.evaluate(person, home, t), false); - auto testing_criteria_manual = - mio::abm::TestingCriteria({}, std::vector({mio::abm::LocationType::Work}), - std::vector({mio::abm::InfectionState::InfectedNoSymptoms, - mio::abm::InfectionState::InfectedSymptoms})); + auto testing_criteria_manual = mio::abm::TestingCriteria( + {}, std::vector({mio::abm::LocationType::Work}), + std::vector( + {mio::abm::InfectionState::InfectedNoSymptoms, mio::abm::InfectionState::InfectedSymptoms})); ASSERT_EQ(testing_criteria == testing_criteria_manual, true); testing_criteria_manual.remove_infection_state(mio::abm::InfectionState::InfectedSymptoms); ASSERT_EQ(testing_criteria == testing_criteria_manual, false); @@ -65,8 +65,9 @@ TEST(TestTestingScheme, runScheme) std::vector test_location_types1 = {mio::abm::LocationType::Home, mio::abm::LocationType::Work}; - auto testing_criteria1 = mio::abm::TestingCriteria({}, test_location_types1, test_infection_states1); - std::vector testing_criterias = {testing_criteria1}; + auto testing_criteria1 = + mio::abm::TestingCriteria({}, test_location_types1, test_infection_states1); + std::vector> testing_criterias = {testing_criteria1}; const auto testing_min_time = mio::abm::days(1); const auto start_date = mio::abm::TimePoint(0); @@ -86,14 +87,15 @@ TEST(TestTestingScheme, runScheme) std::vector test_infection_states2 = {mio::abm::InfectionState::Recovered}; std::vector test_location_types2 = {mio::abm::LocationType::Home}; - auto testing_criteria2 = mio::abm::TestingCriteria({}, test_location_types2, test_infection_states2); + auto testing_criteria2 = + mio::abm::TestingCriteria({}, test_location_types2, test_infection_states2); testing_scheme.add_testing_criteria(testing_criteria2); auto loc_home = mio::abm::Location(mio::abm::LocationType::Home, 0); auto loc_work = mio::abm::Location(mio::abm::LocationType::Work, 0); - auto person1 = make_test_person(loc_home, mio::abm::AgeGroup::Age15to34, mio::abm::InfectionState::InfectedNoSymptoms); - auto person2 = - make_test_person(loc_home, mio::abm::AgeGroup::Age15to34, mio::abm::InfectionState::Recovered); + auto person1 = + make_test_person(loc_home, mio::abm::AgeGroup::Age15to34, mio::abm::InfectionState::InfectedNoSymptoms); + auto person2 = make_test_person(loc_home, mio::abm::AgeGroup::Age15to34, mio::abm::InfectionState::Recovered); ScopedMockDistribution>>> mock_uniform_dist; EXPECT_CALL(mock_uniform_dist.get_mock(), invoke) diff --git a/cpp/tests/test_abm_world.cpp b/cpp/tests/test_abm_world.cpp index f0e00ed235..bab8d577a1 100644 --- a/cpp/tests/test_abm_world.cpp +++ b/cpp/tests/test_abm_world.cpp @@ -369,11 +369,10 @@ TEST(TestWorldTestingCriteria, testAddingAndUpdatingAndRunningTestingSchemes) person.set_assigned_location(home); person.set_assigned_location(work); - auto testing_criteria = mio::abm::TestingCriteria({}, {}, {}); + auto testing_criteria = mio::abm::TestingCriteria({}, {mio::abm::LocationType::Home}, {}); testing_criteria.add_infection_state(mio::abm::InfectionState::InfectedSymptoms); testing_criteria.add_infection_state(mio::abm::InfectionState::InfectedNoSymptoms); - testing_criteria.add_location_type(mio::abm::LocationType::Home); - testing_criteria.add_location_type(mio::abm::LocationType::Work); + testing_criteria.add_location(mio::abm::LocationType::Work); const auto testing_frequency = mio::abm::days(1); const auto start_date = mio::abm::TimePoint(20); diff --git a/pycode/memilio-simulation/memilio/simulation/abm.cpp b/pycode/memilio-simulation/memilio/simulation/abm.cpp index ee7c323b53..6324cc71a5 100644 --- a/pycode/memilio-simulation/memilio/simulation/abm.cpp +++ b/pycode/memilio-simulation/memilio/simulation/abm.cpp @@ -152,7 +152,7 @@ PYBIND11_MODULE(_simulation_abm, m) py::class_(m, "PCRTest").def(py::init<>()); py::class_(m, "TestingScheme") - .def(py::init&, mio::abm::TimeSpan, mio::abm::TimePoint, + .def(py::init>&, mio::abm::TimeSpan, mio::abm::TimePoint, mio::abm::TimePoint, const mio::abm::GenericTest&, double>(), py::arg("testing_criteria"), py::arg("testing_min_time_since_last_test"), py::arg("start_date"), py::arg("end_date"), py::arg("test_type"), py::arg("probability")) From d0839040806414a05a5f066da8c3e6d269b241f8 Mon Sep 17 00:00:00 2001 From: Khoa Nguyen <4763945+khoanguyen-dev@users.noreply.github.com> Date: Sat, 26 Aug 2023 13:34:06 +0200 Subject: [PATCH 02/18] Fix error with python binding --- pycode/memilio-simulation/memilio/simulation/abm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pycode/memilio-simulation/memilio/simulation/abm.cpp b/pycode/memilio-simulation/memilio/simulation/abm.cpp index 6324cc71a5..a1df864215 100644 --- a/pycode/memilio-simulation/memilio/simulation/abm.cpp +++ b/pycode/memilio-simulation/memilio/simulation/abm.cpp @@ -142,7 +142,7 @@ PYBIND11_MODULE(_simulation_abm, m) .def_property_readonly("age", &mio::abm::Person::get_age) .def_property_readonly("is_in_quarantine", &mio::abm::Person::is_in_quarantine); - py::class_(m, "TestingCriteria") + py::class_>(m, "TestingCriteria") .def(py::init&, const std::vector&, const std::vector&>(), py::arg("age_groups"), py::arg("location_types"), py::arg("infection_states")); From 9f4da51a7041692ca1b9d982438fc1dc6d3bc579 Mon Sep 17 00:00:00 2001 From: Khoa Nguyen <4763945+khoanguyen-dev@users.noreply.github.com> Date: Thu, 7 Sep 2023 12:22:31 +0200 Subject: [PATCH 03/18] Change types of ages, infection states in TestingCriteria to bitset --- cpp/models/abm/testing_strategy.h | 50 +++++++++++-------------- cpp/tests/test_abm_testing_strategy.cpp | 20 +++++----- 2 files changed, 33 insertions(+), 37 deletions(-) diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index 5a7cc140e8..c16d8b3d2b 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -24,6 +24,7 @@ #include "abm/person.h" #include "abm/location.h" #include "abm/time.h" +#include namespace mio { @@ -47,16 +48,16 @@ class TestingCriteria */ TestingCriteria() = default; - TestingCriteria(const std::vector& ages, const std::vector& locations, - const std::vector& infection_states) + TestingCriteria(const std::bitset<(size_t)AgeGroup::Count>& ages, const std::vector& locations, + const std::bitset<(size_t)InfectionState::Count>& infection_states) : m_ages(ages) , m_locations(locations) , m_infection_states(infection_states) { } - TestingCriteria(const std::vector& ages, const std::vector& locations, - const std::vector& infection_states) + TestingCriteria(const std::bitset<(size_t)AgeGroup::Count>& ages, const std::vector& locations, + const std::bitset<(size_t)InfectionState::Count>& infection_states) : m_ages(ages) , m_locations(locations) , m_infection_states(infection_states) @@ -72,10 +73,6 @@ class TestingCriteria auto to_compare_infection_states = this->m_infection_states; auto to_compare_locations = this->m_locations; - std::sort(to_compare_ages.begin(), to_compare_ages.end()); - std::sort(other.m_ages.begin(), other.m_ages.end()); - std::sort(to_compare_infection_states.begin(), to_compare_infection_states.end()); - std::sort(other.m_infection_states.begin(), other.m_infection_states.end()); std::sort(to_compare_locations.begin(), to_compare_locations.end()); std::sort(other.m_locations.begin(), other.m_locations.end()); @@ -89,9 +86,7 @@ class TestingCriteria */ void add_age_group(const AgeGroup age_group) { - if (std::find(m_ages.begin(), m_ages.end(), age_group) == m_ages.end()) { - m_ages.push_back(age_group); - } + m_ages.set((size_t)age_group, true); } /** @@ -100,8 +95,7 @@ class TestingCriteria */ void remove_age_group(const AgeGroup age_group) { - auto last = std::remove(m_ages.begin(), m_ages.end(), age_group); - m_ages.erase(last, m_ages.end()); + m_ages.set((size_t)age_group, false); } /** @@ -131,10 +125,7 @@ class TestingCriteria */ void add_infection_state(const InfectionState infection_state) { - if (std::find(m_infection_states.begin(), m_infection_states.end(), infection_state) == - m_infection_states.end()) { - m_infection_states.push_back(infection_state); - } + m_infection_states.set((size_t)infection_state, true); } /** @@ -144,8 +135,7 @@ class TestingCriteria */ void remove_infection_state(const InfectionState infection_state) { - auto last = std::remove(m_infection_states.begin(), m_infection_states.end(), infection_state); - m_infection_states.erase(last, m_infection_states.end()); + m_infection_states.set((size_t)infection_state, false); } /** @@ -166,10 +156,11 @@ class TestingCriteria */ bool has_requested_age(const Person& p) const { - if (m_ages.empty()) { - return true; // no condition on the age + if (m_ages.none()) { + std::cout << "Here" << '\n'; + return true; // no condition on the AgeGroup } - return std::find(m_ages.begin(), m_ages.end(), p.get_age()) != m_ages.end(); + return m_ages[(size_t)p.get_age()]; } /** @@ -179,6 +170,7 @@ class TestingCriteria bool is_requested_location_type(const Location& l) const { if (m_locations.empty()) { + std::cout << "Here" << '\n'; return true; // no condition on the location } return std::find(m_locations.begin(), m_locations.end(), l.get_type()) != m_locations.end(); @@ -191,17 +183,19 @@ class TestingCriteria */ bool has_requested_infection_state(const Person& p, TimePoint t) const { - if (m_infection_states.empty()) { - return true; // no condition on infection state + if (m_infection_states.none()) { + std::cout << "Here" << '\n'; + return true; // no condition on the InfectionState } - return std::find(m_infection_states.begin(), m_infection_states.end(), p.get_infection_state(t)) != - m_infection_states.end(); + return m_infection_states[(size_t)p.get_infection_state(t)]; } - std::vector m_ages; ///< Set of #AgeGroup%s that are either allowed or required to be tested. + std::bitset<(size_t)AgeGroup::Count> + m_ages; ///< BitSet of #AgeGroup%s that are either allowed or required to be tested. std::vector m_locations; /**< Set of #Location%s or #LocationState%s that are either allowed or required to be tested.*/ - std::vector m_infection_states; /**< Set of #InfectionState%s that are either allowed or required to + std::bitset<(size_t)InfectionState::Count> + m_infection_states; /**< BitSet of #InfectionState%s that are either allowed or required to be tested.*/ }; diff --git a/cpp/tests/test_abm_testing_strategy.cpp b/cpp/tests/test_abm_testing_strategy.cpp index b3293cca6e..ec59fea766 100644 --- a/cpp/tests/test_abm_testing_strategy.cpp +++ b/cpp/tests/test_abm_testing_strategy.cpp @@ -50,9 +50,9 @@ TEST(TestTestingCriteria, addRemoveAndEvaluateTestCriteria) ASSERT_EQ(testing_criteria.evaluate(person, home, t), false); auto testing_criteria_manual = mio::abm::TestingCriteria( - {}, std::vector({mio::abm::LocationType::Work}), - std::vector( - {mio::abm::InfectionState::InfectedNoSymptoms, mio::abm::InfectionState::InfectedSymptoms})); + {}, std::vector({mio::abm::LocationType::Work}), {}); + testing_criteria_manual.add_infection_state(mio::abm::InfectionState::InfectedNoSymptoms); + testing_criteria_manual.add_infection_state(mio::abm::InfectionState::InfectedSymptoms); ASSERT_EQ(testing_criteria == testing_criteria_manual, true); testing_criteria_manual.remove_infection_state(mio::abm::InfectionState::InfectedSymptoms); ASSERT_EQ(testing_criteria == testing_criteria_manual, false); @@ -60,10 +60,11 @@ TEST(TestTestingCriteria, addRemoveAndEvaluateTestCriteria) TEST(TestTestingScheme, runScheme) { - std::vector test_infection_states1 = {mio::abm::InfectionState::InfectedSymptoms, - mio::abm::InfectionState::InfectedNoSymptoms}; - std::vector test_location_types1 = {mio::abm::LocationType::Home, - mio::abm::LocationType::Work}; + std::bitset<(size_t)mio::abm::InfectionState::Count> test_infection_states1; + test_infection_states1.set((size_t)mio::abm::InfectionState::InfectedSymptoms, true); + test_infection_states1.set((size_t)mio::abm::InfectionState::InfectedNoSymptoms, true); + std::vector test_location_types1 = {mio::abm::LocationType::Home, + mio::abm::LocationType::Work}; auto testing_criteria1 = mio::abm::TestingCriteria({}, test_location_types1, test_infection_states1); @@ -85,8 +86,9 @@ TEST(TestTestingScheme, runScheme) ASSERT_EQ(testing_scheme.is_active(), false); testing_scheme.update_activity_status(mio::abm::TimePoint(0)); - std::vector test_infection_states2 = {mio::abm::InfectionState::Recovered}; - std::vector test_location_types2 = {mio::abm::LocationType::Home}; + std::bitset<(size_t)mio::abm::InfectionState::Count> test_infection_states2; + test_infection_states2.set((size_t)mio::abm::InfectionState::Recovered, true); + std::vector test_location_types2 = {mio::abm::LocationType::Home}; auto testing_criteria2 = mio::abm::TestingCriteria({}, test_location_types2, test_infection_states2); testing_scheme.add_testing_criteria(testing_criteria2); From 823ab08cf8561e9412ffd19e661c06237aff0be9 Mon Sep 17 00:00:00 2001 From: Khoa Nguyen <4763945+khoanguyen-dev@users.noreply.github.com> Date: Mon, 11 Sep 2023 15:56:43 +0200 Subject: [PATCH 04/18] Only one TestingCriteria per TestingScheme --- cpp/examples/abm_history_object.cpp | 3 +-- cpp/examples/abm_minimal.cpp | 3 +-- cpp/models/abm/testing_strategy.cpp | 20 ++-------------- cpp/models/abm/testing_strategy.h | 21 +++-------------- cpp/simulations/abm.cpp | 9 +++---- cpp/tests/test_abm_testing_strategy.cpp | 31 +++++++++++-------------- cpp/tests/test_abm_world.cpp | 2 +- 7 files changed, 25 insertions(+), 64 deletions(-) diff --git a/cpp/examples/abm_history_object.cpp b/cpp/examples/abm_history_object.cpp index 7b763af899..cda9666cf3 100644 --- a/cpp/examples/abm_history_object.cpp +++ b/cpp/examples/abm_history_object.cpp @@ -122,8 +122,7 @@ int main() auto end_date = mio::abm::TimePoint(0) + mio::abm::days(30); auto test_type = mio::abm::AntigenTest(); auto test_at_work = std::vector{mio::abm::LocationType::Work}; - auto testing_criteria_work = std::vector>{ - mio::abm::TestingCriteria({}, test_at_work, {})}; + auto testing_criteria_work = mio::abm::TestingCriteria({}, test_at_work, {}); auto testing_scheme_work = mio::abm::TestingScheme(testing_criteria_work, testing_min_time, start_date, end_date, test_type, probability); world.get_testing_strategy().add_testing_scheme(testing_scheme_work); diff --git a/cpp/examples/abm_minimal.cpp b/cpp/examples/abm_minimal.cpp index a075c38f45..85d3bdddcf 100644 --- a/cpp/examples/abm_minimal.cpp +++ b/cpp/examples/abm_minimal.cpp @@ -117,8 +117,7 @@ int main() auto end_date = mio::abm::TimePoint(0) + mio::abm::days(30); auto test_type = mio::abm::AntigenTest(); auto test_at_work = std::vector{mio::abm::LocationType::Work}; - auto testing_criteria_work = std::vector>{ - mio::abm::TestingCriteria({}, test_at_work, {})}; + auto testing_criteria_work = mio::abm::TestingCriteria({}, test_at_work, {}); auto testing_scheme_work = mio::abm::TestingScheme(testing_criteria_work, testing_min_time, start_date, end_date, test_type, probability); world.get_testing_strategy().add_testing_scheme(testing_scheme_work); diff --git a/cpp/models/abm/testing_strategy.cpp b/cpp/models/abm/testing_strategy.cpp index 18cdc8b6d6..94928f99eb 100644 --- a/cpp/models/abm/testing_strategy.cpp +++ b/cpp/models/abm/testing_strategy.cpp @@ -26,7 +26,7 @@ namespace mio namespace abm { -TestingScheme::TestingScheme(const std::vector>& testing_criteria, +TestingScheme::TestingScheme(const TestingCriteria& testing_criteria, TimeSpan minimal_time_since_last_test, TimePoint start_date, TimePoint end_date, const GenericTest& test_type, double probability) : m_testing_criteria(testing_criteria) @@ -49,19 +49,6 @@ bool TestingScheme::operator==(const TestingScheme& other) const //To be adjusted and also TestType should be static. } -void TestingScheme::add_testing_criteria(const TestingCriteria criteria) -{ - if (std::find(m_testing_criteria.begin(), m_testing_criteria.end(), criteria) == m_testing_criteria.end()) { - m_testing_criteria.push_back(criteria); - } -} - -void TestingScheme::remove_testing_criteria(const TestingCriteria criteria) -{ - auto last = std::remove(m_testing_criteria.begin(), m_testing_criteria.end(), criteria); - m_testing_criteria.erase(last, m_testing_criteria.end()); -} - bool TestingScheme::is_active() const { return m_is_active; @@ -76,10 +63,7 @@ bool TestingScheme::run_scheme(Person& person, const Location& location, TimePoi if (person.get_time_since_negative_test() > m_minimal_time_since_last_test) { double random = UniformDistribution::get_instance()(); if (random < m_probability) { - if (std::any_of(m_testing_criteria.begin(), m_testing_criteria.end(), - [person, location, t](TestingCriteria tr) { - return tr.evaluate(person, location, t); - })) { + if (m_testing_criteria.evaluate(person, location, t)) { return !person.get_tested(t, m_test_type.get_default()); } } diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index c16d8b3d2b..f321ddf354 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -24,7 +24,7 @@ #include "abm/person.h" #include "abm/location.h" #include "abm/time.h" -#include +#include namespace mio { @@ -157,7 +157,6 @@ class TestingCriteria bool has_requested_age(const Person& p) const { if (m_ages.none()) { - std::cout << "Here" << '\n'; return true; // no condition on the AgeGroup } return m_ages[(size_t)p.get_age()]; @@ -170,7 +169,6 @@ class TestingCriteria bool is_requested_location_type(const Location& l) const { if (m_locations.empty()) { - std::cout << "Here" << '\n'; return true; // no condition on the location } return std::find(m_locations.begin(), m_locations.end(), l.get_type()) != m_locations.end(); @@ -184,7 +182,6 @@ class TestingCriteria bool has_requested_infection_state(const Person& p, TimePoint t) const { if (m_infection_states.none()) { - std::cout << "Here" << '\n'; return true; // no condition on the InfectionState } return m_infection_states[(size_t)p.get_infection_state(t)]; @@ -215,7 +212,7 @@ class TestingScheme * @param test_type The type of test to be performed. * @param probability Probability of the test to be performed if a testing rule applies. */ - TestingScheme(const std::vector>& testing_criteria, + TestingScheme(const TestingCriteria& testing_criteria, TimeSpan minimal_time_since_last_test, TimePoint start_date, TimePoint end_date, const GenericTest& test_type, ScalarType probability); @@ -224,18 +221,6 @@ class TestingScheme */ bool operator==(const TestingScheme& other) const; - /** - * @brief Add a TestingCriteria to the set of TestingCriteria that are checked for testing. - * @param[in] criteria TestingCriteria to be added. - */ - void add_testing_criteria(const TestingCriteria criteria); - - /** - * @brief Remove a TestingCriteria from the set of TestingCriteria that are checked for testing. - * @param[in] criteria TestingCriteria to be removed. - */ - void remove_testing_criteria(const TestingCriteria criteria); - /** * @brief Get the activity status of the scheme. * @return Whether the TestingScheme is currently active. @@ -258,7 +243,7 @@ class TestingScheme bool run_scheme(Person& person, const Location& location, TimePoint t) const; private: - std::vector> m_testing_criteria; ///< Vector with all TestingCriteria of the scheme. + TestingCriteria m_testing_criteria; ///< Vector with all TestingCriteria of the scheme. TimeSpan m_minimal_time_since_last_test; ///< Shortest period of time between two tests. TimePoint m_start_date; ///< Starting date of the scheme. TimePoint m_end_date; ///< Ending date of the scheme. diff --git a/cpp/simulations/abm.cpp b/cpp/simulations/abm.cpp index 8e23c703ce..9be95c0bbb 100644 --- a/cpp/simulations/abm.cpp +++ b/cpp/simulations/abm.cpp @@ -309,8 +309,7 @@ void create_assign_locations(mio::abm::World& world) world.get_individualized_location(event).set_capacity(100, 375); std::vector test_at_social_event = {mio::abm::LocationType::SocialEvent}; - auto testing_criteria = std::vector>{ - mio::abm::TestingCriteria({}, test_at_social_event, {})}; + auto testing_criteria = mio::abm::TestingCriteria({}, test_at_social_event, {}); auto testing_min_time = mio::abm::days(2); auto start_date = mio::abm::TimePoint(0); auto end_date = mio::abm::TimePoint(0) + mio::abm::days(60); @@ -412,8 +411,7 @@ void create_assign_locations(mio::abm::World& world) // add the testing schemes for school and work auto test_at_school = std::vector{mio::abm::LocationType::School}; - auto testing_criteria_school = std::vector>{ - mio::abm::TestingCriteria({}, test_at_school, {})}; + auto testing_criteria_school = mio::abm::TestingCriteria({}, test_at_school, {}); testing_min_time = mio::abm::days(7); auto testing_scheme_school = mio::abm::TestingScheme(testing_criteria_school, testing_min_time, start_date, @@ -421,8 +419,7 @@ void create_assign_locations(mio::abm::World& world) world.get_testing_strategy().add_testing_scheme(testing_scheme_school); auto test_at_work = std::vector{mio::abm::LocationType::Work}; - auto testing_criteria_work = std::vector>{ - mio::abm::TestingCriteria({}, test_at_work, {})}; + auto testing_criteria_work = mio::abm::TestingCriteria({}, test_at_work, {}); assign_uniform_distribution(probability, 0.1, 0.5); testing_min_time = mio::abm::days(1); diff --git a/cpp/tests/test_abm_testing_strategy.cpp b/cpp/tests/test_abm_testing_strategy.cpp index ec59fea766..1d62e31549 100644 --- a/cpp/tests/test_abm_testing_strategy.cpp +++ b/cpp/tests/test_abm_testing_strategy.cpp @@ -76,22 +76,23 @@ TEST(TestTestingScheme, runScheme) const auto probability = 0.8; const auto test_type = mio::abm::PCRTest(); - auto testing_scheme = - mio::abm::TestingScheme(testing_criterias, testing_min_time, start_date, end_date, test_type, probability); + auto testing_scheme1 = + mio::abm::TestingScheme(testing_criteria1, testing_min_time, start_date, end_date, test_type, probability); - ASSERT_EQ(testing_scheme.is_active(), false); - testing_scheme.update_activity_status(mio::abm::TimePoint(10)); - ASSERT_EQ(testing_scheme.is_active(), true); - testing_scheme.update_activity_status(mio::abm::TimePoint(60 * 60 * 24 * 3 + 200)); - ASSERT_EQ(testing_scheme.is_active(), false); - testing_scheme.update_activity_status(mio::abm::TimePoint(0)); + ASSERT_EQ(testing_scheme1.is_active(), false); + testing_scheme1.update_activity_status(mio::abm::TimePoint(10)); + ASSERT_EQ(testing_scheme1.is_active(), true); + testing_scheme1.update_activity_status(mio::abm::TimePoint(60 * 60 * 24 * 3 + 200)); + ASSERT_EQ(testing_scheme1.is_active(), false); + testing_scheme1.update_activity_status(mio::abm::TimePoint(0)); std::bitset<(size_t)mio::abm::InfectionState::Count> test_infection_states2; test_infection_states2.set((size_t)mio::abm::InfectionState::Recovered, true); std::vector test_location_types2 = {mio::abm::LocationType::Home}; auto testing_criteria2 = mio::abm::TestingCriteria({}, test_location_types2, test_infection_states2); - testing_scheme.add_testing_criteria(testing_criteria2); + auto testing_scheme2 = + mio::abm::TestingScheme(testing_criteria2, testing_min_time, start_date, end_date, test_type, probability); auto loc_home = mio::abm::Location(mio::abm::LocationType::Home, 0); auto loc_work = mio::abm::Location(mio::abm::LocationType::Work, 0); @@ -107,13 +108,9 @@ TEST(TestTestingScheme, runScheme) .WillOnce(testing::Return(0.7)) .WillOnce(testing::Return(0.5)) .WillOnce(testing::Return(0.9)); - ASSERT_EQ(testing_scheme.run_scheme(person1, loc_home, start_date), false); // Person tests and tests positive - ASSERT_EQ(testing_scheme.run_scheme(person2, loc_work, start_date), true); // Person tests and tests negative - ASSERT_EQ(testing_scheme.run_scheme(person1, loc_home, start_date), + ASSERT_EQ(testing_scheme1.run_scheme(person1, loc_home, start_date), false); // Person tests and tests positive + ASSERT_EQ(testing_scheme1.run_scheme(person2, loc_work, start_date), true); // Person tests and tests negative + ASSERT_EQ(testing_scheme2.run_scheme(person1, loc_home, start_date), true); // Person is in quarantine and wants to go home -> can do so - ASSERT_EQ(testing_scheme.run_scheme(person1, loc_work, start_date), true); // Person doesn't test - - testing_scheme.add_testing_criteria(testing_criteria1); - testing_scheme.remove_testing_criteria(testing_criteria1); - ASSERT_EQ(testing_scheme.run_scheme(person1, loc_home, start_date), true); + ASSERT_EQ(testing_scheme1.run_scheme(person1, loc_work, start_date), true); // Person doesn't test } diff --git a/cpp/tests/test_abm_world.cpp b/cpp/tests/test_abm_world.cpp index bab8d577a1..eeef5e2319 100644 --- a/cpp/tests/test_abm_world.cpp +++ b/cpp/tests/test_abm_world.cpp @@ -381,7 +381,7 @@ TEST(TestWorldTestingCriteria, testAddingAndUpdatingAndRunningTestingSchemes) const auto test_type = mio::abm::PCRTest(); auto testing_scheme = - mio::abm::TestingScheme({testing_criteria}, testing_frequency, start_date, end_date, test_type, probability); + mio::abm::TestingScheme(testing_criteria, testing_frequency, start_date, end_date, test_type, probability); world.get_testing_strategy().add_testing_scheme(testing_scheme); ASSERT_EQ(world.get_testing_strategy().run_strategy(person, work, current_time), From 65028d0c60ff4814e7d60aa124cebdd5e264d34d Mon Sep 17 00:00:00 2001 From: Khoa Nguyen <4763945+khoanguyen-dev@users.noreply.github.com> Date: Mon, 11 Sep 2023 16:00:10 +0200 Subject: [PATCH 05/18] Fix error in pycode for TestingScheme --- pycode/memilio-simulation/memilio/simulation/abm.cpp | 2 +- pycode/memilio-simulation/memilio/simulation_test/test_abm.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pycode/memilio-simulation/memilio/simulation/abm.cpp b/pycode/memilio-simulation/memilio/simulation/abm.cpp index a1df864215..d4c85344fe 100644 --- a/pycode/memilio-simulation/memilio/simulation/abm.cpp +++ b/pycode/memilio-simulation/memilio/simulation/abm.cpp @@ -152,7 +152,7 @@ PYBIND11_MODULE(_simulation_abm, m) py::class_(m, "PCRTest").def(py::init<>()); py::class_(m, "TestingScheme") - .def(py::init>&, mio::abm::TimeSpan, mio::abm::TimePoint, + .def(py::init&, mio::abm::TimeSpan, mio::abm::TimePoint, mio::abm::TimePoint, const mio::abm::GenericTest&, double>(), py::arg("testing_criteria"), py::arg("testing_min_time_since_last_test"), py::arg("start_date"), py::arg("end_date"), py::arg("test_type"), py::arg("probability")) diff --git a/pycode/memilio-simulation/memilio/simulation_test/test_abm.py b/pycode/memilio-simulation/memilio/simulation_test/test_abm.py index 80d3f9c3bc..6881fdbcb2 100644 --- a/pycode/memilio-simulation/memilio/simulation_test/test_abm.py +++ b/pycode/memilio-simulation/memilio/simulation_test/test_abm.py @@ -54,8 +54,7 @@ def test_locations(self): testing_locations = [abm.LocationType.Home] testing_inf_states = [] - testing_crit = [abm.TestingCriteria( - testing_ages, testing_locations, testing_inf_states)] + testing_crit = abm.TestingCriteria(testing_ages, testing_locations, testing_inf_states) testing_scheme = abm.TestingScheme(testing_crit, abm.days( 1), t0, t0 + abm.days(1), abm.AntigenTest(), 1.0) # initially false, will only active once simulation starts From 04fb3eda0e665376906e8ca6136bc994e9fc1057 Mon Sep 17 00:00:00 2001 From: Khoa Nguyen <4763945+khoanguyen-dev@users.noreply.github.com> Date: Mon, 11 Sep 2023 16:00:57 +0200 Subject: [PATCH 06/18] Fix error in pycode for TestingScheme --- pycode/memilio-simulation/memilio/simulation_test/test_abm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pycode/memilio-simulation/memilio/simulation_test/test_abm.py b/pycode/memilio-simulation/memilio/simulation_test/test_abm.py index 6881fdbcb2..61b4c5f844 100644 --- a/pycode/memilio-simulation/memilio/simulation_test/test_abm.py +++ b/pycode/memilio-simulation/memilio/simulation_test/test_abm.py @@ -54,7 +54,8 @@ def test_locations(self): testing_locations = [abm.LocationType.Home] testing_inf_states = [] - testing_crit = abm.TestingCriteria(testing_ages, testing_locations, testing_inf_states) + testing_crit = abm.TestingCriteria( + testing_ages, testing_locations, testing_inf_states) testing_scheme = abm.TestingScheme(testing_crit, abm.days( 1), t0, t0 + abm.days(1), abm.AntigenTest(), 1.0) # initially false, will only active once simulation starts From 7ac27581c6665919b20adffa6018aa1075da099c Mon Sep 17 00:00:00 2001 From: Khoa Nguyen <4763945+khoanguyen-dev@users.noreply.github.com> Date: Mon, 11 Sep 2023 16:15:35 +0200 Subject: [PATCH 07/18] Fix error in pycode for TestingCriteria --- pycode/memilio-simulation/memilio/simulation/abm.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pycode/memilio-simulation/memilio/simulation/abm.cpp b/pycode/memilio-simulation/memilio/simulation/abm.cpp index d4c85344fe..999bf3fc9a 100644 --- a/pycode/memilio-simulation/memilio/simulation/abm.cpp +++ b/pycode/memilio-simulation/memilio/simulation/abm.cpp @@ -143,8 +143,8 @@ PYBIND11_MODULE(_simulation_abm, m) .def_property_readonly("is_in_quarantine", &mio::abm::Person::is_in_quarantine); py::class_>(m, "TestingCriteria") - .def(py::init&, const std::vector&, - const std::vector&>(), + .def(py::init&, const std::vector&, + const std::bitset<(size_t)mio::abm::InfectionState::Count>&>(), py::arg("age_groups"), py::arg("location_types"), py::arg("infection_states")); py::class_(m, "GenericTest").def(py::init<>()); From 050e9e47e6118f4a437d84e2cc0be67e9104affb Mon Sep 17 00:00:00 2001 From: Khoa Nguyen <4763945+khoanguyen-dev@users.noreply.github.com> Date: Tue, 12 Sep 2023 09:13:32 +0200 Subject: [PATCH 08/18] Revert TestingCriteria to accept vectors of AgeGroup and InfectionState --- cpp/examples/abm_history_object.cpp | 6 ++-- cpp/examples/abm_minimal.cpp | 2 +- cpp/models/abm/testing_strategy.h | 33 ++++++++++++------- cpp/tests/test_abm_testing_strategy.cpp | 21 +++++------- cpp/tests/test_abm_world.cpp | 2 +- .../memilio/simulation/abm.cpp | 4 +-- 6 files changed, 38 insertions(+), 30 deletions(-) diff --git a/cpp/examples/abm_history_object.cpp b/cpp/examples/abm_history_object.cpp index 07947c3ef3..cda9666cf3 100644 --- a/cpp/examples/abm_history_object.cpp +++ b/cpp/examples/abm_history_object.cpp @@ -134,9 +134,9 @@ int main() mio::abm::InfectionState infection_state = (mio::abm::InfectionState)(rand() % ((uint32_t)mio::abm::InfectionState::Count - 1)); if (infection_state != mio::abm::InfectionState::Susceptible) - person.add_new_infection(mio::abm::Infection( - mio::abm::VirusVariant::Wildtype, person.get_age(), world.get_global_infection_parameters(), start_date, - infection_state)); + person.add_new_infection(mio::abm::Infection(mio::abm::VirusVariant::Wildtype, person.get_age(), + world.get_global_infection_parameters(), start_date, + infection_state)); } // Assign locations to the people diff --git a/cpp/examples/abm_minimal.cpp b/cpp/examples/abm_minimal.cpp index d19113048a..60328fc9a1 100644 --- a/cpp/examples/abm_minimal.cpp +++ b/cpp/examples/abm_minimal.cpp @@ -163,4 +163,4 @@ int main() sim.advance(tmax); write_results_to_file(sim); -} +} \ No newline at end of file diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index f321ddf354..99283ac8d7 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -48,20 +48,32 @@ class TestingCriteria */ TestingCriteria() = default; - TestingCriteria(const std::bitset<(size_t)AgeGroup::Count>& ages, const std::vector& locations, - const std::bitset<(size_t)InfectionState::Count>& infection_states) - : m_ages(ages) + TestingCriteria(const std::vector& ages, const std::vector& locations, + const std::vector& infection_states) + : m_ages() , m_locations(locations) - , m_infection_states(infection_states) + , m_infection_states() { + for (auto age : ages) { + m_ages.set((size_t)age, true); + } + for (auto infection_state : infection_states) { + m_infection_states.set((size_t)infection_state, true); + } } - TestingCriteria(const std::bitset<(size_t)AgeGroup::Count>& ages, const std::vector& locations, - const std::bitset<(size_t)InfectionState::Count>& infection_states) - : m_ages(ages) + TestingCriteria(const std::vector& ages, const std::vector& locations, + const std::vector& infection_states) + : m_ages() , m_locations(locations) - , m_infection_states(infection_states) + , m_infection_states() { + for (auto age : ages) { + m_ages.set((size_t)age, true); + } + for (auto infection_state : infection_states) { + m_infection_states.set((size_t)infection_state, true); + } } /** @@ -212,9 +224,8 @@ class TestingScheme * @param test_type The type of test to be performed. * @param probability Probability of the test to be performed if a testing rule applies. */ - TestingScheme(const TestingCriteria& testing_criteria, - TimeSpan minimal_time_since_last_test, TimePoint start_date, TimePoint end_date, - const GenericTest& test_type, ScalarType probability); + TestingScheme(const TestingCriteria& testing_criteria, TimeSpan minimal_time_since_last_test, + TimePoint start_date, TimePoint end_date, const GenericTest& test_type, ScalarType probability); /** * @brief Compares two TestingScheme%s for functional equality. diff --git a/cpp/tests/test_abm_testing_strategy.cpp b/cpp/tests/test_abm_testing_strategy.cpp index 1d62e31549..1975214583 100644 --- a/cpp/tests/test_abm_testing_strategy.cpp +++ b/cpp/tests/test_abm_testing_strategy.cpp @@ -50,9 +50,9 @@ TEST(TestTestingCriteria, addRemoveAndEvaluateTestCriteria) ASSERT_EQ(testing_criteria.evaluate(person, home, t), false); auto testing_criteria_manual = mio::abm::TestingCriteria( - {}, std::vector({mio::abm::LocationType::Work}), {}); - testing_criteria_manual.add_infection_state(mio::abm::InfectionState::InfectedNoSymptoms); - testing_criteria_manual.add_infection_state(mio::abm::InfectionState::InfectedSymptoms); + {}, std::vector({mio::abm::LocationType::Work}), + std::vector( + {mio::abm::InfectionState::InfectedNoSymptoms, mio::abm::InfectionState::InfectedSymptoms})); ASSERT_EQ(testing_criteria == testing_criteria_manual, true); testing_criteria_manual.remove_infection_state(mio::abm::InfectionState::InfectedSymptoms); ASSERT_EQ(testing_criteria == testing_criteria_manual, false); @@ -60,15 +60,13 @@ TEST(TestTestingCriteria, addRemoveAndEvaluateTestCriteria) TEST(TestTestingScheme, runScheme) { - std::bitset<(size_t)mio::abm::InfectionState::Count> test_infection_states1; - test_infection_states1.set((size_t)mio::abm::InfectionState::InfectedSymptoms, true); - test_infection_states1.set((size_t)mio::abm::InfectionState::InfectedNoSymptoms, true); - std::vector test_location_types1 = {mio::abm::LocationType::Home, - mio::abm::LocationType::Work}; + std::vector test_infection_states1 = {mio::abm::InfectionState::InfectedSymptoms, + mio::abm::InfectionState::InfectedNoSymptoms}; + std::vector test_location_types1 = {mio::abm::LocationType::Home, + mio::abm::LocationType::Work}; auto testing_criteria1 = mio::abm::TestingCriteria({}, test_location_types1, test_infection_states1); - std::vector> testing_criterias = {testing_criteria1}; const auto testing_min_time = mio::abm::days(1); const auto start_date = mio::abm::TimePoint(0); @@ -86,9 +84,8 @@ TEST(TestTestingScheme, runScheme) ASSERT_EQ(testing_scheme1.is_active(), false); testing_scheme1.update_activity_status(mio::abm::TimePoint(0)); - std::bitset<(size_t)mio::abm::InfectionState::Count> test_infection_states2; - test_infection_states2.set((size_t)mio::abm::InfectionState::Recovered, true); - std::vector test_location_types2 = {mio::abm::LocationType::Home}; + std::vector test_infection_states2 = {mio::abm::InfectionState::Recovered}; + std::vector test_location_types2 = {mio::abm::LocationType::Home}; auto testing_criteria2 = mio::abm::TestingCriteria({}, test_location_types2, test_infection_states2); auto testing_scheme2 = diff --git a/cpp/tests/test_abm_world.cpp b/cpp/tests/test_abm_world.cpp index 62048b61ce..2fc6370d21 100644 --- a/cpp/tests/test_abm_world.cpp +++ b/cpp/tests/test_abm_world.cpp @@ -373,7 +373,7 @@ TEST(TestWorldTestingCriteria, testAddingAndUpdatingAndRunningTestingSchemes) const auto test_type = mio::abm::PCRTest(); auto testing_scheme = - mio::abm::TestingScheme(testing_criteria, testing_frequency, start_date, end_date, test_type, probability); + mio::abm::TestingScheme({testing_criteria}, testing_frequency, start_date, end_date, test_type, probability); world.get_testing_strategy().add_testing_scheme(testing_scheme); ASSERT_EQ(world.get_testing_strategy().run_strategy(person, work, current_time), diff --git a/pycode/memilio-simulation/memilio/simulation/abm.cpp b/pycode/memilio-simulation/memilio/simulation/abm.cpp index e15b04b2a1..0cab84dd40 100644 --- a/pycode/memilio-simulation/memilio/simulation/abm.cpp +++ b/pycode/memilio-simulation/memilio/simulation/abm.cpp @@ -144,8 +144,8 @@ PYBIND11_MODULE(_simulation_abm, m) .def_property_readonly("is_in_quarantine", &mio::abm::Person::is_in_quarantine); py::class_>(m, "TestingCriteria") - .def(py::init&, const std::vector&, - const std::bitset<(size_t)mio::abm::InfectionState::Count>&>(), + .def(py::init&, const std::vector&, + const std::vector&>(), py::arg("age_groups"), py::arg("location_types"), py::arg("infection_states")); py::class_(m, "GenericTest").def(py::init<>()); From 83ef677a78f1501ba79772bd998b88cb3823eda3 Mon Sep 17 00:00:00 2001 From: Khoa Nguyen <4763945+khoanguyen-dev@users.noreply.github.com> Date: Wed, 13 Sep 2023 12:30:19 +0200 Subject: [PATCH 09/18] Only one TestingCriteria per TestingScheme --- cpp/models/abm/testing_strategy.cpp | 9 +++++---- cpp/models/abm/testing_strategy.h | 20 ++------------------ cpp/models/abm/world.cpp | 15 ++++++++------- cpp/tests/test_abm_testing_strategy.cpp | 13 ++++++------- 4 files changed, 21 insertions(+), 36 deletions(-) diff --git a/cpp/models/abm/testing_strategy.cpp b/cpp/models/abm/testing_strategy.cpp index 94928f99eb..c15208657b 100644 --- a/cpp/models/abm/testing_strategy.cpp +++ b/cpp/models/abm/testing_strategy.cpp @@ -53,6 +53,7 @@ bool TestingScheme::is_active() const { return m_is_active; } + void TestingScheme::update_activity_status(TimePoint t) { m_is_active = (m_start_date <= t && t <= m_end_date); @@ -60,10 +61,10 @@ void TestingScheme::update_activity_status(TimePoint t) bool TestingScheme::run_scheme(Person& person, const Location& location, TimePoint t) const { - if (person.get_time_since_negative_test() > m_minimal_time_since_last_test) { - double random = UniformDistribution::get_instance()(); - if (random < m_probability) { - if (m_testing_criteria.evaluate(person, location, t)) { + if (m_testing_criteria.evaluate(person, location, t)) { + if (person.get_time_since_negative_test() > m_minimal_time_since_last_test) { + double random = UniformDistribution::get_instance()(); + if (random < m_probability) { return !person.get_tested(t, m_test_type.get_default()); } } diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index 99283ac8d7..be9f2303b2 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -48,25 +48,9 @@ class TestingCriteria */ TestingCriteria() = default; - TestingCriteria(const std::vector& ages, const std::vector& locations, + TestingCriteria(const std::vector& ages, const std::vector& locations, const std::vector& infection_states) - : m_ages() - , m_locations(locations) - , m_infection_states() - { - for (auto age : ages) { - m_ages.set((size_t)age, true); - } - for (auto infection_state : infection_states) { - m_infection_states.set((size_t)infection_state, true); - } - } - - TestingCriteria(const std::vector& ages, const std::vector& locations, - const std::vector& infection_states) - : m_ages() - , m_locations(locations) - , m_infection_states() + : m_locations(locations) { for (auto age : ages) { m_ages.set((size_t)age, true); diff --git a/cpp/models/abm/world.cpp b/cpp/models/abm/world.cpp index c85e6f9bcd..bac8fb58ec 100755 --- a/cpp/models/abm/world.cpp +++ b/cpp/models/abm/world.cpp @@ -91,14 +91,15 @@ void World::migration(TimePoint t, TimeSpan dt) auto target_type = rule.first(*person, t, dt, m_migration_parameters); auto& target_location = find_location(target_type, *person); auto current_location = person->get_location(); - if (m_testing_strategy.run_strategy(*person, target_location, t)) { - if (target_location != current_location && - target_location.get_number_persons() < target_location.get_capacity().persons) { - bool wears_mask = person->apply_mask_intervention(target_location); - if (wears_mask) { - person->migrate_to(target_location); + if (target_location != current_location) { + if (m_testing_strategy.run_strategy(*person, target_location, t)) { + if (target_location.get_number_persons() < target_location.get_capacity().persons) { + bool wears_mask = person->apply_mask_intervention(target_location); + if (wears_mask) { + person->migrate_to(target_location); + } + break; } - break; } } } diff --git a/cpp/tests/test_abm_testing_strategy.cpp b/cpp/tests/test_abm_testing_strategy.cpp index 1975214583..87bb4b9f32 100644 --- a/cpp/tests/test_abm_testing_strategy.cpp +++ b/cpp/tests/test_abm_testing_strategy.cpp @@ -99,15 +99,14 @@ TEST(TestTestingScheme, runScheme) ScopedMockDistribution>>> mock_uniform_dist; EXPECT_CALL(mock_uniform_dist.get_mock(), invoke) - .Times(testing::Exactly(5)) + .Times(testing::Exactly(4)) .WillOnce(testing::Return(0.7)) .WillOnce(testing::Return(0.5)) .WillOnce(testing::Return(0.7)) - .WillOnce(testing::Return(0.5)) - .WillOnce(testing::Return(0.9)); + .WillOnce(testing::Return(0.5)); ASSERT_EQ(testing_scheme1.run_scheme(person1, loc_home, start_date), false); // Person tests and tests positive - ASSERT_EQ(testing_scheme1.run_scheme(person2, loc_work, start_date), true); // Person tests and tests negative - ASSERT_EQ(testing_scheme2.run_scheme(person1, loc_home, start_date), - true); // Person is in quarantine and wants to go home -> can do so - ASSERT_EQ(testing_scheme1.run_scheme(person1, loc_work, start_date), true); // Person doesn't test + ASSERT_EQ(testing_scheme2.run_scheme(person2, loc_work, start_date), true); // Person tests and tests negative + ASSERT_EQ(testing_scheme1.run_scheme(person1, loc_home, start_date), + false); // Person is in quarantine and wants to go home -> can do so + ASSERT_EQ(testing_scheme2.run_scheme(person1, loc_work, start_date), true); // Person doesn't test } From 515a368d3d83691d9a65f3784b40363d23675576 Mon Sep 17 00:00:00 2001 From: Khoa Nguyen <4763945+khoanguyen-dev@users.noreply.github.com> Date: Thu, 21 Sep 2023 18:32:50 +0200 Subject: [PATCH 10/18] Add map of Location-TestingScheme to TestingStrategy --- cpp/examples/abm_history_object.cpp | 5 +- cpp/examples/abm_minimal.cpp | 5 +- cpp/models/abm/location_type.h | 8 ++ cpp/models/abm/testing_strategy.cpp | 66 +++++++++++------ cpp/models/abm/testing_strategy.h | 99 ++++++++++--------------- cpp/simulations/abm.cpp | 14 ++-- cpp/tests/test_abm_testing_strategy.cpp | 46 +++++------- cpp/tests/test_abm_world.cpp | 12 +-- 8 files changed, 124 insertions(+), 131 deletions(-) diff --git a/cpp/examples/abm_history_object.cpp b/cpp/examples/abm_history_object.cpp index cda9666cf3..19f5de2843 100644 --- a/cpp/examples/abm_history_object.cpp +++ b/cpp/examples/abm_history_object.cpp @@ -121,11 +121,10 @@ int main() auto start_date = mio::abm::TimePoint(0); auto end_date = mio::abm::TimePoint(0) + mio::abm::days(30); auto test_type = mio::abm::AntigenTest(); - auto test_at_work = std::vector{mio::abm::LocationType::Work}; - auto testing_criteria_work = mio::abm::TestingCriteria({}, test_at_work, {}); + auto testing_criteria_work = mio::abm::TestingCriteria(); auto testing_scheme_work = mio::abm::TestingScheme(testing_criteria_work, testing_min_time, start_date, end_date, test_type, probability); - world.get_testing_strategy().add_testing_scheme(testing_scheme_work); + world.get_testing_strategy().add_testing_scheme(mio::abm::LocationType::Work, testing_scheme_work); // Assign infection state to each person. // The infection states are chosen randomly. diff --git a/cpp/examples/abm_minimal.cpp b/cpp/examples/abm_minimal.cpp index 60328fc9a1..1a3863e6cc 100644 --- a/cpp/examples/abm_minimal.cpp +++ b/cpp/examples/abm_minimal.cpp @@ -116,11 +116,10 @@ int main() auto start_date = mio::abm::TimePoint(0); auto end_date = mio::abm::TimePoint(0) + mio::abm::days(30); auto test_type = mio::abm::AntigenTest(); - auto test_at_work = std::vector{mio::abm::LocationType::Work}; - auto testing_criteria_work = mio::abm::TestingCriteria({}, test_at_work, {}); + auto testing_criteria_work = mio::abm::TestingCriteria(); auto testing_scheme_work = mio::abm::TestingScheme(testing_criteria_work, testing_min_time, start_date, end_date, test_type, probability); - world.get_testing_strategy().add_testing_scheme(testing_scheme_work); + world.get_testing_strategy().add_testing_scheme(mio::abm::LocationType::Work, testing_scheme_work); // Assign infection state to each person. // The infection states are chosen randomly. diff --git a/cpp/models/abm/location_type.h b/cpp/models/abm/location_type.h index ffebb29cc8..fa18cc6e9c 100644 --- a/cpp/models/abm/location_type.h +++ b/cpp/models/abm/location_type.h @@ -67,6 +67,14 @@ struct LocationId { { return !(index == rhs.index && type == rhs.type); } + + bool operator<(const LocationId& rhs) const + { + if (type == rhs.type) { + return index < rhs.index; + } + return (type < rhs.type); + } }; } // namespace abm diff --git a/cpp/models/abm/testing_strategy.cpp b/cpp/models/abm/testing_strategy.cpp index c15208657b..563c904864 100644 --- a/cpp/models/abm/testing_strategy.cpp +++ b/cpp/models/abm/testing_strategy.cpp @@ -20,15 +20,15 @@ #include "abm/testing_strategy.h" #include "memilio/utils/random_number_generator.h" +#include namespace mio { namespace abm { -TestingScheme::TestingScheme(const TestingCriteria& testing_criteria, - TimeSpan minimal_time_since_last_test, TimePoint start_date, TimePoint end_date, - const GenericTest& test_type, double probability) +TestingScheme::TestingScheme(const TestingCriteria& testing_criteria, TimeSpan minimal_time_since_last_test, + TimePoint start_date, TimePoint end_date, const GenericTest& test_type, double probability) : m_testing_criteria(testing_criteria) , m_minimal_time_since_last_test(minimal_time_since_last_test) , m_start_date(start_date) @@ -59,12 +59,12 @@ void TestingScheme::update_activity_status(TimePoint t) m_is_active = (m_start_date <= t && t <= m_end_date); } -bool TestingScheme::run_scheme(Person& person, const Location& location, TimePoint t) const +bool TestingScheme::run_scheme(Person& person, TimePoint t) const { - if (m_testing_criteria.evaluate(person, location, t)) { - if (person.get_time_since_negative_test() > m_minimal_time_since_last_test) { - double random = UniformDistribution::get_instance()(); - if (random < m_probability) { + if (person.get_time_since_negative_test() > m_minimal_time_since_last_test) { + double random = UniformDistribution::get_instance()(); + if (random < m_probability) { + if (m_testing_criteria.evaluate(person, t)) { return !person.get_tested(t, m_test_type.get_default()); } } @@ -72,40 +72,64 @@ bool TestingScheme::run_scheme(Person& person, const Location& location, TimePoi return true; } -TestingStrategy::TestingStrategy(const std::vector& testing_schemes) - : m_testing_schemes(testing_schemes) +TestingStrategy::TestingStrategy(const std::map>& location_to_schemes_map) + : m_location_to_schemes_map(location_to_schemes_map) { } -void TestingStrategy::add_testing_scheme(const TestingScheme& scheme) +void TestingStrategy::add_testing_scheme(const LocationId& loc_id, const TestingScheme& scheme) { - if (std::find(m_testing_schemes.begin(), m_testing_schemes.end(), scheme) == m_testing_schemes.end()) { - m_testing_schemes.push_back(scheme); + auto &schemes_vector = m_location_to_schemes_map[loc_id]; + if (std::find(schemes_vector.begin(), schemes_vector.end(), scheme) == schemes_vector.end()) { + schemes_vector.push_back(scheme); } } -void TestingStrategy::remove_testing_scheme(const TestingScheme& scheme) +void TestingStrategy::add_testing_scheme(const LocationType& loc_type, const TestingScheme& scheme) { - auto last = std::remove(m_testing_schemes.begin(), m_testing_schemes.end(), scheme); - m_testing_schemes.erase(last, m_testing_schemes.end()); + auto loc_id = LocationId{INVALID_LOCATION_INDEX, loc_type}; + add_testing_scheme(loc_id, scheme); +} + +void TestingStrategy::remove_testing_scheme(const LocationId& loc_id, const TestingScheme& scheme) +{ + auto &schemes_vector = m_location_to_schemes_map[loc_id]; + auto last = std::remove(schemes_vector.begin(), schemes_vector.end(), scheme); + schemes_vector.erase(last, schemes_vector.end()); +} + +void TestingStrategy::remove_testing_scheme(const LocationType& loc_type, const TestingScheme& scheme) +{ + auto loc_id = LocationId{INVALID_LOCATION_INDEX, loc_type}; + remove_testing_scheme(loc_id, scheme); } void TestingStrategy::update_activity_status(TimePoint t) { - for (auto& ts : m_testing_schemes) { - ts.update_activity_status(t); + for (auto& [_, testing_schemes] : m_location_to_schemes_map) { + for (auto &scheme : testing_schemes) { + scheme.update_activity_status(t); + } } } -bool TestingStrategy::run_strategy(Person& person, const Location& location, TimePoint t) const +bool TestingStrategy::run_strategy(Person& person, const Location& location, TimePoint t) { // Person who is in quarantine but not yet home should go home. Otherwise they can't because they test positive. if (location.get_type() == mio::abm::LocationType::Home && person.is_in_quarantine()) { return true; } - return std::all_of(m_testing_schemes.begin(), m_testing_schemes.end(), [&person, location, t](TestingScheme ts) { + + auto schemes_location_vector = m_location_to_schemes_map[LocationId{location.get_index(), location.get_type()}]; + auto schemes_location_type_vector = + m_location_to_schemes_map[LocationId{INVALID_LOCATION_INDEX, location.get_type()}]; + schemes_location_vector.insert(schemes_location_vector.end(), schemes_location_type_vector.begin(), + schemes_location_type_vector.end()); + std::cout << schemes_location_vector.size() << '\n'; + return std::all_of(schemes_location_vector.begin(), schemes_location_vector.end(), [&person, t](TestingScheme ts) { if (ts.is_active()) { - return ts.run_scheme(person, location, t); + std::cout << "run scheme" << '\n'; + return ts.run_scheme(person, t); } return true; }); diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index be9f2303b2..41d7a60fcb 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -25,6 +25,7 @@ #include "abm/location.h" #include "abm/time.h" #include +#include namespace mio { @@ -34,7 +35,6 @@ namespace abm /** * @brief TestingCriteria for TestingScheme. */ -template class TestingCriteria { public: @@ -48,9 +48,8 @@ class TestingCriteria */ TestingCriteria() = default; - TestingCriteria(const std::vector& ages, const std::vector& locations, - const std::vector& infection_states) - : m_locations(locations) + TestingCriteria(const std::vector& ages, const std::vector& infection_states) + { for (auto age : ages) { m_ages.set((size_t)age, true); @@ -67,13 +66,7 @@ class TestingCriteria { auto to_compare_ages = this->m_ages; auto to_compare_infection_states = this->m_infection_states; - auto to_compare_locations = this->m_locations; - - std::sort(to_compare_locations.begin(), to_compare_locations.end()); - std::sort(other.m_locations.begin(), other.m_locations.end()); - - return to_compare_ages == other.m_ages && to_compare_locations == other.m_locations && - to_compare_infection_states == other.m_infection_states; + return to_compare_ages == other.m_ages && to_compare_infection_states == other.m_infection_states; } /** @@ -94,27 +87,6 @@ class TestingCriteria m_ages.set((size_t)age_group, false); } - /** - * @brief Add a #Location or #LocationType to the set of #LocationType%s that are either allowed or required to be tested. - * @param[in] location_type #Location%s or #LocationType to be added. - */ - void add_location(const L location) - { - if (std::find(m_locations.begin(), m_locations.end(), location) == m_locations.end()) { - m_locations.push_back(location); - } - } - - /** - * @brief Remove a #Location or #LocationType from the set of #LocationType%s that are either allowed or required to be tested. - * @param[in] location_type #Location or #LocationType to be removed. - */ - void remove_location(const L location) - { - auto last = std::remove(m_locations.begin(), m_locations.end(), location); - m_locations.erase(last, m_locations.end()); - } - /** * @brief Add an #InfectionState to the set of #InfectionState%s that are either allowed or required to be tested. * @param[in] infection_state #InfectionState to be added. @@ -137,12 +109,11 @@ class TestingCriteria /** * @brief Check if a Person and a Location meet all the required properties to get tested. * @param[in] p Person to be checked. - * @param[in] l Location to be checked. * @param[in] t TimePoint when to evaluate the TestingCriteria. */ - bool evaluate(const Person& p, const Location& l, TimePoint t) const + bool evaluate(const Person& p, TimePoint t) const { - return has_requested_age(p) && is_requested_location_type(l) && has_requested_infection_state(p, t); + return has_requested_age(p) && has_requested_infection_state(p, t); } private: @@ -158,18 +129,6 @@ class TestingCriteria return m_ages[(size_t)p.get_age()]; } - /** - * @brief Check if a Location is in the set of Location%s that are allowed for testing. - * @param[in] l Location to be checked. - */ - bool is_requested_location_type(const Location& l) const - { - if (m_locations.empty()) { - return true; // no condition on the location - } - return std::find(m_locations.begin(), m_locations.end(), l.get_type()) != m_locations.end(); - } - /** * @brief Check if a Person has the required InfectionState to get tested. * @param[in] p Person to be checked. @@ -185,8 +144,6 @@ class TestingCriteria std::bitset<(size_t)AgeGroup::Count> m_ages; ///< BitSet of #AgeGroup%s that are either allowed or required to be tested. - std::vector m_locations; /**< Set of #Location%s or #LocationState%s that are either allowed or required to be - tested.*/ std::bitset<(size_t)InfectionState::Count> m_infection_states; /**< BitSet of #InfectionState%s that are either allowed or required to be tested.*/ @@ -208,8 +165,8 @@ class TestingScheme * @param test_type The type of test to be performed. * @param probability Probability of the test to be performed if a testing rule applies. */ - TestingScheme(const TestingCriteria& testing_criteria, TimeSpan minimal_time_since_last_test, - TimePoint start_date, TimePoint end_date, const GenericTest& test_type, ScalarType probability); + TestingScheme(const TestingCriteria& testing_criteria, TimeSpan minimal_time_since_last_test, TimePoint start_date, + TimePoint end_date, const GenericTest& test_type, ScalarType probability); /** * @brief Compares two TestingScheme%s for functional equality. @@ -231,14 +188,13 @@ class TestingScheme /** * @brief Runs the TestingScheme and potentially tests a Person. * @param[in] person Person to check. - * @param[in] location Location to check. * @param[in] t TimePoint when to run the scheme. * @return If the person is allowed to enter the Location by the scheme. */ - bool run_scheme(Person& person, const Location& location, TimePoint t) const; + bool run_scheme(Person& person, TimePoint t) const; private: - TestingCriteria m_testing_criteria; ///< Vector with all TestingCriteria of the scheme. + TestingCriteria m_testing_criteria; ///< Vector with all TestingCriteria of the scheme. TimeSpan m_minimal_time_since_last_test; ///< Shortest period of time between two tests. TimePoint m_start_date; ///< Starting date of the scheme. TimePoint m_end_date; ///< Ending date of the scheme. @@ -258,19 +214,39 @@ class TestingStrategy * @param[in] testing_schemes Vector of TestingSchemes that are checked for testing. */ TestingStrategy() = default; - explicit TestingStrategy(const std::vector& testing_schemes); + explicit TestingStrategy(const std::map>& location_to_schemes_map); + + /** + * @brief Add a TestingScheme to the set of schemes that are checked for testing at a certain Location. + * @param[in] loc_id LocationId key for TestingScheme to be added. + * @param[in] scheme TestingScheme to be added. + */ + void add_testing_scheme(const LocationId& loc_id, const TestingScheme& scheme); /** - * @brief Add a TestingScheme to the set of schemes that are checked for testing. + * @brief Add a TestingScheme to the set of schemes that are checked for testing at a certain LocationType. + * A TestingScheme applies to all Location of the same type is store in + * LocationId{INVALID_LOCATION_INDEX, location_type} of m_location_to_schemes_map. + * @param[in] loc_type LocationId key for TestingScheme to be added. * @param[in] scheme TestingScheme to be added. */ - void add_testing_scheme(const TestingScheme& scheme); + void add_testing_scheme(const LocationType& loc_type, const TestingScheme& scheme); + + /** + * @brief Remove a TestingScheme from the set of schemes that are checked for testing at a certain Location. + * @param[in] loc_id LocationId key for TestingScheme to be remove. + * @param[in] scheme TestingScheme to be removed. + */ + void remove_testing_scheme(const LocationId& loc_id, const TestingScheme& scheme); /** - * @brief Remove a TestingScheme from the set of schemes that are checked for testing. + * @brief Remove a TestingScheme from the set of schemes that are checked for testing at a certain Location. + * A TestingScheme applies to all Location of the same type is store in + * LocationId{INVALID_LOCATION_INDEX, location_type} of m_location_to_schemes_map. + * @param[in] loc_type LocationType key for TestingScheme to be remove. * @param[in] scheme TestingScheme to be removed. */ - void remove_testing_scheme(const TestingScheme& scheme); + void remove_testing_scheme(const LocationType& loc_type, const TestingScheme& scheme); /** * @brief Checks if the given TimePoint is within the interval of start and end date of each TestingScheme and then @@ -286,10 +262,11 @@ class TestingStrategy * @param[in] t TimePoint when to run the strategy. * @return If the Person is allowed to enter the Location. */ - bool run_strategy(Person& person, const Location& location, TimePoint t) const; + bool run_strategy(Person& person, const Location& location, TimePoint t); private: - std::vector m_testing_schemes; ///< Set of schemes that are checked for testing. + std::map> + m_location_to_schemes_map; ///< Set of schemes that are checked for testing. }; } // namespace abm diff --git a/cpp/simulations/abm.cpp b/cpp/simulations/abm.cpp index f6414ac78f..c9812f43ab 100644 --- a/cpp/simulations/abm.cpp +++ b/cpp/simulations/abm.cpp @@ -308,8 +308,7 @@ void create_assign_locations(mio::abm::World& world) world.get_individualized_location(event).get_infection_parameters().set(100); world.get_individualized_location(event).set_capacity(100, 375); - std::vector test_at_social_event = {mio::abm::LocationType::SocialEvent}; - auto testing_criteria = mio::abm::TestingCriteria({}, test_at_social_event, {}); + auto testing_criteria = mio::abm::TestingCriteria(); auto testing_min_time = mio::abm::days(2); auto start_date = mio::abm::TimePoint(0); auto end_date = mio::abm::TimePoint(0) + mio::abm::days(60); @@ -321,7 +320,7 @@ void create_assign_locations(mio::abm::World& world) auto testing_scheme = mio::abm::TestingScheme(testing_criteria, testing_min_time, start_date, end_date, test_type, probability.draw_sample()); - world.get_testing_strategy().add_testing_scheme(testing_scheme); + world.get_testing_strategy().add_testing_scheme(mio::abm::LocationType::SocialEvent, testing_scheme); // Add hospital and ICU with 5 maximum contacs. // For the number of agents in this example we assume a capacity of 584 persons (80 beds per 10000 residents in @@ -410,22 +409,21 @@ void create_assign_locations(mio::abm::World& world) } // add the testing schemes for school and work - auto test_at_school = std::vector{mio::abm::LocationType::School}; - auto testing_criteria_school = mio::abm::TestingCriteria({}, test_at_school, {}); + auto testing_criteria_school = mio::abm::TestingCriteria(); testing_min_time = mio::abm::days(7); auto testing_scheme_school = mio::abm::TestingScheme(testing_criteria_school, testing_min_time, start_date, end_date, test_type, probability.draw_sample()); - world.get_testing_strategy().add_testing_scheme(testing_scheme_school); + world.get_testing_strategy().add_testing_scheme(mio::abm::LocationType::School, testing_scheme_school); auto test_at_work = std::vector{mio::abm::LocationType::Work}; - auto testing_criteria_work = mio::abm::TestingCriteria({}, test_at_work, {}); + auto testing_criteria_work = mio::abm::TestingCriteria(); assign_uniform_distribution(probability, 0.1, 0.5); testing_min_time = mio::abm::days(1); auto testing_scheme_work = mio::abm::TestingScheme(testing_criteria_work, testing_min_time, start_date, end_date, test_type, probability.draw_sample()); - world.get_testing_strategy().add_testing_scheme(testing_scheme_work); + world.get_testing_strategy().add_testing_scheme(mio::abm::LocationType::Work, testing_scheme_work); } /** diff --git a/cpp/tests/test_abm_testing_strategy.cpp b/cpp/tests/test_abm_testing_strategy.cpp index 87bb4b9f32..84dbb85b9d 100644 --- a/cpp/tests/test_abm_testing_strategy.cpp +++ b/cpp/tests/test_abm_testing_strategy.cpp @@ -26,33 +26,24 @@ TEST(TestTestingCriteria, addRemoveAndEvaluateTestCriteria) auto person = make_test_person(home, mio::abm::AgeGroup::Age15to34, mio::abm::InfectionState::InfectedSymptoms); mio::abm::TimePoint t{0}; - auto testing_criteria = mio::abm::TestingCriteria(); - ASSERT_EQ(testing_criteria.evaluate(person, work, t), true); + auto testing_criteria = mio::abm::TestingCriteria(); + ASSERT_EQ(testing_criteria.evaluate(person, t), true); testing_criteria.add_infection_state(mio::abm::InfectionState::InfectedSymptoms); testing_criteria.add_infection_state(mio::abm::InfectionState::InfectedNoSymptoms); - testing_criteria.add_location(mio::abm::LocationType::Home); - testing_criteria.add_location(mio::abm::LocationType::Work); - - ASSERT_EQ(testing_criteria.evaluate(person, work, t), true); - ASSERT_EQ(testing_criteria.evaluate(person, home, t), true); testing_criteria.add_age_group(mio::abm::AgeGroup::Age35to59); - ASSERT_EQ(testing_criteria.evaluate(person, home, t), + ASSERT_EQ(testing_criteria.evaluate(person, t), false); // now it isn't empty and get's evaluated against age group testing_criteria.remove_age_group(mio::abm::AgeGroup::Age35to59); - ASSERT_EQ(testing_criteria.evaluate(person, home, t), true); + ASSERT_EQ(testing_criteria.evaluate(person, t), true); testing_criteria.remove_infection_state(mio::abm::InfectionState::InfectedSymptoms); - ASSERT_EQ(testing_criteria.evaluate(person, home, t), false); - + ASSERT_EQ(testing_criteria.evaluate(person, t), false); testing_criteria.add_infection_state(mio::abm::InfectionState::InfectedSymptoms); - testing_criteria.remove_location(mio::abm::LocationType::Home); - ASSERT_EQ(testing_criteria.evaluate(person, home, t), false); - auto testing_criteria_manual = mio::abm::TestingCriteria( - {}, std::vector({mio::abm::LocationType::Work}), - std::vector( - {mio::abm::InfectionState::InfectedNoSymptoms, mio::abm::InfectionState::InfectedSymptoms})); + auto testing_criteria_manual = mio::abm::TestingCriteria( + {}, std::vector( + {mio::abm::InfectionState::InfectedNoSymptoms, mio::abm::InfectionState::InfectedSymptoms})); ASSERT_EQ(testing_criteria == testing_criteria_manual, true); testing_criteria_manual.remove_infection_state(mio::abm::InfectionState::InfectedSymptoms); ASSERT_EQ(testing_criteria == testing_criteria_manual, false); @@ -65,8 +56,7 @@ TEST(TestTestingScheme, runScheme) std::vector test_location_types1 = {mio::abm::LocationType::Home, mio::abm::LocationType::Work}; - auto testing_criteria1 = - mio::abm::TestingCriteria({}, test_location_types1, test_infection_states1); + auto testing_criteria1 = mio::abm::TestingCriteria({}, test_infection_states1); const auto testing_min_time = mio::abm::days(1); const auto start_date = mio::abm::TimePoint(0); @@ -85,9 +75,7 @@ TEST(TestTestingScheme, runScheme) testing_scheme1.update_activity_status(mio::abm::TimePoint(0)); std::vector test_infection_states2 = {mio::abm::InfectionState::Recovered}; - std::vector test_location_types2 = {mio::abm::LocationType::Home}; - auto testing_criteria2 = - mio::abm::TestingCriteria({}, test_location_types2, test_infection_states2); + auto testing_criteria2 = mio::abm::TestingCriteria({}, test_infection_states2); auto testing_scheme2 = mio::abm::TestingScheme(testing_criteria2, testing_min_time, start_date, end_date, test_type, probability); @@ -99,14 +87,14 @@ TEST(TestTestingScheme, runScheme) ScopedMockDistribution>>> mock_uniform_dist; EXPECT_CALL(mock_uniform_dist.get_mock(), invoke) - .Times(testing::Exactly(4)) + .Times(testing::Exactly(5)) .WillOnce(testing::Return(0.7)) .WillOnce(testing::Return(0.5)) .WillOnce(testing::Return(0.7)) - .WillOnce(testing::Return(0.5)); - ASSERT_EQ(testing_scheme1.run_scheme(person1, loc_home, start_date), false); // Person tests and tests positive - ASSERT_EQ(testing_scheme2.run_scheme(person2, loc_work, start_date), true); // Person tests and tests negative - ASSERT_EQ(testing_scheme1.run_scheme(person1, loc_home, start_date), - false); // Person is in quarantine and wants to go home -> can do so - ASSERT_EQ(testing_scheme2.run_scheme(person1, loc_work, start_date), true); // Person doesn't test + .WillOnce(testing::Return(0.5)) + .WillOnce(testing::Return(0.9)); + ASSERT_EQ(testing_scheme1.run_scheme(person1, start_date), false); // Person tests and tests positive + ASSERT_EQ(testing_scheme2.run_scheme(person2, start_date), true); // Person tests and tests negative + ASSERT_EQ(testing_scheme1.run_scheme(person1, start_date), + true); // Person doesn't test } diff --git a/cpp/tests/test_abm_world.cpp b/cpp/tests/test_abm_world.cpp index 2fc6370d21..e1f25b7ed6 100644 --- a/cpp/tests/test_abm_world.cpp +++ b/cpp/tests/test_abm_world.cpp @@ -361,10 +361,9 @@ TEST(TestWorldTestingCriteria, testAddingAndUpdatingAndRunningTestingSchemes) person.set_assigned_location(home); person.set_assigned_location(work); - auto testing_criteria = mio::abm::TestingCriteria({}, {mio::abm::LocationType::Home}, {}); + auto testing_criteria = mio::abm::TestingCriteria(); testing_criteria.add_infection_state(mio::abm::InfectionState::InfectedSymptoms); testing_criteria.add_infection_state(mio::abm::InfectionState::InfectedNoSymptoms); - testing_criteria.add_location(mio::abm::LocationType::Work); const auto testing_frequency = mio::abm::days(1); const auto start_date = mio::abm::TimePoint(20); @@ -373,9 +372,9 @@ TEST(TestWorldTestingCriteria, testAddingAndUpdatingAndRunningTestingSchemes) const auto test_type = mio::abm::PCRTest(); auto testing_scheme = - mio::abm::TestingScheme({testing_criteria}, testing_frequency, start_date, end_date, test_type, probability); + mio::abm::TestingScheme(testing_criteria, testing_frequency, start_date, end_date, test_type, probability); - world.get_testing_strategy().add_testing_scheme(testing_scheme); + world.get_testing_strategy().add_testing_scheme(mio::abm::LocationType::Work, testing_scheme); ASSERT_EQ(world.get_testing_strategy().run_strategy(person, work, current_time), true); // no active testing scheme -> person can enter current_time = mio::abm::TimePoint(30); @@ -387,7 +386,8 @@ TEST(TestWorldTestingCriteria, testAddingAndUpdatingAndRunningTestingSchemes) .WillOnce(testing::Return(0.4)); ASSERT_EQ(world.get_testing_strategy().run_strategy(person, work, current_time), false); - world.get_testing_strategy().add_testing_scheme(testing_scheme); //doesn't get added because of == operator - world.get_testing_strategy().remove_testing_scheme(testing_scheme); + world.get_testing_strategy().add_testing_scheme(mio::abm::LocationType::Work, + testing_scheme); //doesn't get added because of == operator + world.get_testing_strategy().remove_testing_scheme(mio::abm::LocationType::Work, testing_scheme); ASSERT_EQ(world.get_testing_strategy().run_strategy(person, work, current_time), true); // no more testing_schemes } From 120232e53582b33aa1cdd28674ef8df3ad21fb44 Mon Sep 17 00:00:00 2001 From: Khoa Nguyen <4763945+khoanguyen-dev@users.noreply.github.com> Date: Fri, 22 Sep 2023 14:27:19 +0200 Subject: [PATCH 11/18] Fix errors in pycode --- cpp/models/abm/testing_strategy.cpp | 66 +++++++++++++++++-- cpp/models/abm/testing_strategy.h | 59 +++-------------- .../memilio/simulation/abm.cpp | 13 ++-- .../memilio/simulation_test/test_abm.py | 3 +- 4 files changed, 78 insertions(+), 63 deletions(-) diff --git a/cpp/models/abm/testing_strategy.cpp b/cpp/models/abm/testing_strategy.cpp index 563c904864..85d54d9c79 100644 --- a/cpp/models/abm/testing_strategy.cpp +++ b/cpp/models/abm/testing_strategy.cpp @@ -27,6 +27,64 @@ namespace mio namespace abm { +TestingCriteria::TestingCriteria(const std::vector& ages, const std::vector& infection_states) +{ + for (auto age : ages) { + m_ages.set((size_t)age, true); + } + for (auto infection_state : infection_states) { + m_infection_states.set((size_t)infection_state, true); + } +} + +bool TestingCriteria::operator==(TestingCriteria other) const +{ + auto to_compare_ages = this->m_ages; + auto to_compare_infection_states = this->m_infection_states; + return to_compare_ages == other.m_ages && to_compare_infection_states == other.m_infection_states; +} + +void TestingCriteria::add_age_group(const AgeGroup age_group) +{ + m_ages.set((size_t)age_group, true); +} + +void TestingCriteria::remove_age_group(const AgeGroup age_group) +{ + m_ages.set((size_t)age_group, false); +} + +void TestingCriteria::add_infection_state(const InfectionState infection_state) +{ + m_infection_states.set((size_t)infection_state, true); +} + +void TestingCriteria::remove_infection_state(const InfectionState infection_state) +{ + m_infection_states.set((size_t)infection_state, false); +} + +bool TestingCriteria::evaluate(const Person& p, TimePoint t) const +{ + return has_requested_age(p) && has_requested_infection_state(p, t); +} + +bool TestingCriteria::has_requested_age(const Person& p) const +{ + if (m_ages.none()) { + return true; // no condition on the AgeGroup + } + return m_ages[(size_t)p.get_age()]; +} + +bool TestingCriteria::has_requested_infection_state(const Person& p, TimePoint t) const +{ + if (m_infection_states.none()) { + return true; // no condition on the InfectionState + } + return m_infection_states[(size_t)p.get_infection_state(t)]; +} + TestingScheme::TestingScheme(const TestingCriteria& testing_criteria, TimeSpan minimal_time_since_last_test, TimePoint start_date, TimePoint end_date, const GenericTest& test_type, double probability) : m_testing_criteria(testing_criteria) @@ -79,7 +137,7 @@ TestingStrategy::TestingStrategy(const std::map& ages, const std::vector& infection_states) - - { - for (auto age : ages) { - m_ages.set((size_t)age, true); - } - for (auto infection_state : infection_states) { - m_infection_states.set((size_t)infection_state, true); - } - } + TestingCriteria(const std::vector& ages, const std::vector& infection_states); /** * @brief Compares two TestingCriteria for functional equality. */ - bool operator==(TestingCriteria other) const - { - auto to_compare_ages = this->m_ages; - auto to_compare_infection_states = this->m_infection_states; - return to_compare_ages == other.m_ages && to_compare_infection_states == other.m_infection_states; - } + bool operator==(TestingCriteria other) const; /** * @brief Add an AgeGroup to the set of AgeGroup%s that are either allowed or required to be tested. * @param[in] age_group AgeGroup to be added. */ - void add_age_group(const AgeGroup age_group) - { - m_ages.set((size_t)age_group, true); - } + void add_age_group(const AgeGroup age_group); /** * @brief Remove an AgeGroup from the set of AgeGroup%s that are either allowed or required to be tested. * @param[in] age_group AgeGroup to be removed. */ - void remove_age_group(const AgeGroup age_group) - { - m_ages.set((size_t)age_group, false); - } + void remove_age_group(const AgeGroup age_group); /** * @brief Add an #InfectionState to the set of #InfectionState%s that are either allowed or required to be tested. * @param[in] infection_state #InfectionState to be added. */ - void add_infection_state(const InfectionState infection_state) - { - m_infection_states.set((size_t)infection_state, true); - } + void add_infection_state(const InfectionState infection_state); /** * @brief Remove an #InfectionState from the set of #InfectionState%s that are either allowed or required to be * tested. * @param[in] infection_state #InfectionState to be removed. */ - void remove_infection_state(const InfectionState infection_state) - { - m_infection_states.set((size_t)infection_state, false); - } + void remove_infection_state(const InfectionState infection_state); /** * @brief Check if a Person and a Location meet all the required properties to get tested. * @param[in] p Person to be checked. * @param[in] t TimePoint when to evaluate the TestingCriteria. */ - bool evaluate(const Person& p, TimePoint t) const - { - return has_requested_age(p) && has_requested_infection_state(p, t); - } + bool evaluate(const Person& p, TimePoint t) const; private: /** * @brief Check if a Person has the required age to get tested. * @param[in] p Person to be checked. */ - bool has_requested_age(const Person& p) const - { - if (m_ages.none()) { - return true; // no condition on the AgeGroup - } - return m_ages[(size_t)p.get_age()]; - } + bool has_requested_age(const Person& p) const; /** * @brief Check if a Person has the required InfectionState to get tested. * @param[in] p Person to be checked. * @param[in] t TimePoint when to check. */ - bool has_requested_infection_state(const Person& p, TimePoint t) const - { - if (m_infection_states.none()) { - return true; // no condition on the InfectionState - } - return m_infection_states[(size_t)p.get_infection_state(t)]; - } + bool has_requested_infection_state(const Person& p, TimePoint t) const; std::bitset<(size_t)AgeGroup::Count> m_ages; ///< BitSet of #AgeGroup%s that are either allowed or required to be tested. diff --git a/pycode/memilio-simulation/memilio/simulation/abm.cpp b/pycode/memilio-simulation/memilio/simulation/abm.cpp index 0cab84dd40..d98e10268b 100644 --- a/pycode/memilio-simulation/memilio/simulation/abm.cpp +++ b/pycode/memilio-simulation/memilio/simulation/abm.cpp @@ -143,18 +143,17 @@ PYBIND11_MODULE(_simulation_abm, m) .def_property_readonly("age", &mio::abm::Person::get_age) .def_property_readonly("is_in_quarantine", &mio::abm::Person::is_in_quarantine); - py::class_>(m, "TestingCriteria") - .def(py::init&, const std::vector&, - const std::vector&>(), - py::arg("age_groups"), py::arg("location_types"), py::arg("infection_states")); + py::class_(m, "TestingCriteria") + .def(py::init&, const std::vector&>(), + py::arg("age_groups"), py::arg("infection_states")); py::class_(m, "GenericTest").def(py::init<>()); py::class_(m, "AntigenTest").def(py::init<>()); py::class_(m, "PCRTest").def(py::init<>()); py::class_(m, "TestingScheme") - .def(py::init&, mio::abm::TimeSpan, mio::abm::TimePoint, - mio::abm::TimePoint, const mio::abm::GenericTest&, double>(), + .def(py::init(), py::arg("testing_criteria"), py::arg("testing_min_time_since_last_test"), py::arg("start_date"), py::arg("end_date"), py::arg("test_type"), py::arg("probability")) .def_property_readonly("active", &mio::abm::TestingScheme::is_active); @@ -165,7 +164,7 @@ PYBIND11_MODULE(_simulation_abm, m) .def_readwrite("time", &mio::abm::Vaccination::time); py::class_(m, "TestingStrategy") - .def(py::init&>()); + .def(py::init>&>()); py::class_(m, "Location") .def_property_readonly("type", &mio::abm::Location::get_type) diff --git a/pycode/memilio-simulation/memilio/simulation_test/test_abm.py b/pycode/memilio-simulation/memilio/simulation_test/test_abm.py index fd792d4b69..005ad7dd73 100644 --- a/pycode/memilio-simulation/memilio/simulation_test/test_abm.py +++ b/pycode/memilio-simulation/memilio/simulation_test/test_abm.py @@ -52,10 +52,9 @@ def test_locations(self): home.infection_parameters.MaximumContacts = 10 self.assertEqual(home.infection_parameters.MaximumContacts, 10) - testing_locations = [abm.LocationType.Home] testing_inf_states = [] testing_crit = abm.TestingCriteria( - testing_ages, testing_locations, testing_inf_states) + testing_ages, testing_inf_states) testing_scheme = abm.TestingScheme(testing_crit, abm.days( 1), t0, t0 + abm.days(1), abm.AntigenTest(), 1.0) # initially false, will only active once simulation starts From 898dee9d5a0ed74d28a2c2514942036e02305388 Mon Sep 17 00:00:00 2001 From: Khoa Nguyen <4763945+khoanguyen-dev@users.noreply.github.com> Date: Wed, 27 Sep 2023 16:13:11 +0200 Subject: [PATCH 12/18] Use unordered_map in TestingStrategy and avoid copy in run_strategy --- cpp/models/abm/location_type.h | 12 ++++++++++++ cpp/models/abm/testing_strategy.cpp | 27 ++++++++++++++------------- cpp/models/abm/testing_strategy.h | 4 ++-- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/cpp/models/abm/location_type.h b/cpp/models/abm/location_type.h index fa18cc6e9c..463ef96fbf 100644 --- a/cpp/models/abm/location_type.h +++ b/cpp/models/abm/location_type.h @@ -22,6 +22,7 @@ #include #include +#include namespace mio { @@ -80,4 +81,15 @@ struct LocationId { } // namespace abm } // namespace mio +namespace std +{ +template <> +struct hash { + std::size_t operator()(const mio::abm::LocationId& loc_id) const + { + return (hash()(loc_id.index)) ^ (hash()(static_cast(loc_id.type))); + } +}; +} // namespace std + #endif diff --git a/cpp/models/abm/testing_strategy.cpp b/cpp/models/abm/testing_strategy.cpp index 85d54d9c79..9bfe6848fe 100644 --- a/cpp/models/abm/testing_strategy.cpp +++ b/cpp/models/abm/testing_strategy.cpp @@ -130,7 +130,8 @@ bool TestingScheme::run_scheme(Person& person, TimePoint t) const return true; } -TestingStrategy::TestingStrategy(const std::map>& location_to_schemes_map) +TestingStrategy::TestingStrategy( + const std::unordered_map>& location_to_schemes_map) : m_location_to_schemes_map(location_to_schemes_map) { } @@ -178,19 +179,19 @@ bool TestingStrategy::run_strategy(Person& person, const Location& location, Tim return true; } - auto schemes_location_vector = m_location_to_schemes_map[LocationId{location.get_index(), location.get_type()}]; - auto schemes_location_type_vector = - m_location_to_schemes_map[LocationId{INVALID_LOCATION_INDEX, location.get_type()}]; - schemes_location_vector.insert(schemes_location_vector.end(), schemes_location_type_vector.begin(), - schemes_location_type_vector.end()); - std::cout << schemes_location_vector.size() << '\n'; - return std::all_of(schemes_location_vector.begin(), schemes_location_vector.end(), [&person, t](TestingScheme ts) { - if (ts.is_active()) { - std::cout << "run scheme" << '\n'; - return ts.run_scheme(person, t); + // Combine two vectors of schemes at corresponding location and location stype + std::vector*> schemes_vector = { + &m_location_to_schemes_map[LocationId{location.get_index(), location.get_type()}], + &m_location_to_schemes_map[LocationId{INVALID_LOCATION_INDEX, location.get_type()}]}; + + for (auto vec_ptr : schemes_vector) { + if (!std::all_of(vec_ptr->begin(), vec_ptr->end(), [&person, t](TestingScheme& ts) { + return !ts.is_active() || ts.run_scheme(person, t); + })) { + return false; } - return true; - }); + } + return true; } } // namespace abm diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index 2f7979f728..3a6664992a 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -173,7 +173,7 @@ class TestingStrategy * @param[in] testing_schemes Vector of TestingSchemes that are checked for testing. */ TestingStrategy() = default; - explicit TestingStrategy(const std::map>& location_to_schemes_map); + explicit TestingStrategy(const std::unordered_map>& location_to_schemes_map); /** * @brief Add a TestingScheme to the set of schemes that are checked for testing at a certain Location. @@ -224,7 +224,7 @@ class TestingStrategy bool run_strategy(Person& person, const Location& location, TimePoint t); private: - std::map> + std::unordered_map> m_location_to_schemes_map; ///< Set of schemes that are checked for testing. }; From 6df42b322e2841ef3000aa84cc9937259d5fa271 Mon Sep 17 00:00:00 2001 From: Khoa Nguyen <4763945+khoanguyen-dev@users.noreply.github.com> Date: Wed, 18 Oct 2023 14:33:15 +0200 Subject: [PATCH 13/18] Optimise testing_strategy --- cpp/models/abm/testing_strategy.cpp | 40 +++++++------------------ cpp/models/abm/testing_strategy.h | 15 +--------- cpp/models/abm/world.h | 7 +++-- cpp/tests/test_abm_testing_strategy.cpp | 9 +++--- 4 files changed, 21 insertions(+), 50 deletions(-) diff --git a/cpp/models/abm/testing_strategy.cpp b/cpp/models/abm/testing_strategy.cpp index 5e09ac6bc8..c1f0c279f3 100644 --- a/cpp/models/abm/testing_strategy.cpp +++ b/cpp/models/abm/testing_strategy.cpp @@ -20,7 +20,6 @@ #include "abm/testing_strategy.h" #include "memilio/utils/random_number_generator.h" -#include namespace mio { @@ -30,59 +29,42 @@ namespace abm TestingCriteria::TestingCriteria(const std::vector& ages, const std::vector& infection_states) { for (auto age : ages) { - m_ages.set((size_t)age, true); + m_ages.set(static_cast(age), true); } for (auto infection_state : infection_states) { - m_infection_states.set((size_t)infection_state, true); + m_infection_states.set(static_cast(infection_state), true); } } -bool TestingCriteria::operator==(TestingCriteria other) const +bool TestingCriteria::operator==(const TestingCriteria& other) const { - auto to_compare_ages = this->m_ages; - auto to_compare_infection_states = this->m_infection_states; - return to_compare_ages == other.m_ages && to_compare_infection_states == other.m_infection_states; + return m_ages == other.m_ages && m_infection_states == other.m_infection_states; } void TestingCriteria::add_age_group(const AgeGroup age_group) { - m_ages.set((size_t)age_group, true); + m_ages.set(static_cast(age_group), true); } void TestingCriteria::remove_age_group(const AgeGroup age_group) { - m_ages.set((size_t)age_group, false); + m_ages.set(static_cast(age_group), false); } void TestingCriteria::add_infection_state(const InfectionState infection_state) { - m_infection_states.set((size_t)infection_state, true); + m_infection_states.set(static_cast(infection_state), true); } void TestingCriteria::remove_infection_state(const InfectionState infection_state) { - m_infection_states.set((size_t)infection_state, false); + m_infection_states.set(static_cast(infection_state), false); } bool TestingCriteria::evaluate(const Person& p, TimePoint t) const { - return has_requested_age(p) && has_requested_infection_state(p, t); -} - -bool TestingCriteria::has_requested_age(const Person& p) const -{ - if (m_ages.none()) { - return true; // no condition on the AgeGroup - } - return m_ages[(size_t)p.get_age()]; -} - -bool TestingCriteria::has_requested_infection_state(const Person& p, TimePoint t) const -{ - if (m_infection_states.none()) { - return true; // no condition on the InfectionState - } - return m_infection_states[(size_t)p.get_infection_state(t)]; + return (m_ages.none() || m_ages[static_cast(p.get_age())]) && + (m_infection_states.none() || m_infection_states[static_cast(p.get_infection_state(t))]); } TestingScheme::TestingScheme(const TestingCriteria& testing_criteria, TimeSpan minimal_time_since_last_test, @@ -140,7 +122,7 @@ void TestingStrategy::add_testing_scheme(const LocationId& loc_id, const Testing { auto& schemes_vector = m_location_to_schemes_map[loc_id]; if (std::find(schemes_vector.begin(), schemes_vector.end(), scheme) == schemes_vector.end()) { - schemes_vector.push_back(scheme); + schemes_vector.emplace_back(scheme); } } diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index ce76b011c2..fadf2cd201 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -54,7 +54,7 @@ class TestingCriteria /** * @brief Compares two TestingCriteria for functional equality. */ - bool operator==(TestingCriteria other) const; + bool operator==(const TestingCriteria& other) const; /** * @brief Add an AgeGroup to the set of AgeGroup%s that are either allowed or required to be tested. @@ -89,19 +89,6 @@ class TestingCriteria bool evaluate(const Person& p, TimePoint t) const; private: - /** - * @brief Check if a Person has the required age to get tested. - * @param[in] p Person to be checked. - */ - bool has_requested_age(const Person& p) const; - - /** - * @brief Check if a Person has the required InfectionState to get tested. - * @param[in] p Person to be checked. - * @param[in] t TimePoint when to check. - */ - bool has_requested_infection_state(const Person& p, TimePoint t) const; - std::bitset<(size_t)AgeGroup::Count> m_ages; ///< BitSet of #AgeGroup%s that are either allowed or required to be tested. std::bitset<(size_t)InfectionState::Count> diff --git a/cpp/models/abm/world.h b/cpp/models/abm/world.h index 901582acde..a10a78b671 100644 --- a/cpp/models/abm/world.h +++ b/cpp/models/abm/world.h @@ -55,7 +55,8 @@ class World * @param[in] params Parameters of the Infection that are the same everywhere in the World. */ World(const GlobalInfectionParameters& params = {}) - : m_infection_parameters(params) + : m_testing_strategy() + , m_infection_parameters(params) , m_migration_parameters() , m_trip_list() , m_cemetery_id(add_location(LocationType::Cemetery)) @@ -64,10 +65,10 @@ class World } //type is move-only for stable references of persons/locations - World(World&& other) = default; + World(World&& other) = default; World& operator=(World&& other) = default; World(const World&) = delete; - World& operator=(const World&) = delete; + World& operator=(const World&) = delete; /** * @brief Prepare the World for the next Simulation step. diff --git a/cpp/tests/test_abm_testing_strategy.cpp b/cpp/tests/test_abm_testing_strategy.cpp index 361dbf451f..9de93aa931 100644 --- a/cpp/tests/test_abm_testing_strategy.cpp +++ b/cpp/tests/test_abm_testing_strategy.cpp @@ -44,11 +44,12 @@ TEST(TestTestingCriteria, addRemoveAndEvaluateTestCriteria) testing_criteria.add_infection_state(mio::abm::InfectionState::InfectedSymptoms); auto testing_criteria_manual = mio::abm::TestingCriteria( - {}, std::vector( - {mio::abm::InfectionState::InfectedNoSymptoms, mio::abm::InfectionState::InfectedSymptoms})); - ASSERT_EQ(testing_criteria == testing_criteria_manual, true); - testing_criteria_manual.remove_infection_state(mio::abm::InfectionState::InfectedSymptoms); + std::vector({mio::abm::AgeGroup::Age15to34}), + std::vector({mio::abm::InfectionState::InfectedNoSymptoms})); ASSERT_EQ(testing_criteria == testing_criteria_manual, false); + testing_criteria_manual.add_infection_state(mio::abm::InfectionState::InfectedSymptoms); + testing_criteria_manual.remove_age_group(mio::abm::AgeGroup::Age15to34); + ASSERT_EQ(testing_criteria == testing_criteria_manual, true); } TEST(TestTestingScheme, runScheme) From 668fb6bc965df41470f1f06d3e7b597896b48bc5 Mon Sep 17 00:00:00 2001 From: Khoa Nguyen <4763945+khoanguyen-dev@users.noreply.github.com> Date: Thu, 19 Oct 2023 12:06:08 +0200 Subject: [PATCH 14/18] Add test for initialising TestStrategy --- cpp/models/abm/testing_strategy.cpp | 2 +- cpp/models/abm/world.h | 3 +- cpp/tests/test_abm_testing_strategy.cpp | 60 ++++++++++++++++++++----- 3 files changed, 52 insertions(+), 13 deletions(-) diff --git a/cpp/models/abm/testing_strategy.cpp b/cpp/models/abm/testing_strategy.cpp index c1f0c279f3..8e4826f7ca 100644 --- a/cpp/models/abm/testing_strategy.cpp +++ b/cpp/models/abm/testing_strategy.cpp @@ -163,7 +163,7 @@ bool TestingStrategy::run_strategy(Person::RandomNumberGenerator& rng, Person& p } // Combine two vectors of schemes at corresponding location and location stype - std::vector*> schemes_vector = { + std::vector* schemes_vector[] = { &m_location_to_schemes_map[LocationId{location.get_index(), location.get_type()}], &m_location_to_schemes_map[LocationId{INVALID_LOCATION_INDEX, location.get_type()}]}; diff --git a/cpp/models/abm/world.h b/cpp/models/abm/world.h index a10a78b671..e1e52366d3 100644 --- a/cpp/models/abm/world.h +++ b/cpp/models/abm/world.h @@ -55,8 +55,7 @@ class World * @param[in] params Parameters of the Infection that are the same everywhere in the World. */ World(const GlobalInfectionParameters& params = {}) - : m_testing_strategy() - , m_infection_parameters(params) + : m_infection_parameters(params) , m_migration_parameters() , m_trip_list() , m_cemetery_id(add_location(LocationType::Cemetery)) diff --git a/cpp/tests/test_abm_testing_strategy.cpp b/cpp/tests/test_abm_testing_strategy.cpp index 9de93aa931..a07d276f14 100644 --- a/cpp/tests/test_abm_testing_strategy.cpp +++ b/cpp/tests/test_abm_testing_strategy.cpp @@ -54,21 +54,17 @@ TEST(TestTestingCriteria, addRemoveAndEvaluateTestCriteria) TEST(TestTestingScheme, runScheme) { - auto rng = mio::RandomNumberGenerator(); - - std::vector test_infection_states1 = {mio::abm::InfectionState::InfectedSymptoms, - mio::abm::InfectionState::InfectedNoSymptoms}; - std::vector test_location_types1 = {mio::abm::LocationType::Home, - mio::abm::LocationType::Work}; - - auto testing_criteria1 = mio::abm::TestingCriteria({}, test_infection_states1); - + auto rng = mio::RandomNumberGenerator(); const auto testing_min_time = mio::abm::days(1); const auto start_date = mio::abm::TimePoint(0); const auto end_date = mio::abm::TimePoint(60 * 60 * 24 * 3); const auto probability = 0.8; const auto test_type = mio::abm::PCRTest(); + std::vector test_infection_states = {mio::abm::InfectionState::InfectedSymptoms, + mio::abm::InfectionState::InfectedNoSymptoms}; + + auto testing_criteria1 = mio::abm::TestingCriteria({}, test_infection_states); auto testing_scheme1 = mio::abm::TestingScheme(testing_criteria1, testing_min_time, start_date, end_date, test_type, probability); @@ -85,7 +81,6 @@ TEST(TestTestingScheme, runScheme) mio::abm::TestingScheme(testing_criteria2, testing_min_time, start_date, end_date, test_type, probability); mio::abm::Location loc_home(mio::abm::LocationType::Home, 0); - mio::abm::Location loc_work(mio::abm::LocationType::Work, 0); auto person1 = make_test_person(loc_home, mio::abm::AgeGroup::Age15to34, mio::abm::InfectionState::InfectedNoSymptoms); auto rng_person1 = mio::abm::Person::RandomNumberGenerator(rng, person1); @@ -105,3 +100,48 @@ TEST(TestTestingScheme, runScheme) ASSERT_EQ(testing_scheme1.run_scheme(rng_person1, person1, start_date), true); // Person doesn't test } + +TEST(TestTestingScheme, initAndRunTestingStrategy) +{ + auto rng = mio::RandomNumberGenerator(); + const auto testing_min_time = mio::abm::days(1); + const auto start_date = mio::abm::TimePoint(0); + const auto end_date = mio::abm::TimePoint(60 * 60 * 24 * 3); + const auto probability = 0.8; + const auto test_type = mio::abm::PCRTest(); + + std::vector test_infection_states = {mio::abm::InfectionState::InfectedSymptoms, + mio::abm::InfectionState::InfectedNoSymptoms}; + auto testing_criteria1 = mio::abm::TestingCriteria({}, test_infection_states); + auto testing_scheme1 = + mio::abm::TestingScheme(testing_criteria1, testing_min_time, start_date, end_date, test_type, probability); + testing_scheme1.update_activity_status(mio::abm::TimePoint(0)); + std::vector test_infection_states2 = {mio::abm::InfectionState::Recovered}; + auto testing_criteria2 = mio::abm::TestingCriteria({}, test_infection_states2); + auto testing_scheme2 = + mio::abm::TestingScheme(testing_criteria2, testing_min_time, start_date, end_date, test_type, probability); + + mio::abm::Location loc_home(mio::abm::LocationType::Home, 0); + auto person1 = + make_test_person(loc_home, mio::abm::AgeGroup::Age15to34, mio::abm::InfectionState::InfectedNoSymptoms); + auto rng_person1 = mio::abm::Person::RandomNumberGenerator(rng, person1); + auto person2 = make_test_person(loc_home, mio::abm::AgeGroup::Age15to34, mio::abm::InfectionState::Recovered); + auto rng_person2 = mio::abm::Person::RandomNumberGenerator(rng, person2); + + ScopedMockDistribution>>> mock_uniform_dist; + EXPECT_CALL(mock_uniform_dist.get_mock(), invoke) + .Times(testing::Exactly(3)) + .WillOnce(testing::Return(0.7)) + .WillOnce(testing::Return(0.5)) + .WillOnce(testing::Return(0.9)); + + mio::abm::TestingStrategy test_strategy = + mio::abm::TestingStrategy(std::unordered_map>()); + test_strategy.add_testing_scheme(mio::abm::LocationType::Home, testing_scheme1); + test_strategy.add_testing_scheme(mio::abm::LocationType::Home, testing_scheme2); + ASSERT_EQ(test_strategy.run_strategy(rng_person1, person1, loc_home, start_date), + false); // Person tests and tests positive + ASSERT_EQ(test_strategy.run_strategy(rng_person2, person2, loc_home, start_date), + true); // Person tests and tests negative + ASSERT_EQ(test_strategy.run_strategy(rng_person1, person1, loc_home, start_date), true); // Person doesn't test +} From fd5b5777c2118309f38c6651c1ac9d64beb52630 Mon Sep 17 00:00:00 2001 From: Khoa Nguyen <4763945+khoanguyen-dev@users.noreply.github.com> Date: Mon, 23 Oct 2023 11:51:34 +0200 Subject: [PATCH 15/18] Fix errors in using bitset in TestingStrategy --- cpp/models/abm/analyze_result 2.h | 206 ++++++++++++++++++++++++ cpp/models/abm/testing_strategy.cpp | 8 +- cpp/models/abm/testing_strategy.h | 4 +- cpp/tests/test_abm_testing_strategy.cpp | 12 +- 4 files changed, 217 insertions(+), 13 deletions(-) create mode 100644 cpp/models/abm/analyze_result 2.h diff --git a/cpp/models/abm/analyze_result 2.h b/cpp/models/abm/analyze_result 2.h new file mode 100644 index 0000000000..0e93be84a9 --- /dev/null +++ b/cpp/models/abm/analyze_result 2.h @@ -0,0 +1,206 @@ +/* +* Copyright (C) 2020-2024 MEmilio +* +* Authors: Daniel Abele, Khoa Nguyen +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#ifndef ABM_ANALYZE_RESULT_H +#define ABM_ANALYZE_RESULT_H + +#include "abm/simulation.h" +#include "abm/parameters.h" +#include "memilio/data/analyze_result.h" + +#include +#include + +namespace mio +{ +namespace abm +{ +/** + * @brief computes the p percentile of the parameters for each node. + * @param ensemble_result graph of multiple simulation runs + * @param p percentile value in open interval (0, 1) + * @return p percentile of the parameters over all runs + */ +template +std::vector ensemble_params_percentile(const std::vector>& ensemble_params, double p) +{ + assert(p > 0.0 && p < 1.0 && "Invalid percentile value."); + + auto num_runs = ensemble_params.size(); + auto num_nodes = ensemble_params[0].size(); + std::vector single_element_ensemble(num_runs); + auto num_groups = (size_t)ensemble_params[0][0].parameters.get_num_groups(); + + // Lambda function that calculates the percentile of a single parameter + std::vector percentile(num_nodes, Model((int)num_groups)); + auto param_percentil = [&ensemble_params, p, num_runs, &percentile](auto n, auto get_param) mutable { + std::vector single_element(num_runs); + for (size_t run = 0; run < num_runs; run++) { + auto const& params = ensemble_params[run][n]; + single_element[run] = get_param(params); + } + std::sort(single_element.begin(), single_element.end()); + auto& new_params = get_param(percentile[n]); + new_params = single_element[static_cast(num_runs * p)]; + }; + + for (size_t node = 0; node < num_nodes; node++) { + for (auto age_group = AgeGroup(0); age_group < AgeGroup(num_groups); age_group++) { + for (auto virus_variant = VirusVariant(0); virus_variant < VirusVariant::Count; + virus_variant = static_cast((uint32_t)virus_variant + 1)) { + // Global infection parameters + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + static auto result = + model.parameters.template get()[{virus_variant, age_group}] + .viral_load_incline.params.a(); + return result; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + static auto result = + model.parameters.template get()[{virus_variant, age_group}] + .viral_load_incline.params.b(); + return result; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + static auto result = + model.parameters.template get()[{virus_variant, age_group}] + .viral_load_decline.params.a(); + return result; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + static auto result = + model.parameters.template get()[{virus_variant, age_group}] + .viral_load_decline.params.b(); + return result; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + static auto result = + model.parameters.template get()[{virus_variant, age_group}] + .viral_load_peak.params.a(); + return result; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + static auto result = + model.parameters.template get()[{virus_variant, age_group}] + .viral_load_peak.params.b(); + return result; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + static auto result = + model.parameters.template get()[{virus_variant, age_group}] + .infectivity_alpha.params.a(); + return result; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + static auto result = + model.parameters.template get()[{virus_variant, age_group}] + .infectivity_alpha.params.b(); + return result; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + static auto result = + model.parameters.template get()[{virus_variant, age_group}] + .infectivity_beta.params.a(); + return result; + }); + param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + static auto result = + model.parameters.template get()[{virus_variant, age_group}] + .infectivity_beta.params.b(); + return result; + }); + param_percentil(node, [virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant}]; + }); + } + param_percentil(node, [age_group](auto&& model) -> auto& { + return model.parameters.template get()[{age_group}]; + }); + param_percentil(node, [age_group](auto&& model) -> auto& { + static auto result = model.parameters.template get()[{age_group}].hours(); + return result; + }); + param_percentil(node, [age_group](auto&& model) -> auto& { + static auto result = model.parameters.template get()[{age_group}].hours(); + return result; + }); + param_percentil(node, [age_group](auto&& model) -> auto& { + static auto result = model.parameters.template get()[{age_group}].hours(); + return result; + }); + param_percentil(node, [age_group](auto&& model) -> auto& { + static auto result = model.parameters.template get()[{age_group}].hours(); + return result; + }); + } + param_percentil(node, [](auto&& model) -> auto& { + return model.parameters.template get()[MaskType::Community]; + }); + param_percentil(node, [](auto&& model) -> auto& { + return model.parameters.template get()[MaskType::FFP2]; + }); + param_percentil(node, [](auto&& model) -> auto& { + return model.parameters.template get()[MaskType::Surgical]; + }); + param_percentil(node, [](auto&& model) -> auto& { + static auto result = model.parameters.template get().days(); + return result; + }); + } + + return percentile; +} + +} // namespace abm +} // namespace mio + +#endif diff --git a/cpp/models/abm/testing_strategy.cpp b/cpp/models/abm/testing_strategy.cpp index 031859208a..5eb40beb4a 100644 --- a/cpp/models/abm/testing_strategy.cpp +++ b/cpp/models/abm/testing_strategy.cpp @@ -29,7 +29,7 @@ namespace abm TestingCriteria::TestingCriteria(const std::vector& ages, const std::vector& infection_states) { for (auto age : ages) { - m_ages.set(static_cast(age), true); + m_ages.insert(static_cast(age)); } for (auto infection_state : infection_states) { m_infection_states.set(static_cast(infection_state), true); @@ -43,12 +43,12 @@ bool TestingCriteria::operator==(const TestingCriteria& other) const void TestingCriteria::add_age_group(const AgeGroup age_group) { - m_ages.set(static_cast(age_group), true); + m_ages.insert(static_cast(age_group)); } void TestingCriteria::remove_age_group(const AgeGroup age_group) { - m_ages.set(static_cast(age_group), false); + m_ages.erase(static_cast(age_group)); } void TestingCriteria::add_infection_state(const InfectionState infection_state) @@ -63,7 +63,7 @@ void TestingCriteria::remove_infection_state(const InfectionState infection_stat bool TestingCriteria::evaluate(const Person& p, TimePoint t) const { - return (m_ages.none() || m_ages[static_cast(p.get_age())]) && + return (m_ages.empty() || m_ages.count(static_cast(p.get_age()))) && (m_infection_states.none() || m_infection_states[static_cast(p.get_infection_state(t))]); } diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index 67aa96b891..d78f2029b1 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -26,6 +26,7 @@ #include "abm/time.h" #include "memilio/utils/random_number_generator.h" #include +#include #include namespace mio @@ -89,8 +90,7 @@ class TestingCriteria bool evaluate(const Person& p, TimePoint t) const; private: - std::bitset<(size_t)AgeGroup::Count> - m_ages; ///< BitSet of #AgeGroup%s that are either allowed or required to be tested. + std::unordered_set m_ages; ///< Set of #AgeGroup%s that are either allowed or required to be tested. std::bitset<(size_t)InfectionState::Count> m_infection_states; /**< BitSet of #InfectionState%s that are either allowed or required to be tested.*/ diff --git a/cpp/tests/test_abm_testing_strategy.cpp b/cpp/tests/test_abm_testing_strategy.cpp index 7a9e30437d..0760e36b62 100644 --- a/cpp/tests/test_abm_testing_strategy.cpp +++ b/cpp/tests/test_abm_testing_strategy.cpp @@ -44,11 +44,11 @@ TEST(TestTestingCriteria, addRemoveAndEvaluateTestCriteria) testing_criteria.add_infection_state(mio::abm::InfectionState::InfectedSymptoms); auto testing_criteria_manual = mio::abm::TestingCriteria( - std::vector({mio::abm::AgeGroup::Age15to34}), + std::vector({AGE_GROUP_15_TO_34}), std::vector({mio::abm::InfectionState::InfectedNoSymptoms})); ASSERT_EQ(testing_criteria == testing_criteria_manual, false); testing_criteria_manual.add_infection_state(mio::abm::InfectionState::InfectedSymptoms); - testing_criteria_manual.remove_age_group(mio::abm::AgeGroup::Age15to34); + testing_criteria_manual.remove_age_group(AGE_GROUP_15_TO_34); ASSERT_EQ(testing_criteria == testing_criteria_manual, true); } @@ -61,7 +61,7 @@ TEST(TestTestingScheme, runScheme) std::vector test_location_types1 = {mio::abm::LocationType::Home, mio::abm::LocationType::Work}; - auto testing_criteria1 = mio::abm::TestingCriteria({}, test_location_types1, test_infection_states1); + auto testing_criteria1 = mio::abm::TestingCriteria({}, test_infection_states1); std::vector testing_criterias = {testing_criteria1}; const auto testing_min_time = mio::abm::days(1); @@ -73,7 +73,6 @@ TEST(TestTestingScheme, runScheme) std::vector test_infection_states = {mio::abm::InfectionState::InfectedSymptoms, mio::abm::InfectionState::InfectedNoSymptoms}; - auto testing_criteria1 = mio::abm::TestingCriteria({}, test_infection_states); auto testing_scheme1 = mio::abm::TestingScheme(testing_criteria1, testing_min_time, start_date, end_date, test_type, probability); @@ -131,10 +130,9 @@ TEST(TestTestingScheme, initAndRunTestingStrategy) mio::abm::TestingScheme(testing_criteria2, testing_min_time, start_date, end_date, test_type, probability); mio::abm::Location loc_home(mio::abm::LocationType::Home, 0); - auto person1 = - make_test_person(loc_home, mio::abm::AgeGroup::Age15to34, mio::abm::InfectionState::InfectedNoSymptoms); + auto person1 = make_test_person(loc_home, AGE_GROUP_15_TO_34, mio::abm::InfectionState::InfectedNoSymptoms); auto rng_person1 = mio::abm::Person::RandomNumberGenerator(rng, person1); - auto person2 = make_test_person(loc_home, mio::abm::AgeGroup::Age15to34, mio::abm::InfectionState::Recovered); + auto person2 = make_test_person(loc_home, AGE_GROUP_15_TO_34, mio::abm::InfectionState::Recovered); auto rng_person2 = mio::abm::Person::RandomNumberGenerator(rng, person2); ScopedMockDistribution>>> mock_uniform_dist; From 50f626f3951aba5de73c72ce23bd32f9980465c5 Mon Sep 17 00:00:00 2001 From: Khoa Nguyen <4763945+khoanguyen-dev@users.noreply.github.com> Date: Mon, 23 Oct 2023 16:10:38 +0200 Subject: [PATCH 16/18] Remove analyze_result 2.h --- cpp/models/abm/analyze_result 2.h | 206 ------------------------------ cpp/models/abm/testing_strategy.h | 1 - 2 files changed, 207 deletions(-) delete mode 100644 cpp/models/abm/analyze_result 2.h diff --git a/cpp/models/abm/analyze_result 2.h b/cpp/models/abm/analyze_result 2.h deleted file mode 100644 index 0e93be84a9..0000000000 --- a/cpp/models/abm/analyze_result 2.h +++ /dev/null @@ -1,206 +0,0 @@ -/* -* Copyright (C) 2020-2024 MEmilio -* -* Authors: Daniel Abele, Khoa Nguyen -* -* Contact: Martin J. Kuehn -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -#ifndef ABM_ANALYZE_RESULT_H -#define ABM_ANALYZE_RESULT_H - -#include "abm/simulation.h" -#include "abm/parameters.h" -#include "memilio/data/analyze_result.h" - -#include -#include - -namespace mio -{ -namespace abm -{ -/** - * @brief computes the p percentile of the parameters for each node. - * @param ensemble_result graph of multiple simulation runs - * @param p percentile value in open interval (0, 1) - * @return p percentile of the parameters over all runs - */ -template -std::vector ensemble_params_percentile(const std::vector>& ensemble_params, double p) -{ - assert(p > 0.0 && p < 1.0 && "Invalid percentile value."); - - auto num_runs = ensemble_params.size(); - auto num_nodes = ensemble_params[0].size(); - std::vector single_element_ensemble(num_runs); - auto num_groups = (size_t)ensemble_params[0][0].parameters.get_num_groups(); - - // Lambda function that calculates the percentile of a single parameter - std::vector percentile(num_nodes, Model((int)num_groups)); - auto param_percentil = [&ensemble_params, p, num_runs, &percentile](auto n, auto get_param) mutable { - std::vector single_element(num_runs); - for (size_t run = 0; run < num_runs; run++) { - auto const& params = ensemble_params[run][n]; - single_element[run] = get_param(params); - } - std::sort(single_element.begin(), single_element.end()); - auto& new_params = get_param(percentile[n]); - new_params = single_element[static_cast(num_runs * p)]; - }; - - for (size_t node = 0; node < num_nodes; node++) { - for (auto age_group = AgeGroup(0); age_group < AgeGroup(num_groups); age_group++) { - for (auto virus_variant = VirusVariant(0); virus_variant < VirusVariant::Count; - virus_variant = static_cast((uint32_t)virus_variant + 1)) { - // Global infection parameters - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .viral_load_incline.params.a(); - return result; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .viral_load_incline.params.b(); - return result; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .viral_load_decline.params.a(); - return result; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .viral_load_decline.params.b(); - return result; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .viral_load_peak.params.a(); - return result; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .viral_load_peak.params.b(); - return result; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .infectivity_alpha.params.a(); - return result; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .infectivity_alpha.params.b(); - return result; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .infectivity_beta.params.a(); - return result; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .infectivity_beta.params.b(); - return result; - }); - param_percentil(node, [virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant}]; - }); - } - param_percentil(node, [age_group](auto&& model) -> auto& { - return model.parameters.template get()[{age_group}]; - }); - param_percentil(node, [age_group](auto&& model) -> auto& { - static auto result = model.parameters.template get()[{age_group}].hours(); - return result; - }); - param_percentil(node, [age_group](auto&& model) -> auto& { - static auto result = model.parameters.template get()[{age_group}].hours(); - return result; - }); - param_percentil(node, [age_group](auto&& model) -> auto& { - static auto result = model.parameters.template get()[{age_group}].hours(); - return result; - }); - param_percentil(node, [age_group](auto&& model) -> auto& { - static auto result = model.parameters.template get()[{age_group}].hours(); - return result; - }); - } - param_percentil(node, [](auto&& model) -> auto& { - return model.parameters.template get()[MaskType::Community]; - }); - param_percentil(node, [](auto&& model) -> auto& { - return model.parameters.template get()[MaskType::FFP2]; - }); - param_percentil(node, [](auto&& model) -> auto& { - return model.parameters.template get()[MaskType::Surgical]; - }); - param_percentil(node, [](auto&& model) -> auto& { - static auto result = model.parameters.template get().days(); - return result; - }); - } - - return percentile; -} - -} // namespace abm -} // namespace mio - -#endif diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index d78f2029b1..bf338523f3 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -27,7 +27,6 @@ #include "memilio/utils/random_number_generator.h" #include #include -#include namespace mio { From 2d935f885128e56fca5f4db03303bb9c72a14dba Mon Sep 17 00:00:00 2001 From: Khoa Nguyen <4763945+khoanguyen-dev@users.noreply.github.com> Date: Thu, 9 Nov 2023 21:37:43 +0100 Subject: [PATCH 17/18] Small fixes according to Sascha's comments --- cpp/models/abm/location_type.h | 7 ++----- cpp/models/abm/testing_strategy.cpp | 13 +------------ cpp/models/abm/testing_strategy.h | 21 +++++++++++++++------ 3 files changed, 18 insertions(+), 23 deletions(-) diff --git a/cpp/models/abm/location_type.h b/cpp/models/abm/location_type.h index 8fb36c36d3..9b685a9a7a 100644 --- a/cpp/models/abm/location_type.h +++ b/cpp/models/abm/location_type.h @@ -99,15 +99,12 @@ struct GeographicalLocation { } // namespace abm } // namespace mio -namespace std -{ template <> -struct hash { +struct std::hash { std::size_t operator()(const mio::abm::LocationId& loc_id) const { - return (hash()(loc_id.index)) ^ (hash()(static_cast(loc_id.type))); + return (std::hash()(loc_id.index)) ^ (std::hash()(static_cast(loc_id.type))); } }; -} // namespace std #endif diff --git a/cpp/models/abm/testing_strategy.cpp b/cpp/models/abm/testing_strategy.cpp index 5eb40beb4a..1f82bdadc7 100644 --- a/cpp/models/abm/testing_strategy.cpp +++ b/cpp/models/abm/testing_strategy.cpp @@ -63,6 +63,7 @@ void TestingCriteria::remove_infection_state(const InfectionState infection_stat bool TestingCriteria::evaluate(const Person& p, TimePoint t) const { + // An empty vector of ages or none bitset of #InfectionStates% means that no condition on the corresponding property is set. return (m_ages.empty() || m_ages.count(static_cast(p.get_age()))) && (m_infection_states.none() || m_infection_states[static_cast(p.get_infection_state(t))]); } @@ -126,12 +127,6 @@ void TestingStrategy::add_testing_scheme(const LocationId& loc_id, const Testing } } -void TestingStrategy::add_testing_scheme(const LocationType& loc_type, const TestingScheme& scheme) -{ - auto loc_id = LocationId{INVALID_LOCATION_INDEX, loc_type}; - add_testing_scheme(loc_id, scheme); -} - void TestingStrategy::remove_testing_scheme(const LocationId& loc_id, const TestingScheme& scheme) { auto& schemes_vector = m_location_to_schemes_map[loc_id]; @@ -139,12 +134,6 @@ void TestingStrategy::remove_testing_scheme(const LocationId& loc_id, const Test schemes_vector.erase(last, schemes_vector.end()); } -void TestingStrategy::remove_testing_scheme(const LocationType& loc_type, const TestingScheme& scheme) -{ - auto loc_id = LocationId{INVALID_LOCATION_INDEX, loc_type}; - remove_testing_scheme(loc_id, scheme); -} - void TestingStrategy::update_activity_status(TimePoint t) { for (auto& [_, testing_schemes] : m_location_to_schemes_map) { diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index bf338523f3..80e88523a0 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -39,16 +39,19 @@ namespace abm class TestingCriteria { public: + /** + * @brief Create a TestingCriteria where everyone is tested. + */ + TestingCriteria() = default; + /** * @brief Create a TestingCriteria. * @param[in] ages Vector of AgeGroup%s that are either allowed or required to be tested. * @param[in] locations Vector of #Location%s or #LocationType%s that are either allowed or required to be tested. * @param[in] infection_states Vector of #InfectionState%s that are either allowed or required to be tested. - * An empty vector of ages/#LocationType%s/#InfectionStates% means that no condition on the corresponding property + * An empty vector of ages or none bitset of #InfectionStates% means that no condition on the corresponding property * is set! */ - TestingCriteria() = default; - TestingCriteria(const std::vector& ages, const std::vector& infection_states); /** @@ -141,7 +144,7 @@ class TestingScheme bool run_scheme(Person::RandomNumberGenerator& rng, Person& person, TimePoint t) const; private: - TestingCriteria m_testing_criteria; ///< Vector with all TestingCriteria of the scheme. + TestingCriteria m_testing_criteria; ///< TestingCriteria of the scheme. TimeSpan m_minimal_time_since_last_test; ///< Shortest period of time between two tests. TimePoint m_start_date; ///< Starting date of the scheme. TimePoint m_end_date; ///< Ending date of the scheme. @@ -177,7 +180,10 @@ class TestingStrategy * @param[in] loc_type LocationId key for TestingScheme to be added. * @param[in] scheme TestingScheme to be added. */ - void add_testing_scheme(const LocationType& loc_type, const TestingScheme& scheme); + void add_testing_scheme(const LocationType& loc_type, const TestingScheme& scheme) + { + add_testing_scheme(LocationId{INVALID_LOCATION_INDEX, loc_type}, scheme); + } /** * @brief Remove a TestingScheme from the set of schemes that are checked for testing at a certain Location. @@ -193,7 +199,10 @@ class TestingStrategy * @param[in] loc_type LocationType key for TestingScheme to be remove. * @param[in] scheme TestingScheme to be removed. */ - void remove_testing_scheme(const LocationType& loc_type, const TestingScheme& scheme); + void remove_testing_scheme(const LocationType& loc_type, const TestingScheme& scheme) + { + remove_testing_scheme(LocationId{INVALID_LOCATION_INDEX, loc_type}, scheme); + } /** * @brief Checks if the given TimePoint is within the interval of start and end date of each TestingScheme and then From d3eb492fadbec2dfc64b6949a99d6e2f60002ad7 Mon Sep 17 00:00:00 2001 From: Khoa Nguyen <4763945+khoanguyen-dev@users.noreply.github.com> Date: Thu, 9 Nov 2023 22:26:21 +0100 Subject: [PATCH 18/18] Remove a comment in initialisation of TestingCriteria --- cpp/models/abm/testing_strategy.h | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index 80e88523a0..0591e3c931 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -47,7 +47,6 @@ class TestingCriteria /** * @brief Create a TestingCriteria. * @param[in] ages Vector of AgeGroup%s that are either allowed or required to be tested. - * @param[in] locations Vector of #Location%s or #LocationType%s that are either allowed or required to be tested. * @param[in] infection_states Vector of #InfectionState%s that are either allowed or required to be tested. * An empty vector of ages or none bitset of #InfectionStates% means that no condition on the corresponding property * is set!