From cf52fa1bc8a37c024eafa770defc99188234bfae Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Fri, 19 Jul 2024 19:14:28 +0200 Subject: [PATCH 01/42] manually move branch to new main --- cpp/memilio/CMakeLists.txt | 3 + cpp/memilio/io/README.md | 22 + cpp/memilio/io/auto_serialize.cpp | 20 + cpp/memilio/io/auto_serialize.h | 195 +++++++ .../math/time_dependent_parameter_functor.h | 106 ++++ cpp/memilio/utils/random_number_generator.h | 8 + cpp/models/abm/infection.h | 21 +- cpp/models/abm/location.h | 61 +-- cpp/models/abm/mask.h | 17 + cpp/models/abm/parameters.h | 118 ++++- cpp/models/abm/person.h | 67 +-- cpp/models/abm/testing_strategy.h | 88 ++++ cpp/models/abm/time.h | 14 + cpp/models/abm/trip_list.h | 36 ++ cpp/models/abm/vaccine.h | 19 +- cpp/models/abm/world.h | 46 +- cpp/simulations/abm.cpp | 7 +- cpp/simulations/abm_braunschweig.cpp | 492 ++++++++---------- cpp/tests/test_abm_infection.cpp | 71 ++- cpp/tests/test_abm_location.cpp | 44 ++ cpp/tests/test_abm_person.cpp | 62 +++ cpp/tests/test_abm_world.cpp | 66 +++ cpp/tests/test_json_serializer.cpp | 105 +--- 23 files changed, 1172 insertions(+), 516 deletions(-) create mode 100644 cpp/memilio/io/auto_serialize.cpp create mode 100644 cpp/memilio/io/auto_serialize.h create mode 100644 cpp/memilio/math/time_dependent_parameter_functor.h diff --git a/cpp/memilio/CMakeLists.txt b/cpp/memilio/CMakeLists.txt index fd24f11e8e..aef7abb775 100644 --- a/cpp/memilio/CMakeLists.txt +++ b/cpp/memilio/CMakeLists.txt @@ -26,6 +26,8 @@ add_library(memilio compartments/simulation.h compartments/flow_simulation.h compartments/parameter_studies.h + io/auto_serialize.h + io/auto_serialize.cpp io/io.h io/io.cpp io/hdf5_cpp.h @@ -56,6 +58,7 @@ add_library(memilio math/matrix_shape.cpp math/interpolation.h math/interpolation.cpp + math/time_dependent_parameter_functor.h mobility/metapopulation_mobility_instant.h mobility/metapopulation_mobility_instant.cpp mobility/metapopulation_mobility_stochastic.h diff --git a/cpp/memilio/io/README.md b/cpp/memilio/io/README.md index 2ec7d91547..a5bc18e97c 100644 --- a/cpp/memilio/io/README.md +++ b/cpp/memilio/io/README.md @@ -121,6 +121,28 @@ more efficiently than the provided general free functions. - HDF5 support classes for C++ - Reading of mobility matrix files +## Auto-serialization + +This feature provides an easy and convenient method to serialize and deserialize classes, but with additional requirements and a reduced feature set. To give an example: + +```cpp +struct Foo { + int i; + auto auto_serialize() { + return make_auto_serialization("Foo", NVP("i", i)); + } +}; +``` + +The auto-serialization effectively only supports the `add_element` and `expect_element` operations defined in the Concepts section, where the function arguments are provided by the name-value pairs (NVPs). Note that the value part of an NVP is also used to assign a value during deserialization, hence the class members must be used directly in the NVP constructor (i.e. as a non-const lvalue reference). + +The requirements for auto-serialization are: +- The class must be trivially constructible. + - Alternatively, you may provide a spezialisation of the struct `AutoSerializableFactory`. +- There is exactly one NVP for every class member (but the names and their order is arbitrary). + - Values must be passed directly. +- Every class member itself is both (auto-)serializable and assignable. + ## The command line interface We provide a function `mio::command_line_interface` in the header `memilio/io/cli.h`, that can be used to write to or read from a parameter set. It can take parameters from command line arguments (i.e. the content of `argv` in the main function), and assign them to or get them from a `mio::ParameterSet`. A small example can be seen in `cpp/examples/cli.cpp`. diff --git a/cpp/memilio/io/auto_serialize.cpp b/cpp/memilio/io/auto_serialize.cpp new file mode 100644 index 0000000000..054a77c1b3 --- /dev/null +++ b/cpp/memilio/io/auto_serialize.cpp @@ -0,0 +1,20 @@ +/* +* Copyright (C) 2020-2024 MEmilio +* +* Authors: Rene Schmieding +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#include "memilio/io/auto_serialize.h" diff --git a/cpp/memilio/io/auto_serialize.h b/cpp/memilio/io/auto_serialize.h new file mode 100644 index 0000000000..fa823c7e02 --- /dev/null +++ b/cpp/memilio/io/auto_serialize.h @@ -0,0 +1,195 @@ +/* +* Copyright (C) 2020-2024 MEmilio +* +* Authors: Rene Schmieding +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#ifndef MIO_IO_AUTO_SERIALIZE_H_ +#define MIO_IO_AUTO_SERIALIZE_H_ + +#include "memilio/io/io.h" +#include "memilio/utils/metaprogramming.h" + +#include +#include +#include +#include +#include + +namespace mio +{ + +/** + * @brief Pair of name and value used for auto-(de)serialization. + * + * This object holds a view of a name and reference of a value. Mind their lifetime! + * @tparam ValueType The (non-cv, non-reference) type of the value. + */ +template +struct NVP { + using Type = ValueType&; + /** + * @brief Create a (name, value) pair. + * + * @param n A view of the name. + * @param v A non-const lvalue reference to the value. + */ + explicit NVP(const std::string_view n, Type v) + : name(n) + , value(v) + { + } + const std::string_view name; + Type value; + + NVP() = delete; + NVP(const NVP&) = delete; + NVP(NVP&&) = default; + NVP& operator=(const NVP&) = delete; + NVP& operator=(NVP&&) = delete; +}; + +/** + * @brief Provide names and values for auto-(de)serialization. + * + * This function packages the class name and a name-value pair for each class member together to define both a + * serialize and deserialize function (with limited features). + * + * Note that auto-serialization requires that all class members participate in serialization, and that + * each class member is (auto-)serializable and assignable. + * + * @tparam Targets List of each class member's type. + * @param class_name The name of the class to auto-serialize. + * @param class_members A name-value pair (NVP) for each class member. + * @return Collection of all name views and value references used for auto-(de)serialization. + */ +template +[[nodiscard]] inline auto make_auto_serialization(const std::string_view class_name, NVP&&... class_members) +{ + return std::make_pair(class_name, std::make_tuple(std::move(class_members)...)); +} + +/** + * @brief Creates an instance of AutoSerializable for auto-deserialization. + * + * The default implementation uses the default constructor of AutoSerializable, if available. If there is no default + * constructor, this class must be spezialized to provide the method `static AutoSerializable create()`. If there is + * a default constructor, but it is private, AutoSerializableFactory can be marked as friend instead. + * + * The state of the object retured by `create()` is completely arbitrary, as it is expected that auto-deserialization + * will overwrite the value of each class member. + * + * @tparam AutoSerializable A type with an auto_serialize member. + */ +template +struct AutoSerializableFactory { + static AutoSerializable create() + { + return AutoSerializable{}; + } +}; + +namespace details +{ + +/** + * @brief Helper type to detect whether T has a auto_serialize member function. + * @tparam T Any type. + */ +template +using auto_serialize_expr_t = decltype(std::declval().auto_serialize()); + +template +void add_nvp(IOObject& obj, NVP const&& nvp) +{ + obj.add_element(std::string{nvp.name}, nvp.value); +} + +template +void auto_serialize_impl(IOContext& io, const std::string_view name, std::tuple...> const&& targets) +{ + auto obj = io.create_object(std::string{name}); + + std::apply( + [&obj](NVP const&&... nvps) { + (add_nvp(obj, std::move(nvps)), ...); + }, + std::move(targets)); +} + +template +IOResult expect_nvp(IOObject& obj, NVP&& nvp) +{ + return obj.expect_element(std::string{nvp.name}, Tag{}); +} + +template +IOResult auto_deserialize_impl(IOContext& io, AutoSerializable& a, std::string_view name, + std::tuple...>&& targets) +{ + auto obj = io.expect_object(std::string{name}); + + auto unpacked_apply = [&io, &a, &obj](NVP... nvps) { + return apply( + io, + [&a, &nvps...](const Targets&... values) { + ((nvps.value = values), ...); + return a; + }, + expect_nvp(obj, std::move(nvps))...); + }; + + return std::apply(unpacked_apply, std::move(targets)); +} + +} // namespace details + +/** + * @brief Detect whether T has a auto_serialize member function. + * @tparam T Any type. + */ +template +using has_auto_serialize = is_expression_valid; + +// disables itself if a deserialize member is present or if there is no auto_serialize member +// generates serialize method depending on NVPs given by auto_serialize +template ::value && + not has_serialize::value, + AutoSerializable*> = nullptr> +void serialize_internal(IOContext& io, const AutoSerializable& t) +{ + // Note that this cast is only safe if we do not modify targets. + const auto targets = const_cast(&t)->auto_serialize(); + details::auto_serialize_impl(io, targets.first, std::move(targets.second)); +} + +// disables itself if a deserialize member is present or if there is no auto_serialize member +// generates deserialize method depending on NVPs given by auto_serialize +template ::value && + not has_deserialize::value, + AutoSerializable*> = nullptr> +IOResult deserialize_internal(IOContext& io, Tag) +{ + AutoSerializable a = AutoSerializableFactory::create(); + auto targets = a.auto_serialize(); + return details::auto_deserialize_impl(io, a, targets.first, std::move(targets.second)); +} + +} // namespace mio + +#endif // MIO_IO_AUTO_SERIALIZE_H_ diff --git a/cpp/memilio/math/time_dependent_parameter_functor.h b/cpp/memilio/math/time_dependent_parameter_functor.h new file mode 100644 index 0000000000..3b442df8fa --- /dev/null +++ b/cpp/memilio/math/time_dependent_parameter_functor.h @@ -0,0 +1,106 @@ +/* +* Copyright (C) 2020-2024 MEmilio +* +* Authors: Rene Schmieding +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#ifndef MIO_MATH_TIME_DEPENDENT_PARAMETER_FUNCTOR_H +#define MIO_MATH_TIME_DEPENDENT_PARAMETER_FUNCTOR_H + +#include "memilio/config.h" +#include "memilio/io/auto_serialize.h" +#include "memilio/math/interpolation.h" + +#include +#include + +namespace mio +{ + +class TimeDependentParameterFunctor +{ +public: + enum class Type + { + Zero, + LinearInterpolation, + }; + + using DataType = std::vector>; + TimeDependentParameterFunctor(Type type, DataType data) + : m_type(type) + , m_data(data) + { + // data preprocessing + switch (m_type) { + case Type::Zero: + // no preprocessing needed + break; + case Type::LinearInterpolation: + // make sure data has the correct shape, i.e. a list of (time, value) pairs + assert(m_data.size() > 0); + assert(std::all_of(m_data.begin(), m_data.end(), [](auto&& a) { + return a.size() == 2; + })); + // sort by time + std::sort(m_data.begin(), m_data.end(), [](auto&& a, auto&& b) { + return a[0] < b[0]; + }); + } + } + + TimeDependentParameterFunctor() + : TimeDependentParameterFunctor(Type::Zero, {}) + { + } + + ScalarType operator()(ScalarType time) const + { + switch (m_type) { + case Type::Zero: + return 0.0; + case Type::LinearInterpolation: + // find next time point in m_data (strictly) after time + const auto next_tp = std::upper_bound(m_data.begin(), m_data.end(), time, [](auto&& t, auto&& tp) { + return t < tp[0]; + }); + if (next_tp == m_data.begin()) { // time is before first data point + return m_data.front()[1]; + } + if (next_tp == m_data.end()) { // time is past last data point + return m_data.back()[1]; + } + const auto tp = next_tp - 1; + return linear_interpolation(time, (*tp)[0], (*next_tp)[0], (*tp)[1], (*next_tp)[1]); + } + + return 0.0; // should be unreachable, but without this the compiler may complain about a missing return. + } + + /// This method is used by the auto-serialization feature. + auto auto_serialize() + { + return make_auto_serialization("TimeDependentParameterFunctor", NVP("type", m_type), NVP("data", m_data)); + } + +private: + Type m_type; + DataType m_data; +}; + +} // namespace mio + +#endif diff --git a/cpp/memilio/utils/random_number_generator.h b/cpp/memilio/utils/random_number_generator.h index 2ab4cadd48..bc1b2bbc77 100644 --- a/cpp/memilio/utils/random_number_generator.h +++ b/cpp/memilio/utils/random_number_generator.h @@ -21,6 +21,7 @@ #ifndef MIO_RANDOM_NUMBER_GENERATOR_H #define MIO_RANDOM_NUMBER_GENERATOR_H +#include "memilio/io/auto_serialize.h" #include "memilio/utils/compiler_diagnostics.h" #include "memilio/utils/logging.h" #include "memilio/utils/miompi.h" @@ -357,6 +358,13 @@ class RandomNumberGenerator : public RandomNumberGeneratorBase m_key; Counter m_counter; diff --git a/cpp/models/abm/infection.h b/cpp/models/abm/infection.h index ca9dae9bdd..4c2ed0499b 100644 --- a/cpp/models/abm/infection.h +++ b/cpp/models/abm/infection.h @@ -21,6 +21,7 @@ #define MIO_ABM_INFECTION_H #include "abm/personal_rng.h" +#include "memilio/io/auto_serialize.h" #include "abm/time.h" #include "abm/infection_state.h" #include "abm/virus_variant.h" @@ -44,11 +45,17 @@ struct ViralLoad { ScalarType peak; ///< Peak amplitude of the ViralLoad. ScalarType incline; ///< Incline of the ViralLoad during incline phase in log_10 scale per day (always positive). ScalarType decline; ///< Decline of the ViralLoad during decline phase in log_10 scale per day (always negative). + + /// This method is used by the auto-serialization feature. + auto auto_serialize() + { + return make_auto_serialization("ViralLoad", NVP("start_date", start_date), NVP("end_date", end_date), + NVP("peak", peak), NVP("incline", incline), NVP("decline", decline)); + } }; class Infection { - public: /** * @brief Create an Infection for a single Person. @@ -114,7 +121,19 @@ class Infection */ TimePoint get_start_date() const; + /// This method is used by the auto-serialization feature. + auto auto_serialize() + { + return make_auto_serialization("Infection", NVP("infection_course", m_infection_course), + NVP("virus_variant", m_virus_variant), NVP("viral_load", m_viral_load), + NVP("log_norm_alpha", m_log_norm_alpha), NVP("log_norm_beta", m_log_norm_beta), + NVP("detected", m_detected)); + } + private: + friend AutoSerializableFactory; + Infection() = default; + /** * @brief Determine ViralLoad course and Infection course based on init_state. * Calls draw_infection_course_backward for all #InfectionState%s prior and draw_infection_course_forward for all diff --git a/cpp/models/abm/location.h b/cpp/models/abm/location.h index 921cca01a0..87327a41c1 100644 --- a/cpp/models/abm/location.h +++ b/cpp/models/abm/location.h @@ -25,6 +25,7 @@ #include "abm/parameters.h" #include "abm/location_type.h" +#include "memilio/io/auto_serialize.h" #include "boost/atomic/atomic.hpp" namespace mio @@ -48,6 +49,12 @@ struct GeographicalLocation { { return !(latitude == other.latitude && longitude == other.longitude); } + + /// This method is used by the auto-serialization feature. + auto auto_serialize() + { + return make_auto_serialization("GraphicalLocation", NVP("latitude", latitude), NVP("longitude", longitude)); + } }; struct CellIndex : public mio::Index { @@ -73,6 +80,12 @@ struct CellCapacity { } uint32_t volume; ///< Volume of the Cell. uint32_t persons; ///< Maximal number of Person%s at the Cell. + + /// This method is used by the auto-serialization feature. + auto auto_serialize() + { + return make_auto_serialization("CellCapacity", NVP("volume", volume), NVP("persons", persons)); + } }; /** @@ -87,6 +100,12 @@ struct Cell { * @return The relative cell size for the Cell. */ ScalarType compute_space_per_person_relative() const; + + /// This method is used by the auto-serialization feature. + auto auto_serialize() + { + return make_auto_serialization("Cell", NVP("capacity", m_capacity)); + } }; // namespace mio /** @@ -232,36 +251,6 @@ class Location m_npi_active = new_status; } - /** - * serialize this. - * @see mio::serialize - */ - template - void serialize(IOContext& io) const - { - auto obj = io.create_object("Location"); - obj.add_element("index", m_id); - obj.add_element("type", m_type); - } - - /** - * deserialize an object of this class. - * @see mio::deserialize - */ - template - static IOResult deserialize(IOContext& io) - { - auto obj = io.expect_object("Location"); - auto index = obj.expect_element("index", Tag{}); - auto type = obj.expect_element("type", Tag{}); - return apply( - io, - [](auto&& index_, auto&& type_) { - return Location{type_, index_}; - }, - index, type); - } - /** * @brief Get the geographical location of the Location. * @return The geographical location of the Location. @@ -280,7 +269,19 @@ class Location m_geographical_location = location; } + /// This method is used by the auto-serialization feature. + auto auto_serialize() + { + return make_auto_serialization("Location", NVP("id", m_id), NVP("parameters", m_parameters), + NVP("cells", m_cells), NVP("required_mask", m_required_mask), + NVP("npi_active", m_npi_active), + NVP("geographical_location", m_geographical_location)); + } + private: + friend AutoSerializableFactory; + Location() = default; + LocationType m_type; ///< Type of the Location. LocationId m_id; ///< Unique identifier for the Location in the World owning it. LocalInfectionParameters m_parameters; ///< Infection parameters for the Location. diff --git a/cpp/models/abm/mask.h b/cpp/models/abm/mask.h index 2b9048925b..f3697d7425 100644 --- a/cpp/models/abm/mask.h +++ b/cpp/models/abm/mask.h @@ -23,6 +23,7 @@ #include "abm/mask_type.h" #include "abm/time.h" +#include "memilio/io/auto_serialize.h" namespace mio { @@ -72,11 +73,27 @@ class Mask */ void change_mask(MaskType new_mask_type); + /// This method is used by the auto-serialization feature. + auto auto_serialize() + { + return make_auto_serialization("Mask", NVP("mask_type", m_type), NVP("time_used", m_time_used)); + } + private: MaskType m_type; ///< Type of the Mask. TimeSpan m_time_used; ///< Length of time the Mask has been used. }; } // namespace abm + +/// @brief Creates an instance of abm::Mask for auto-deserialization. +template <> +struct AutoSerializableFactory { + static abm::Mask create() + { + return abm::Mask(abm::MaskType::Count); + } +}; + } // namespace mio #endif diff --git a/cpp/models/abm/parameters.h b/cpp/models/abm/parameters.h index 7ca6dd4578..b4a66a5d0d 100644 --- a/cpp/models/abm/parameters.h +++ b/cpp/models/abm/parameters.h @@ -24,16 +24,45 @@ #include "abm/time.h" #include "abm/virus_variant.h" #include "abm/vaccine.h" +#include "memilio/config.h" +#include "memilio/io/auto_serialize.h" +#include "memilio/io/io.h" +#include "memilio/math/time_dependent_parameter_functor.h" #include "memilio/utils/custom_index_array.h" #include "memilio/utils/uncertain_value.h" #include "memilio/utils/parameter_set.h" #include "memilio/epidemiology/age_group.h" #include "memilio/epidemiology/damping.h" #include "memilio/epidemiology/contact_matrix.h" +#include #include namespace mio { + +template +void serialize_internal(IOContext& io, const UniformDistribution::ParamType& p) +{ + auto obj = io.create_object("UniformDistributionParams"); + obj.add_element("a", p.params.a()); + obj.add_element("b", p.params.b()); +} + +template +IOResult::ParamType> deserialize_internal(IOContext& io, + Tag::ParamType>) +{ + auto obj = io.expect_object("UniformDistributionParams"); + auto a = obj.expect_element("a", Tag{}); + auto b = obj.expect_element("b", Tag{}); + return apply( + io, + [](auto&& a_, auto&& b_) { + return UniformDistribution::ParamType{a_, b_}; + }, + a, b); +} + namespace abm { @@ -41,7 +70,7 @@ namespace abm * @brief Time that a Person is infected but not yet infectious. */ struct IncubationPeriod { - using Type = CustomIndexArray< UncertainValue<>, VirusVariant, AgeGroup>; + using Type = CustomIndexArray, VirusVariant, AgeGroup>; static Type get_default(AgeGroup size) { return Type({VirusVariant::Count, size}, 1.); @@ -53,7 +82,7 @@ struct IncubationPeriod { }; struct InfectedNoSymptomsToSymptoms { - using Type = CustomIndexArray< UncertainValue<>, VirusVariant, AgeGroup>; + using Type = CustomIndexArray, VirusVariant, AgeGroup>; static Type get_default(AgeGroup size) { return Type({VirusVariant::Count, size}, 1.); @@ -65,7 +94,7 @@ struct InfectedNoSymptomsToSymptoms { }; struct InfectedNoSymptomsToRecovered { - using Type = CustomIndexArray< UncertainValue<>, VirusVariant, AgeGroup>; + using Type = CustomIndexArray, VirusVariant, AgeGroup>; static Type get_default(AgeGroup size) { return Type({VirusVariant::Count, size}, 1.); @@ -77,7 +106,7 @@ struct InfectedNoSymptomsToRecovered { }; struct InfectedSymptomsToRecovered { - using Type = CustomIndexArray< UncertainValue<>, VirusVariant, AgeGroup>; + using Type = CustomIndexArray, VirusVariant, AgeGroup>; static Type get_default(AgeGroup size) { return Type({VirusVariant::Count, size}, 1.); @@ -89,7 +118,7 @@ struct InfectedSymptomsToRecovered { }; struct InfectedSymptomsToSevere { - using Type = CustomIndexArray< UncertainValue<>, VirusVariant, AgeGroup>; + using Type = CustomIndexArray, VirusVariant, AgeGroup>; static Type get_default(AgeGroup size) { return Type({VirusVariant::Count, size}, 1.); @@ -101,7 +130,7 @@ struct InfectedSymptomsToSevere { }; struct SevereToCritical { - using Type = CustomIndexArray< UncertainValue<>, VirusVariant, AgeGroup>; + using Type = CustomIndexArray, VirusVariant, AgeGroup>; static Type get_default(AgeGroup size) { return Type({VirusVariant::Count, size}, 1.); @@ -113,7 +142,7 @@ struct SevereToCritical { }; struct SevereToRecovered { - using Type = CustomIndexArray< UncertainValue<>, VirusVariant, AgeGroup>; + using Type = CustomIndexArray, VirusVariant, AgeGroup>; static Type get_default(AgeGroup size) { return Type({VirusVariant::Count, size}, 1.); @@ -125,7 +154,7 @@ struct SevereToRecovered { }; struct CriticalToRecovered { - using Type = CustomIndexArray< UncertainValue<>, VirusVariant, AgeGroup>; + using Type = CustomIndexArray, VirusVariant, AgeGroup>; static Type get_default(AgeGroup size) { return Type({VirusVariant::Count, size}, 1.); @@ -137,7 +166,7 @@ struct CriticalToRecovered { }; struct CriticalToDead { - using Type = CustomIndexArray< UncertainValue<>, VirusVariant, AgeGroup>; + using Type = CustomIndexArray, VirusVariant, AgeGroup>; static Type get_default(AgeGroup size) { return Type({VirusVariant::Count, size}, 1.); @@ -149,7 +178,7 @@ struct CriticalToDead { }; struct RecoveredToSusceptible { - using Type = CustomIndexArray< UncertainValue<>, VirusVariant, AgeGroup>; + using Type = CustomIndexArray, VirusVariant, AgeGroup>; static Type get_default(AgeGroup size) { return Type({VirusVariant::Count, size}, 1.); @@ -168,6 +197,14 @@ struct ViralLoadDistributionsParameters { UniformDistribution::ParamType viral_load_peak; UniformDistribution::ParamType viral_load_incline; UniformDistribution::ParamType viral_load_decline; + + /// This method is used by the auto-serialization feature. + auto auto_serialize() + { + return make_auto_serialization("ViralLoadDistributionsParameters", NVP("viral_load_peak", viral_load_peak), + NVP("viral_load_incline", viral_load_incline), + NVP("viral_load_decline", viral_load_decline)); + } }; struct ViralLoadDistributions { @@ -191,6 +228,14 @@ struct ViralLoadDistributions { struct InfectivityDistributionsParameters { UniformDistribution::ParamType infectivity_alpha; UniformDistribution::ParamType infectivity_beta; + + /// This method is used by the auto-serialization feature. + auto auto_serialize() + { + return make_auto_serialization("InfectivityDistributionsParameters", + NVP("infectivity_alpha", infectivity_alpha), + NVP("infectivity_beta", infectivity_beta)); + } }; struct InfectivityDistributions { @@ -210,7 +255,7 @@ struct InfectivityDistributions { * @brief Probability that an Infection is detected. */ struct DetectInfection { - using Type = CustomIndexArray< UncertainValue<>, VirusVariant, AgeGroup>; + using Type = CustomIndexArray, VirusVariant, AgeGroup>; static Type get_default(AgeGroup size) { return Type({VirusVariant::Count, size}, 1.); @@ -225,7 +270,7 @@ struct DetectInfection { * @brief Effectiveness of a Mask of a certain MaskType% against an Infection%. */ struct MaskProtection { - using Type = CustomIndexArray< UncertainValue<>, MaskType>; + using Type = CustomIndexArray, MaskType>; static Type get_default(AgeGroup /*size*/) { return Type({MaskType::Count}, 1.); @@ -251,7 +296,8 @@ struct AerosolTransmissionRates { } }; -using InputFunctionForProtectionLevel = std::function; +// using InputFunctionForProtectionLevel = std::function; +using InputFunctionForProtectionLevel = TimeDependentParameterFunctor; /** * @brief Personal protection factor against #Infection% after #Infection and #Vaccination, which depends on #ExposureType, @@ -261,9 +307,8 @@ struct InfectionProtectionFactor { using Type = CustomIndexArray; static auto get_default(AgeGroup size) { - return Type({ExposureType::Count, size, VirusVariant::Count}, [](ScalarType /*days*/) -> ScalarType { - return 0; - }); + return Type({ExposureType::Count, size, VirusVariant::Count}, + Type::value_type(TimeDependentParameterFunctor::Type::Zero, {})); } static std::string name() { @@ -279,9 +324,8 @@ struct SeverityProtectionFactor { using Type = CustomIndexArray; static auto get_default(AgeGroup size) { - return Type({ExposureType::Count, size, VirusVariant::Count}, [](ScalarType /*days*/) -> ScalarType { - return 0; - }); + return Type({ExposureType::Count, size, VirusVariant::Count}, + Type::value_type(TimeDependentParameterFunctor::Type::Zero, {})); } static std::string name() { @@ -296,9 +340,7 @@ struct HighViralLoadProtectionFactor { using Type = InputFunctionForProtectionLevel; static auto get_default() { - return Type([](ScalarType /*days*/) -> ScalarType { - return 0; - }); + return Type(TimeDependentParameterFunctor::Type::Zero, {}); } static std::string name() { @@ -310,8 +352,15 @@ struct HighViralLoadProtectionFactor { * @brief Parameters that describe the reliability of a test. */ struct TestParameters { - UncertainValue<> sensitivity; - UncertainValue<> specificity; + UncertainValue<> sensitivity; + UncertainValue<> specificity; + + /// This method is used by the auto-serialization feature. + auto auto_serialize() + { + return make_auto_serialization("TestParameters", NVP("sensitivity", sensitivity), + NVP("specificity", specificity)); + } }; struct GenericTest { @@ -390,7 +439,7 @@ struct QuarantineDuration { * @brief Parameter for the exponential distribution to decide if a Person goes shopping. */ struct BasicShoppingRate { - using Type = CustomIndexArray< UncertainValue<>, AgeGroup>; + using Type = CustomIndexArray, AgeGroup>; static auto get_default(AgeGroup size) { return Type({size}, 1.0); @@ -606,6 +655,14 @@ class Parameters : public ParametersBase { } +private: + Parameters(ParametersBase&& base) + : ParametersBase(std::move(base)) + , m_num_groups(this->get().size().get()) + { + } + +public: /** * @brief Get the number of the age groups. */ @@ -756,6 +813,17 @@ class Parameters : public ParametersBase return false; } + /** + * deserialize an object of this class. + * @see epi::deserialize + */ + template + static IOResult deserialize(IOContext& io) + { + BOOST_OUTCOME_TRY(auto&& base, ParametersBase::deserialize(io)); + return success(Parameters(std::move(base))); + } + private: size_t m_num_groups; }; diff --git a/cpp/models/abm/person.h b/cpp/models/abm/person.h index 7810eede7f..c81010c995 100755 --- a/cpp/models/abm/person.h +++ b/cpp/models/abm/person.h @@ -28,6 +28,7 @@ #include "abm/parameters.h" #include "abm/person_id.h" #include "abm/personal_rng.h" +#include "memilio/io/auto_serialize.h" #include "abm/time.h" #include "abm/vaccine.h" #include "abm/mask.h" @@ -393,36 +394,33 @@ class Person */ std::pair get_latest_protection() const; - /** - * serialize this. - * @see mio::serialize - */ - template - void serialize(IOContext& io) const - { - auto obj = io.create_object("Person"); - obj.add_element("Location", m_location); - obj.add_element("age", m_age); - obj.add_element("id", m_person_id); - } - - /** - * deserialize an object of this class. - * @see mio::deserialize - */ - template - static IOResult deserialize(IOContext& io) + /// This method is used by the auto-serialization feature. + auto auto_serialize() { - auto obj = io.expect_object("Person"); - auto loc = obj.expect_element("Location", mio::Tag{}); - auto age = obj.expect_element("age", Tag{}); - auto id = obj.expect_element("id", Tag{}); - return apply( - io, - [](auto&& loc_, auto&& age_, auto&& id_) { - return Person{mio::RandomNumberGenerator(), loc_, AgeGroup(age_), id_}; - }, - loc, age, id); + // clang-format off + return make_auto_serialization( + "Person", + NVP("location", m_location), + NVP("assigned_locations", m_assigned_locations), + NVP("vaccinations", m_vaccinations), + NVP("infections", m_infections), + NVP("quarantine_start",m_quarantine_start), + NVP("age_group", m_age), + NVP("time_at_location", m_time_at_location), + NVP("rnd_workgroup", m_random_workgroup), + NVP("rnd_schoolgroup", m_random_schoolgroup), + NVP("rnd_go_to_work_hour", m_random_goto_work_hour), + NVP("rnd_go_to_school_hour", m_random_goto_school_hour), + NVP("time_of_last_test", m_time_of_last_test), + NVP("mask", m_mask), + NVP("wears_mask", m_wears_mask), + NVP("mask_compliance", m_mask_compliance), + NVP("id", m_person_id), + NVP("cells", m_cells), + NVP("last_transport_mode", m_last_transport_mode), + NVP("rng_counter", m_rng_counter) + ); + // clang-format on } private: @@ -450,6 +448,17 @@ class Person }; } // namespace abm + +/// @brief Creates an instance of abm::Person for auto-deserialization. +template <> +struct AutoSerializableFactory { + static abm::Person create() + { + return abm::Person(thread_local_rng(), abm::LocationType::Count, abm::LocationId(), AgeGroup(0), + abm::PersonId()); + } +}; + } // namespace mio #endif diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index ee0c4a3ba0..e3f4e29251 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -27,6 +27,8 @@ #include "abm/person.h" #include "abm/location.h" #include "abm/time.h" +#include "memilio/io/auto_serialize.h" + #include #include @@ -92,6 +94,36 @@ class TestingCriteria */ bool evaluate(const Person& p, TimePoint t) const; + /** + * serialize this. + * @see mio::serialize + */ + template + void serialize(IOContext& io) const + { + auto obj = io.create_object("TestingCriteria"); + obj.add_element("ages", m_ages.to_ulong()); + obj.add_element("infection_states", m_infection_states.to_ulong()); + } + + /** + * deserialize an object of this class. + * @see mio::deserialize + */ + template + static IOResult deserialize(IOContext& io) + { + auto obj = io.expect_object("TestingCriteria"); + auto ages = obj.expect_element("ages", Tag{}); + auto infection_states = obj.expect_element("infection_states", Tag{}); + return apply( + io, + [](auto&& ages_, auto&& infection_states_) { + return TestingCriteria{ages_, infection_states_}; + }, + ages, infection_states); + } + private: std::bitset m_ages; ///< Set of #AgeGroup%s that are either allowed or required to be tested. std::bitset<(size_t)InfectionState::Count> @@ -144,6 +176,50 @@ class TestingScheme */ bool run_scheme(PersonalRandomNumberGenerator& rng, Person& person, TimePoint t) const; + /** + * serialize this. + * @see mio::serialize + */ + template + void serialize(IOContext& io) const + { + auto obj = io.create_object("TestingScheme"); + obj.add_element("criteria", m_testing_criteria); + obj.add_element("min_time_since_last_test", m_minimal_time_since_last_test); + obj.add_element("start_date", m_start_date); + obj.add_element("end_date", m_end_date); + obj.add_element("test_type", + m_test_type.get_default()); // FIXME: m_test_type should contain TestParameters directly + obj.add_element("probability", m_probability); + obj.add_element("is_active", m_is_active); + } + + /** + * deserialize an object of this class. + * @see mio::deserialize + */ + template + static IOResult deserialize(IOContext& io) + { + auto obj = io.expect_object("TestingScheme"); + auto criteria = obj.expect_element("criteria", Tag{}); + auto min_time_since_last_test = obj.expect_element("min_time_since_last_test", Tag{}); + auto start_date = obj.expect_element("start_date", Tag{}); + auto end_date = obj.expect_element("end_date", Tag{}); + auto test_type = obj.expect_element( + "test_type", Tag{}); // FIXME: m_test_type should contain TestParameters directly + auto probability = obj.expect_element("probability", Tag{}); + auto is_active = obj.expect_element("is_active", Tag{}); + return apply( + io, + [](auto&& criteria_, auto&& min_time_since_last_test_, auto&& start_date_, auto&& end_date_, + auto&& test_type_, auto&& probability_, auto&& is_active_) { + return TestingScheme{ + criteria_, min_time_since_last_test_, start_date_, end_date_, test_type_, probability_, is_active_}; + }, + criteria, min_time_since_last_test, start_date, end_date, test_type, probability, is_active); + } + private: TestingCriteria m_testing_criteria; ///< TestingCriteria of the scheme. TimeSpan m_minimal_time_since_last_test; ///< Shortest period of time between two tests. @@ -168,6 +244,12 @@ class TestingStrategy LocationType type; LocationId id; std::vector schemes; + + /// This method is used by the auto-serialization feature. + auto auto_serialize() + { + return make_auto_serialization("LocalStrategy", NVP("type", type), NVP("id", id), NVP("schemes", schemes)); + } }; /** @@ -233,6 +315,12 @@ class TestingStrategy */ bool run_strategy(PersonalRandomNumberGenerator& rng, Person& person, const Location& location, TimePoint t); + /// This method is used by the auto-serialization feature. + auto auto_serialize() + { + return make_auto_serialization("TestingStrategy", NVP("schemes", m_location_to_schemes_map)); + } + private: std::vector m_location_to_schemes_map; ///< Set of schemes that are checked for testing. }; diff --git a/cpp/models/abm/time.h b/cpp/models/abm/time.h index 45ad5495d4..1f3bb2ca62 100644 --- a/cpp/models/abm/time.h +++ b/cpp/models/abm/time.h @@ -20,6 +20,8 @@ #ifndef MIO_ABM_TIME_H #define MIO_ABM_TIME_H +#include "memilio/io/auto_serialize.h" + namespace mio { namespace abm @@ -143,6 +145,12 @@ class TimeSpan } /**@}*/ + /// This method is used by the auto-serialization feature. + auto auto_serialize() + { + return make_auto_serialization("TimeSpan", NVP("seconds", m_seconds)); + } + private: int m_seconds; ///< The duration of time in seconds. }; @@ -284,6 +292,12 @@ class TimePoint return TimeSpan{m_seconds - p2.seconds()}; } + /// This method is used by the auto-serialization feature. + auto auto_serialize() + { + return make_auto_serialization("TimePoint", NVP("seconds", m_seconds)); + } + private: int m_seconds; ///< The number of seconds after the epoch. }; diff --git a/cpp/models/abm/trip_list.h b/cpp/models/abm/trip_list.h index 627afef7f1..f7182dd166 100644 --- a/cpp/models/abm/trip_list.h +++ b/cpp/models/abm/trip_list.h @@ -191,6 +191,42 @@ class TripList return m_current_index; } + /** + * serialize this. + * @see mio::serialize + */ + template + void serialize(IOContext& io) const + { + auto obj = io.create_object("TripList"); + obj.add_list("trips_weekday", m_trips_weekday.cbegin(), m_trips_weekday.cend()); + obj.add_list("trips_weekend", m_trips_weekend.cbegin(), m_trips_weekend.cend()); + obj.add_element("index", m_current_index); + } + + /** + * deserialize an object of this class. + * @see mio::deserialize + */ + template + static IOResult deserialize(IOContext& io) + { + auto obj = io.expect_object("TripList"); + auto trips_wd = obj.expect_list("trips_weekday", Tag{}); + auto trips_we = obj.expect_list("trips_weekend", Tag{}); + auto index = obj.expect_element("index", Tag{}); + return apply( + io, + [](auto&& trips_wd_, auto&& trips_we_, auto&& index_) { + TripList tl; + tl.m_trips_weekday = trips_wd_; + tl.m_trips_weekend = trips_we_; + tl.m_current_index = index_; + return tl; + }, + trips_wd, trips_we, index); + } + private: std::vector m_trips_weekday; ///< The list of Trip%s a Person makes on a weekday. std::vector m_trips_weekend; ///< The list of Trip%s a Person makes on a weekend day. diff --git a/cpp/models/abm/vaccine.h b/cpp/models/abm/vaccine.h index 72da133516..88638a1c57 100644 --- a/cpp/models/abm/vaccine.h +++ b/cpp/models/abm/vaccine.h @@ -20,6 +20,7 @@ #ifndef MIO_ABM_VACCINE_H #define MIO_ABM_VACCINE_H +#include "memilio/io/auto_serialize.h" #include "abm/time.h" #include @@ -44,7 +45,7 @@ enum class ExposureType : std::uint32_t /** * @brief A tuple of #TimePoint and #ExposureType (i.e. type of the Vaccine). * The #TimePoint describes the time of administration of the Vaccine. -*/ + */ struct Vaccination { Vaccination(ExposureType exposure, TimePoint t) : exposure_type(exposure) @@ -52,11 +53,27 @@ struct Vaccination { { } + /// This method is used by the auto-serialization feature. + auto auto_serialize() + { + return make_auto_serialization("Vaccination", NVP("exposure_type", exposure_type), NVP("time", time)); + } + ExposureType exposure_type; TimePoint time; }; } // namespace abm + +/// @brief Creates an instance of abm::Vaccination for auto-deserialization. +template <> +struct AutoSerializableFactory { + static abm::Vaccination create() + { + return abm::Vaccination(abm::ExposureType::Count, abm::TimePoint()); + } +}; + } // namespace mio #endif diff --git a/cpp/models/abm/world.h b/cpp/models/abm/world.h index 853dff4960..44b7724153 100644 --- a/cpp/models/abm/world.h +++ b/cpp/models/abm/world.h @@ -112,22 +112,16 @@ class World void serialize(IOContext& io) const { auto obj = io.create_object("World"); - obj.add_element("num_agegroups", parameters.get_num_groups()); - std::vector trips; - TripList trip_list = get_trip_list(); - for (size_t i = 0; i < trip_list.num_trips(false); i++) { - trips.push_back(trip_list.get_next_trip(false)); - trip_list.increase_index(); - } - trip_list.reset_index(); - for (size_t i = 0; i < trip_list.num_trips(true); i++) { - trips.push_back(trip_list.get_next_trip(true)); - trip_list.increase_index(); - } - obj.add_list("trips", trips.begin(), trips.end()); - obj.add_list("locations", get_locations().begin(), get_locations().end()); + obj.add_element("parameters", parameters); + // skip caches, they are rebuild by the deserialized world obj.add_list("persons", get_persons().begin(), get_persons().end()); + obj.add_list("locations", get_locations().begin(), get_locations().end()); + obj.add_element("location_types", m_has_locations.to_ulong()); + obj.add_element("testing_strategy", m_testing_strategy); + obj.add_element("trip_list", m_trip_list); obj.add_element("use_migration_rules", m_use_migration_rules); + obj.add_element("cemetery_id", m_cemetery_id); + obj.add_element("rng", m_rng); } /** @@ -138,17 +132,29 @@ class World static IOResult deserialize(IOContext& io) { auto obj = io.expect_object("World"); - auto size = obj.expect_element("num_agegroups", Tag{}); - auto locations = obj.expect_list("locations", Tag{}); - auto trip_list = obj.expect_list("trips", Tag{}); + auto params = obj.expect_element("parameters", Tag{}); auto persons = obj.expect_list("persons", Tag{}); + auto locations = obj.expect_list("locations", Tag{}); + auto location_types = obj.expect_element("location_types", Tag{}); + auto trip_list = obj.expect_element("trip_list", Tag{}); auto use_migration_rules = obj.expect_element("use_migration_rules", Tag{}); + auto cemetery_id = obj.expect_element("cemetery_id", Tag{}); + auto rng = obj.expect_element("rng", Tag{}); return apply( io, - [](auto&& size_, auto&& locations_, auto&& trip_list_, auto&& persons_, auto&& use_migration_rule_) { - return World{size_, locations_, trip_list_, persons_, use_migration_rule_}; + [](auto&& params_, auto&& persons_, auto&& locations_, auto&& location_types_, auto&& trip_list_, + auto&& use_migration_rule_, auto&& cemetery_id_, auto&& rng_) { + World world{params_}; + world.m_persons.assign(persons_.cbegin(), persons_.cend()); + world.m_locations.assign(locations_.cbegin(), locations_.cend()); + world.m_has_locations = location_types_; + world.m_trip_list = trip_list_; + world.m_use_migration_rules = use_migration_rule_; + world.m_cemetery_id = cemetery_id_; + world.m_rng = rng_; + return world; }, - size, locations, trip_list, persons, use_migration_rules); + params, persons, locations, location_types, trip_list, use_migration_rules, cemetery_id, rng); } /** diff --git a/cpp/simulations/abm.cpp b/cpp/simulations/abm.cpp index 78ce0f2a40..e5833fe381 100644 --- a/cpp/simulations/abm.cpp +++ b/cpp/simulations/abm.cpp @@ -471,10 +471,9 @@ void set_parameters(mio::abm::Parameters params) params.set({{mio::abm::VirusVariant::Count, mio::AgeGroup(num_age_groups)}, 4.}); // Set protection level from high viral load. Information based on: https://doi.org/10.1093/cid/ciaa886 - params.get() = [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.863}, {1, 0.969}, {7, 0.029}, {10, 0.002}, {14, 0.0014}, {21, 0}}, days); - }; + params.get() = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.863}, {1, 0.969}, {7, 0.029}, {10, 0.002}, {14, 0.0014}, {21, 0}}}; //0-4 params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.276; diff --git a/cpp/simulations/abm_braunschweig.cpp b/cpp/simulations/abm_braunschweig.cpp index 7d6630f931..2bc7a35264 100644 --- a/cpp/simulations/abm_braunschweig.cpp +++ b/cpp/simulations/abm_braunschweig.cpp @@ -20,13 +20,13 @@ #include "abm/common_abm_loggers.h" #include "abm/location_id.h" #include "abm/lockdown_rules.h" +#include "abm/parameters.h" #include "abm/person.h" #include "abm/simulation.h" #include "abm/world.h" #include "memilio/epidemiology/age_group.h" #include "memilio/io/io.h" #include "memilio/io/result_io.h" -#include "memilio/math/interpolation.h" #include "memilio/utils/uncertain_value.h" #include "boost/algorithm/string/split.hpp" #include "boost/algorithm/string/classification.hpp" @@ -400,10 +400,9 @@ void set_parameters(mio::abm::Parameters params) params.set({{mio::abm::VirusVariant::Count, mio::AgeGroup(num_age_groups)}, 4.}); // Set protection level from high viral load. Information based on: https://doi.org/10.1093/cid/ciaa886 - params.get() = [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.863}, {1, 0.969}, {7, 0.029}, {10, 0.002}, {14, 0.0014}, {21, 0}}, days); - }; + params.get() = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.863}, {1, 0.969}, {7, 0.029}, {10, 0.002}, {14, 0.0014}, {21, 0}}}; //0-4 params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.276; @@ -493,58 +492,51 @@ void set_parameters(mio::abm::Parameters params) // Protection of reinfection is the same for all age-groups, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_0_to_4, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.852}, - {180, 0.852}, - {210, 0.845}, - {240, 0.828}, - {270, 0.797}, - {300, 0.759}, - {330, 0.711}, - {360, 0.661}, - {390, 0.616}, - {420, 0.580}, - {450, 0.559}, - {450, 0.550}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.852}, + {180, 0.852}, + {210, 0.845}, + {240, 0.828}, + {270, 0.797}, + {300, 0.759}, + {330, 0.711}, + {360, 0.661}, + {390, 0.616}, + {420, 0.580}, + {450, 0.559}, + {450, 0.550}}}; + // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_0_to_4, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.91}, {60, 0.92}, {90, 0.88}, {120, 0.84}, {150, 0.81}, {180, 0.88}, {450, 0.5}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.5}, {30, 0.91}, {60, 0.92}, {90, 0.88}, {120, 0.84}, {150, 0.81}, {180, 0.88}, {450, 0.5}}}; // Set up age-related severe protection levels, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_0_to_4, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.967}, - {30, 0.975}, - {60, 0.977}, - {90, 0.974}, - {120, 0.963}, - {150, 0.947}, - {180, 0.93}, - {210, 0.929}, - {240, 0.923}, - {270, 0.908}, - {300, 0.893}, - {330, 0.887}, - {360, 0.887}, - {450, 0.5}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.967}, + {30, 0.975}, + {60, 0.977}, + {90, 0.974}, + {120, 0.963}, + {150, 0.947}, + {180, 0.93}, + {210, 0.929}, + {240, 0.923}, + {270, 0.908}, + {300, 0.893}, + {330, 0.887}, + {360, 0.887}, + {450, 0.5}}}; // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_0_to_4, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {450, 0.5}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {450, 0.5}}}; //5-14 params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.161; @@ -560,57 +552,50 @@ void set_parameters(mio::abm::Parameters params) // Protection of reinfection is the same for all age-groups, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_5_to_14, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.852}, - {180, 0.852}, - {210, 0.845}, - {240, 0.828}, - {270, 0.797}, - {300, 0.759}, - {330, 0.711}, - {360, 0.661}, - {390, 0.616}, - {420, 0.580}, - {450, 0.559}, - {450, 0.550}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.852}, + {180, 0.852}, + {210, 0.845}, + {240, 0.828}, + {270, 0.797}, + {300, 0.759}, + {330, 0.711}, + {360, 0.661}, + {390, 0.616}, + {420, 0.580}, + {450, 0.559}, + {450, 0.550}}}; // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_5_to_14, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.91}, {60, 0.92}, {90, 0.88}, {120, 0.84}, {150, 0.81}, {180, 0.88}, {450, 0.5}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.5}, {30, 0.91}, {60, 0.92}, {90, 0.88}, {120, 0.84}, {150, 0.81}, {180, 0.88}, {450, 0.5}}}; + // Set up age-related severe protection levels, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_5_to_14, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.967}, - {30, 0.975}, - {60, 0.977}, - {90, 0.974}, - {120, 0.963}, - {150, 0.947}, - {180, 0.93}, - {210, 0.929}, - {240, 0.923}, - {270, 0.908}, - {300, 0.893}, - {330, 0.887}, - {360, 0.887}, - {450, 0.5}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.967}, + {30, 0.975}, + {60, 0.977}, + {90, 0.974}, + {120, 0.963}, + {150, 0.947}, + {180, 0.93}, + {210, 0.929}, + {240, 0.923}, + {270, 0.908}, + {300, 0.893}, + {330, 0.887}, + {360, 0.887}, + {450, 0.5}}}; // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_5_to_14, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {450, 0.5}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {450, 0.5}}}; //15-34 params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = @@ -626,57 +611,49 @@ void set_parameters(mio::abm::Parameters params) params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.0; // Set up personal infection and vaccine protection levels, based on: https://doi.org/10.1038/s41577-021-00550-x, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_15_to_34, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.852}, - {180, 0.852}, - {210, 0.845}, - {240, 0.828}, - {270, 0.797}, - {300, 0.759}, - {330, 0.711}, - {360, 0.661}, - {390, 0.616}, - {420, 0.580}, - {450, 0.559}, - {450, 0.550}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.852}, + {180, 0.852}, + {210, 0.845}, + {240, 0.828}, + {270, 0.797}, + {300, 0.759}, + {330, 0.711}, + {360, 0.661}, + {390, 0.616}, + {420, 0.580}, + {450, 0.559}, + {450, 0.550}}}; // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_15_to_34, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.89}, {60, 0.84}, {90, 0.78}, {120, 0.68}, {150, 0.57}, {180, 0.39}, {450, 0.1}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.5}, {30, 0.89}, {60, 0.84}, {90, 0.78}, {120, 0.68}, {150, 0.57}, {180, 0.39}, {450, 0.1}}}; // Set up age-related severe protection levels, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_15_to_34, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.967}, - {30, 0.975}, - {60, 0.977}, - {90, 0.974}, - {120, 0.963}, - {150, 0.947}, - {180, 0.93}, - {210, 0.929}, - {240, 0.923}, - {270, 0.908}, - {300, 0.893}, - {330, 0.887}, - {360, 0.887}, - {450, 0.5}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.967}, + {30, 0.975}, + {60, 0.977}, + {90, 0.974}, + {120, 0.963}, + {150, 0.947}, + {180, 0.93}, + {210, 0.929}, + {240, 0.923}, + {270, 0.908}, + {300, 0.893}, + {330, 0.887}, + {360, 0.887}, + {450, 0.5}}}; // Information is from: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_15_to_34, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {180, 0.90}, {450, 0.5}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {180, 0.90}, {450, 0.5}}}; //35-59 params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = @@ -693,58 +670,49 @@ void set_parameters(mio::abm::Parameters params) // Protection of reinfection is the same for all age-groups, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_35_to_59, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.852}, - {180, 0.852}, - {210, 0.845}, - {240, 0.828}, - {270, 0.797}, - {300, 0.759}, - {330, 0.711}, - {360, 0.661}, - {390, 0.616}, - {420, 0.580}, - {450, 0.559}, - {450, 0.550}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.852}, + {180, 0.852}, + {210, 0.845}, + {240, 0.828}, + {270, 0.797}, + {300, 0.759}, + {330, 0.711}, + {360, 0.661}, + {390, 0.616}, + {420, 0.580}, + {450, 0.559}, + {450, 0.550}}}; // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_35_to_59, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.89}, {60, 0.84}, {90, 0.78}, {120, 0.68}, {150, 0.57}, {180, 0.39}, {450, 0.1}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.5}, {30, 0.89}, {60, 0.84}, {90, 0.78}, {120, 0.68}, {150, 0.57}, {180, 0.39}, {450, 0.1}}}; // Set up age-related severe protection levels, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_35_to_59, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.967}, - {30, 0.975}, - {60, 0.977}, - {90, 0.974}, - {120, 0.963}, - {150, 0.947}, - {180, 0.93}, - {210, 0.929}, - {240, 0.923}, - {270, 0.908}, - {300, 0.893}, - {330, 0.887}, - {360, 0.887}, - {450, 0.5}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.967}, + {30, 0.975}, + {60, 0.977}, + {90, 0.974}, + {120, 0.963}, + {150, 0.947}, + {180, 0.93}, + {210, 0.929}, + {240, 0.923}, + {270, 0.908}, + {300, 0.893}, + {330, 0.887}, + {360, 0.887}, + {450, 0.5}}}; // Information is from: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_35_to_59, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {180, 0.90}, {450, 0.5}}, days); - }; - + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {180, 0.90}, {450, 0.5}}}; //60-79 params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.179; @@ -760,57 +728,49 @@ void set_parameters(mio::abm::Parameters params) // Protection of reinfection is the same for all age-groups, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_60_to_79, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.852}, - {180, 0.852}, - {210, 0.845}, - {240, 0.828}, - {270, 0.797}, - {300, 0.759}, - {330, 0.711}, - {360, 0.661}, - {390, 0.616}, - {420, 0.580}, - {450, 0.559}, - {450, 0.550}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.852}, + {180, 0.852}, + {210, 0.845}, + {240, 0.828}, + {270, 0.797}, + {300, 0.759}, + {330, 0.711}, + {360, 0.661}, + {390, 0.616}, + {420, 0.580}, + {450, 0.559}, + {450, 0.550}}}; // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_60_to_79, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.87}, {60, 0.85}, {90, 0.78}, {120, 0.67}, {150, 0.61}, {180, 0.50}, {450, 0.1}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.5}, {30, 0.87}, {60, 0.85}, {90, 0.78}, {120, 0.67}, {150, 0.61}, {180, 0.50}, {450, 0.1}}}; // Set up personal severe protection levels. // Protection of severe infection of age group 65 + is different from other age group, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_60_to_79, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.967}, - {30, 0.975}, - {60, 0.977}, - {90, 0.974}, - {120, 0.963}, - {150, 0.947}, - {180, 0.93}, - {210, 0.929}, - {240, 0.923}, - {270, 0.908}, - {300, 0.893}, - {330, 0.887}, - {360, 0.887}, - {360, 0.5}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.967}, + {30, 0.975}, + {60, 0.977}, + {90, 0.974}, + {120, 0.963}, + {150, 0.947}, + {180, 0.93}, + {210, 0.929}, + {240, 0.923}, + {270, 0.908}, + {300, 0.893}, + {330, 0.887}, + {360, 0.887}, + {360, 0.5}}}; params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_60_to_79, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.91}, {60, 0.86}, {90, 0.91}, {120, 0.94}, {150, 0.95}, {180, 0.90}, {450, 0.5}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.5}, {30, 0.91}, {60, 0.86}, {90, 0.91}, {120, 0.94}, {150, 0.95}, {180, 0.90}, {450, 0.5}}}; //80+ params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.179; @@ -826,58 +786,50 @@ void set_parameters(mio::abm::Parameters params) // Protection of reinfection is the same for all age-groups, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_80_plus, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.852}, - {180, 0.852}, - {210, 0.845}, - {240, 0.828}, - {270, 0.797}, - {300, 0.759}, - {330, 0.711}, - {360, 0.661}, - {390, 0.616}, - {420, 0.580}, - {450, 0.559}, - {450, 0.550}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.852}, + {180, 0.852}, + {210, 0.845}, + {240, 0.828}, + {270, 0.797}, + {300, 0.759}, + {330, 0.711}, + {360, 0.661}, + {390, 0.616}, + {420, 0.580}, + {450, 0.559}, + {450, 0.550}}}; // Information is from: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_80_plus, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.80}, {60, 0.79}, {90, 0.75}, {120, 0.56}, {150, 0.49}, {180, 0.43}, {450, 0.1}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.5}, {30, 0.80}, {60, 0.79}, {90, 0.75}, {120, 0.56}, {150, 0.49}, {180, 0.43}, {450, 0.1}}}; // Set up personal severe protection levels. // Protection of severe infection of age group 65 + is different from other age group, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_0_to_4, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.967}, - {30, 0.975}, - {60, 0.977}, - {90, 0.974}, - {120, 0.963}, - {150, 0.947}, - {180, 0.93}, - {210, 0.929}, - {240, 0.923}, - {270, 0.908}, - {300, 0.893}, - {330, 0.887}, - {360, 0.887}, - {360, 0.5}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.967}, + {30, 0.975}, + {60, 0.977}, + {90, 0.974}, + {120, 0.963}, + {150, 0.947}, + {180, 0.93}, + {210, 0.929}, + {240, 0.923}, + {270, 0.908}, + {300, 0.893}, + {330, 0.887}, + {360, 0.887}, + {360, 0.5}}}; // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_80_plus, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.84}, {60, 0.88}, {90, 0.89}, {120, 0.86}, {150, 0.85}, {180, 0.83}, {450, 0.5}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{0, 0.5}, {30, 0.84}, {60, 0.88}, {90, 0.89}, {120, 0.86}, {150, 0.85}, {180, 0.83}, {450, 0.5}}}; } /** diff --git a/cpp/tests/test_abm_infection.cpp b/cpp/tests/test_abm_infection.cpp index 58d8425e2a..0ccf69bd2a 100644 --- a/cpp/tests/test_abm_infection.cpp +++ b/cpp/tests/test_abm_infection.cpp @@ -79,12 +79,10 @@ TEST(TestInfection, init) EXPECT_NEAR(infection.get_infectivity(mio::abm::TimePoint(0) + mio::abm::days(3)), 0.2689414213699951, 1e-14); params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_test, - virus_variant_test}] = [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.91}, {30, 0.81}}, days); - }; - params.get() = [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.91}, {30, 0.81}}, days); - }; + virus_variant_test}] = mio::TimeDependentParameterFunctor{ + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, {{0, 0.91}, {30, 0.81}}}; + params.get() = mio::TimeDependentParameterFunctor{ + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, {{0, 0.91}, {30, 0.81}}}; auto infection_w_previous_exp = mio::abm::Infection(rng, mio::abm::VirusVariant::Wildtype, age_group_test, params, mio::abm::TimePoint(0), mio::abm::InfectionState::InfectedSymptoms, @@ -164,7 +162,8 @@ TEST(TestInfection, drawInfectionCourseBackward) TEST(TestInfection, getPersonalProtectiveFactor) { - auto rng = mio::RandomNumberGenerator(); + const ScalarType eps = 1e-4; + auto rng = mio::RandomNumberGenerator(); auto location = mio::abm::Location(mio::abm::LocationType::School, 0, num_age_groups); auto person = mio::abm::Person(rng, location.get_type(), location.get_id(), age_group_15_to_34); @@ -177,79 +176,79 @@ TEST(TestInfection, getPersonalProtectiveFactor) mio::abm::ExposureType::GenericVaccine, mio::AgeGroup(0), mio::abm::VirusVariant::Wildtype}](0); auto defaut_severity_protection = params.get()[{ mio::abm::ExposureType::GenericVaccine, mio::AgeGroup(0), mio::abm::VirusVariant::Wildtype}](0); - ASSERT_NEAR(defaut_infection_protection, 0, 0.0001); - ASSERT_NEAR(defaut_severity_protection, 0, 0.0001); + EXPECT_NEAR(defaut_infection_protection, 0, eps); + EXPECT_NEAR(defaut_severity_protection, 0, eps); // Test linear interpolation with one node - mio::set_log_level(mio::LogLevel::critical); //this throws an error either way + // mio::set_log_level(mio::LogLevel::critical); //this throws an error either way params.get()[{mio::abm::ExposureType::GenericVaccine, person.get_age(), mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{2, 0.91}}, days); - }; + mio::TimeDependentParameterFunctor{mio::TimeDependentParameterFunctor::Type::LinearInterpolation, {{2, 0.91}}}; auto t = mio::abm::TimePoint(6 * 24 * 60 * 60); - ASSERT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0, 0.001); - mio::set_log_level(mio::LogLevel::warn); //this throws an error either way + // TODO: Discuss: Assumption of interpolation in TDPF is that the function is constant with value at front/back entry outside of [front, back] time range. This works with one node as well and prints no errors + EXPECT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0.91, eps); + // mio::set_log_level(mio::LogLevel::warn); //this throws an error either way params.get()[{mio::abm::ExposureType::GenericVaccine, person.get_age(), mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{2, 0.91}, {30, 0.81}}, days); - }; + mio::TimeDependentParameterFunctor{mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{2, 0.91}, {30, 0.81}}}; params.get()[{mio::abm::ExposureType::GenericVaccine, person.get_age(), mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{2, 0.91}, {30, 0.81}}, days); - }; - params.get() = [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{2, 0.91}, {30, 0.81}}, days); - }; + mio::TimeDependentParameterFunctor{mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + {{2, 0.91}, {30, 0.81}}}; + params.get() = mio::TimeDependentParameterFunctor{ + mio::TimeDependentParameterFunctor::Type::LinearInterpolation, {{2, 0.91}, {30, 0.81}}}; // Test Parameter InfectionProtectionFactor and get_protection_factor() t = mio::abm::TimePoint(0) + mio::abm::days(2); auto infection_protection_factor = params.get()[{ latest_protection.first, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}]( t.days() - latest_protection.second.days()); - ASSERT_NEAR(infection_protection_factor, 0.91, 0.0001); - ASSERT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0.91, 0.0001); + EXPECT_NEAR(infection_protection_factor, 0.91, eps); + EXPECT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0.91, eps); t = mio::abm::TimePoint(0) + mio::abm::days(15); infection_protection_factor = params.get()[{ latest_protection.first, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}]( t.days() - latest_protection.second.days()); - ASSERT_NEAR(infection_protection_factor, 0.8635, 0.0001); - ASSERT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0.8635, 0.0001); + EXPECT_NEAR(infection_protection_factor, 0.8635, eps); + EXPECT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0.8635, eps); t = mio::abm::TimePoint(0) + mio::abm::days(40); infection_protection_factor = params.get()[{ latest_protection.first, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}]( t.days() - latest_protection.second.days()); - ASSERT_NEAR(infection_protection_factor, 0, 0.0001); - ASSERT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0, 0.0001); + EXPECT_NEAR(infection_protection_factor, 0.81, + eps); // TODO: why was this 0? should there be an instant falloff after last data point? + EXPECT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0.81, + eps); // TODO: why was this 0? should there be an instant falloff after last data point? // Test Parameter SeverityProtectionFactor t = mio::abm::TimePoint(0) + mio::abm::days(2); auto severity_protection_factor = params.get()[{ latest_protection.first, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}]( t.days() - latest_protection.second.days()); - ASSERT_NEAR(severity_protection_factor, 0.91, 0.0001); + EXPECT_NEAR(severity_protection_factor, 0.91, eps); t = mio::abm::TimePoint(0) + mio::abm::days(15); severity_protection_factor = params.get()[{ latest_protection.first, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}]( t.days() - latest_protection.second.days()); - ASSERT_NEAR(severity_protection_factor, 0.8635, 0.0001); + EXPECT_NEAR(severity_protection_factor, 0.8635, eps); t = mio::abm::TimePoint(0) + mio::abm::days(40); severity_protection_factor = params.get()[{ latest_protection.first, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}]( t.days() - latest_protection.second.days()); - ASSERT_NEAR(severity_protection_factor, 0, 0.0001); + EXPECT_NEAR(severity_protection_factor, 0.81, + eps); // TODO: why was this 0? should there be an instant falloff after last data point? // Test Parameter HighViralLoadProtectionFactor t = mio::abm::TimePoint(0) + mio::abm::days(2); - ASSERT_NEAR(params.get()(t.days()), 0.91, 0.0001); + EXPECT_NEAR(params.get()(t.days()), 0.91, eps); t = mio::abm::TimePoint(0) + mio::abm::days(15); - ASSERT_NEAR(params.get()(t.days()), 0.8635, 0.0001); + EXPECT_NEAR(params.get()(t.days()), 0.8635, eps); t = mio::abm::TimePoint(0) + mio::abm::days(40); - ASSERT_NEAR(params.get()(t.days()), 0, 0.0001); + EXPECT_NEAR(params.get()(t.days()), 0.81, + eps); // TODO: why was this 0? should there be an instant falloff after last data point? } diff --git a/cpp/tests/test_abm_location.cpp b/cpp/tests/test_abm_location.cpp index 2e95746ca9..40494ce2f5 100644 --- a/cpp/tests/test_abm_location.cpp +++ b/cpp/tests/test_abm_location.cpp @@ -23,6 +23,8 @@ #include "abm/person.h" #include "abm/world.h" #include "abm_helpers.h" +#include "matchers.h" +#include "memilio/io/json_serializer.h" #include "memilio/utils/random_number_generator.h" TEST(TestLocation, initCell) @@ -193,3 +195,45 @@ TEST(TestLocation, getGeographicalLocation) ASSERT_EQ(location.get_geographical_location(), geographical_location); } + +TEST(TestLocation, jsonSerialization) +{ + // Test that a json value x representing Location is equal to serialize(deserialize(x)) w.r.t json representation + + // Assuming (de)serialization does not depend on specific values of member variables, and that deserialize is + // injective (meaning two instances with different values do not have the same json representation, which can + // happen e.g. if not all member variables are serialized), + // this sufficiently tests that serialize and deserialize are inverse functions to each other + + unsigned i = 1; // counter s.t. members have different values + + // define a json value for a Location + Json::Value reference_json; // aka x + reference_json["cells"][0]["capacity"]["persons"] = Json::UInt(i++); + reference_json["cells"][0]["capacity"]["volume"] = Json::UInt(i++); + reference_json["geographical_location"]["latitude"] = Json::Value((double)i++); + reference_json["geographical_location"]["longitude"] = Json::Value((double)i++); + reference_json["id"] = Json::UInt(i++); + reference_json["npi_active"] = Json::Value(false); + reference_json["parameters"]["ContactRates"] = + mio::serialize_json(mio::abm::ContactRates::get_default(i++)).value(); + reference_json["parameters"]["MaximumContacts"] = Json::Value((double)i++); + reference_json["parameters"]["UseLocationCapacityForTransmissions"] = Json::Value(false); + reference_json["required_mask"] = Json::UInt(0); + + // check that the json is deserializable (i.e. a valid representation) + auto r = mio::deserialize_json(reference_json, mio::Tag()); + ASSERT_THAT(print_wrap(r), IsSuccess()); + // check that the resulting Person is serializable + auto result = mio::serialize_json(r.value()); + ASSERT_TRUE(result.value()); + // write the resulting json value and the reference value to string to compare their representations. + Json::StreamWriterBuilder swb; + swb["indentation"] = " "; + auto js_writer = std::unique_ptr(swb.newStreamWriter()); + std::stringstream result_str, reference_str; + js_writer->write(reference_json, &reference_str); + js_writer->write(result.value(), &result_str); + // we compare strings here, as e.g. Json::Int(5) != Json::Uint(5), but their json representation is the same + EXPECT_EQ(result_str.str(), reference_str.str()); +} diff --git a/cpp/tests/test_abm_person.cpp b/cpp/tests/test_abm_person.cpp index 6cea44e1d4..39e3970cef 100644 --- a/cpp/tests/test_abm_person.cpp +++ b/cpp/tests/test_abm_person.cpp @@ -24,6 +24,8 @@ #include "abm/person.h" #include "abm/time.h" #include "abm_helpers.h" +#include "matchers.h" +#include "memilio/io/json_serializer.h" #include "memilio/utils/random_number_generator.h" #include @@ -322,3 +324,63 @@ TEST(Person, rng) EXPECT_EQ(p.get_rng_counter(), mio::Counter(1)); EXPECT_EQ(p_rng.get_counter(), mio::rng_totalsequence_counter(13, mio::Counter{1})); } + +TEST(TestPerson, jsonSerialization) +{ + // Test that a json value x representing Person is equal to serialize(deserialize(x)) w.r.t json representation + + // Assuming (de)serialization does not depend on specific values of member variables, and that deserialize is + // injective (meaning two instances with different values do not have the same json representation, which can + // happen e.g. if not all member variables are serialized), + // this sufficiently tests that serialize and deserialize are inverse functions to each other + + auto json_uint_array = [](std::vector values) { + return mio::serialize_json(values).value(); + }; + auto json_double_array = [](std::vector values) { + return mio::serialize_json(values).value(); + }; + + unsigned i = 1; // counter s.t. members have different values + + // define a json value for a Person + Json::Value reference_json; // aka x + reference_json["age_group"] = Json::UInt(i++); + reference_json["assigned_locations"] = json_uint_array({i++, i++, i++, i++, i++, i++, i++, i++, i++, i++, i++}); + reference_json["cells"] = json_uint_array({i++}); + reference_json["id"] = Json::UInt(i++); + reference_json["infections"] = Json::Value(Json::arrayValue); + reference_json["last_transport_mode"] = Json::UInt(i++); + reference_json["location"] = Json::UInt(i++); + reference_json["mask"]["mask_type"] = Json::UInt(0); + reference_json["mask"]["time_used"]["seconds"] = Json::UInt(i++); + reference_json["mask_compliance"] = + json_double_array({(double)i++, (double)i++, (double)i++, (double)i++, (double)i++, (double)i++, (double)i++, + (double)i++, (double)i++, (double)i++, (double)i++}); + reference_json["quarantine_start"]["seconds"] = Json::UInt(i++); + reference_json["rnd_go_to_school_hour"] = Json::Value((double)i++); + reference_json["rnd_go_to_work_hour"] = Json::Value((double)i++); + reference_json["rnd_schoolgroup"] = Json::Value((double)i++); + reference_json["rnd_workgroup"] = Json::Value((double)i++); + reference_json["rng_counter"] = Json::UInt(i++); + reference_json["time_at_location"]["seconds"] = Json::UInt(i++); + reference_json["time_of_last_test"]["seconds"] = Json::UInt(i++); + reference_json["vaccinations"] = Json::Value(Json::arrayValue); + reference_json["wears_mask"] = Json::Value(false); + + // check that the json is deserializable (i.e. a valid representation) + auto r = mio::deserialize_json(reference_json, mio::Tag()); + ASSERT_THAT(print_wrap(r), IsSuccess()); + // check that the resulting Person is serializable + auto result = mio::serialize_json(r.value()); + ASSERT_TRUE(result.value()); + // write the resulting json value and the reference value to string to compare their representations. + Json::StreamWriterBuilder swb; + swb["indentation"] = " "; + auto js_writer = std::unique_ptr(swb.newStreamWriter()); + std::stringstream result_str, reference_str; + js_writer->write(reference_json, &reference_str); + js_writer->write(result.value(), &result_str); + // we compare strings here, as e.g. Json::Int(5) != Json::Uint(5), but their json representation is the same + EXPECT_EQ(result_str.str(), reference_str.str()); +} diff --git a/cpp/tests/test_abm_world.cpp b/cpp/tests/test_abm_world.cpp index 99a3eed6fc..59dfe8c59a 100644 --- a/cpp/tests/test_abm_world.cpp +++ b/cpp/tests/test_abm_world.cpp @@ -19,6 +19,8 @@ */ #include "abm/person.h" #include "abm_helpers.h" +#include "matchers.h" +#include "memilio/io/json_serializer.h" #include "memilio/utils/random_number_generator.h" TEST(TestWorld, init) @@ -558,3 +560,67 @@ TEST(TestWorld, checkParameterConstraints) params.get() = mio::abm::TimePoint(-2); ASSERT_EQ(params.check_constraints(), true); } + +TEST(TestWorld, abmTripJsonSerialization) +{ + mio::abm::Trip trip(0, mio::abm::TimePoint(0) + mio::abm::hours(8), 1, 2); + auto js = mio::serialize_json(trip, true); + Json::Value expected_json; + expected_json["person_id"] = Json::UInt(0); + expected_json["time"] = Json::Int(mio::abm::hours(8).seconds()); + expected_json["destination"] = Json::UInt(1); + expected_json["origin"] = Json::UInt(2); + ASSERT_EQ(js.value(), expected_json); + + auto r = mio::deserialize_json(expected_json, mio::Tag()); + ASSERT_THAT(print_wrap(r), IsSuccess()); + EXPECT_EQ(r.value(), trip); +} + +TEST(TestWorld, jsonSerialization) +{ + // Test that a json value x representing World is equal to serialize(deserialize(x)) w.r.t json representation + + // Assuming (de)serialization does not depend on specific values of member variables, and that deserialize is + // injective (meaning two instances with different values do not have the same json representation, which can + // happen e.g. if not all member variables are serialized), + // this sufficiently tests that serialize and deserialize are inverse functions to each other + + auto json_uint_array = [](std::vector values) { + return mio::serialize_json(values).value(); + }; + + unsigned i = 1; // counter s.t. members have different values + + // define a json value for a World + Json::Value reference_json; // aka x + reference_json["cemetery_id"] = Json::UInt(i++); + reference_json["location_types"] = Json::UInt(i++); + reference_json["locations"] = Json::Value(Json::arrayValue); + reference_json["parameters"] = mio::serialize_json(mio::abm::Parameters(i++)).value(); + reference_json["persons"] = Json::Value(Json::arrayValue); + reference_json["rng"]["counter"] = Json::UInt(i++); + reference_json["rng"]["key"] = Json::UInt(i++); + reference_json["rng"]["seeds"] = json_uint_array({i++, i++, i++, i++, i++, i++}); + reference_json["testing_strategy"]["schemes"] = Json::Value(Json::arrayValue); + reference_json["trip_list"]["index"] = Json::UInt(i++); + reference_json["trip_list"]["trips_weekday"] = Json::Value(Json::arrayValue); + reference_json["trip_list"]["trips_weekend"] = Json::Value(Json::arrayValue); + reference_json["use_migration_rules"] = Json::Value(true); + + // check that the json is deserializable (i.e. a valid representation) + auto r = mio::deserialize_json(reference_json, mio::Tag()); + ASSERT_THAT(print_wrap(r), IsSuccess()); + // check that the resulting Person is serializable + auto result = mio::serialize_json(r.value()); + ASSERT_TRUE(result.value()); + // write the resulting json value and the reference value to string to compare their representations. + Json::StreamWriterBuilder swb; + swb["indentation"] = " "; + auto js_writer = std::unique_ptr(swb.newStreamWriter()); + std::stringstream result_str, reference_str; + js_writer->write(reference_json, &reference_str); + js_writer->write(result.value(), &result_str); + // we compare strings here, as e.g. Json::Int(5) != Json::Uint(5), but their json representation is the same + EXPECT_EQ(result_str.str(), reference_str.str()); +} diff --git a/cpp/tests/test_json_serializer.cpp b/cpp/tests/test_json_serializer.cpp index 9ef34343c3..a676edc0ea 100644 --- a/cpp/tests/test_json_serializer.cpp +++ b/cpp/tests/test_json_serializer.cpp @@ -17,6 +17,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "matchers.h" +#include "distributions_helpers.h" + #include "memilio/io/json_serializer.h" #include "memilio/utils/parameter_distributions.h" #include "memilio/utils/stl_util.h" @@ -24,14 +27,10 @@ #include "memilio/utils/custom_index_array.h" #include "memilio/utils/parameter_set.h" #include "memilio/utils/uncertain_value.h" -#include "abm_helpers.h" -#include "matchers.h" -#include "distributions_helpers.h" + #include "gtest/gtest.h" -#include "json/config.h" #include "gmock/gmock.h" -#include -#include + #include namespace jsontest @@ -472,97 +471,3 @@ TEST(TestJsonSerializer, container_of_objects) ASSERT_THAT(print_wrap(r), IsSuccess()); EXPECT_THAT(r.value(), testing::UnorderedElementsAre(jsontest::Foo{1}, jsontest::Foo{2})); } - -TEST(TestJsonSerializer, abmLocation) -{ - auto location = mio::abm::Location(mio::abm::LocationType::Home, 0, num_age_groups); - auto js = mio::serialize_json(location); - Json::Value expected_json; - expected_json["index"] = Json::UInt64(0); - expected_json["type"] = Json::UInt64(mio::abm::LocationType::Home); - ASSERT_EQ(js.value(), expected_json); - - auto r = mio::deserialize_json(expected_json, mio::Tag()); - ASSERT_THAT(print_wrap(r), IsSuccess()); - EXPECT_EQ(r.value(), location); -} - -// TEST(TestJsonSerializer, abmPerson) // FIXME: (de)serialize is only partially implemented -// { -// auto location = mio::abm::Location(mio::abm::LocationType::School, 0, 6); -// auto person = make_test_person(location); -// auto js = mio::serialize_json(person); -// Json::Value expected_json; -// expected_json["Location"]["index"] = Json::UInt(location.get_id()); -// expected_json["Location"]["type"] = Json::UInt(location.get_type()); -// expected_json["age"] = Json::UInt(2); -// expected_json["id"] = Json::UInt(person.get_id()); -// ASSERT_EQ(js.value(), expected_json); - -// // auto r = mio::deserialize_json(expected_json, mio::Tag()); -// // ASSERT_THAT(print_wrap(r), IsSuccess()); -// // EXPECT_EQ(r.value(), person); -// } - -TEST(TestJsonSerializer, abmTrip) -{ - mio::abm::Location home{mio::abm::LocationType::Home, 0}; - mio::abm::Location work{mio::abm::LocationType::Work, 1}; - auto person = make_test_person(home); - // add a trip from home (0) to work (1) - mio::abm::Trip trip(person.get_id(), mio::abm::TimePoint(0) + mio::abm::hours(8), 1, 0); - auto js = mio::serialize_json(trip, true); - Json::Value expected_json; - expected_json["person_id"] = Json::UInt(person.get_id()); - expected_json["time"] = Json::Int(mio::abm::hours(8).seconds()); - expected_json["destination"] = Json::UInt(1); // work - expected_json["origin"] = Json::UInt(0); // home - ASSERT_EQ(js.value(), expected_json); - - auto r = mio::deserialize_json(expected_json, mio::Tag()); - ASSERT_THAT(print_wrap(r), IsSuccess()); - EXPECT_EQ(r.value(), trip); -} - -// TEST(TestJsonSerializer, abmWorld) // FIXME: (de)serialize is only partially implemented -// { -// auto world = mio::abm::World(num_age_groups); -// auto home_id = world.add_location(mio::abm::LocationType::Home); -// auto work_id = world.add_location(mio::abm::LocationType::Work); -// auto person = world.add_person(home_id, age_group_15_to_34); -// mio::abm::Trip trip1(person, mio::abm::TimePoint(0) + mio::abm::hours(8), work_id, home_id); -// mio::abm::Trip trip2(person, mio::abm::TimePoint(0) + mio::abm::hours(11), work_id, home_id); -// world.get_trip_list().add_trip(trip1, false); -// world.get_trip_list().add_trip(trip2, true); -// auto js = mio::serialize_json(world); -// Json::Value expected_json; -// expected_json["num_agegroups"] = Json::UInt(num_age_groups); -// expected_json["trips"][0]["person_id"] = Json::UInt(person); -// expected_json["trips"][0]["time"] = Json::Int(mio::abm::hours(8).seconds()); -// expected_json["trips"][0]["destination_index"] = Json::UInt(1); // work_id -// expected_json["trips"][0]["destination_type"] = Json::UInt(mio::abm::LocationType::Work); -// expected_json["trips"][0]["origin_index"] = Json::UInt(0); // home_id -// expected_json["trips"][0]["origin_type"] = Json::UInt(mio::abm::LocationType::Home); -// expected_json["trips"][1]["person_id"] = Json::UInt(person); -// expected_json["trips"][1]["time"] = Json::Int(mio::abm::hours(11).seconds()); -// expected_json["trips"][1]["destination_index"] = Json::UInt(1); // work_id -// expected_json["trips"][1]["destination_type"] = Json::UInt(mio::abm::LocationType::Work); -// expected_json["trips"][1]["origin_index"] = Json::UInt(0); // home_id -// expected_json["trips"][1]["origin_type"] = Json::UInt(mio::abm::LocationType::Home); -// expected_json["locations"][0]["index"] = Json::UInt(0); -// expected_json["locations"][0]["type"] = Json::UInt(mio::abm::LocationType::Cemetery); -// expected_json["locations"][1]["index"] = Json::UInt(1); -// expected_json["locations"][1]["type"] = Json::UInt(mio::abm::LocationType::Home); -// expected_json["locations"][2]["index"] = Json::UInt(2); -// expected_json["locations"][2]["type"] = Json::UInt(mio::abm::LocationType::Work); -// expected_json["persons"][0]["Location"]["index"] = Json::UInt(1); -// expected_json["persons"][0]["Location"]["type"] = Json::UInt(mio::abm::LocationType::Home); -// expected_json["persons"][0]["age"] = Json::UInt(2); -// expected_json["persons"][0]["id"] = Json::UInt(person); -// expected_json["use_migration_rules"] = Json::Value(true); -// ASSERT_EQ(js.value(), expected_json); - -// // auto r = mio::deserialize_json(expected_json, mio::Tag()); -// // ASSERT_THAT(print_wrap(r), IsSuccess()); -// // EXPECT_EQ(r.value(), world); -// } From f0598bfad7cb1a24b0d8f0d57eacb6f93bb7f613 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Fri, 19 Jul 2024 19:31:05 +0200 Subject: [PATCH 02/42] CI fixes --- cpp/memilio/io/auto_serialize.h | 10 +++++----- cpp/tests/test_abm_location.cpp | 4 ++++ cpp/tests/test_abm_person.cpp | 4 ++++ cpp/tests/test_abm_world.cpp | 4 ++++ 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/cpp/memilio/io/auto_serialize.h b/cpp/memilio/io/auto_serialize.h index fa823c7e02..aa056320d3 100644 --- a/cpp/memilio/io/auto_serialize.h +++ b/cpp/memilio/io/auto_serialize.h @@ -166,10 +166,10 @@ using has_auto_serialize = is_expression_valid::value && - not has_serialize::value, - AutoSerializable*> = nullptr> +template < + class IOContext, class AutoSerializable, + std::enable_if_t::value && !has_serialize::value, + AutoSerializable*> = nullptr> void serialize_internal(IOContext& io, const AutoSerializable& t) { // Note that this cast is only safe if we do not modify targets. @@ -181,7 +181,7 @@ void serialize_internal(IOContext& io, const AutoSerializable& t) // generates deserialize method depending on NVPs given by auto_serialize template ::value && - not has_deserialize::value, + !has_deserialize::value, AutoSerializable*> = nullptr> IOResult deserialize_internal(IOContext& io, Tag) { diff --git a/cpp/tests/test_abm_location.cpp b/cpp/tests/test_abm_location.cpp index 40494ce2f5..c108b8fbf5 100644 --- a/cpp/tests/test_abm_location.cpp +++ b/cpp/tests/test_abm_location.cpp @@ -196,6 +196,8 @@ TEST(TestLocation, getGeographicalLocation) ASSERT_EQ(location.get_geographical_location(), geographical_location); } +#ifdef MEMILIO_HAS_JSONCPP + TEST(TestLocation, jsonSerialization) { // Test that a json value x representing Location is equal to serialize(deserialize(x)) w.r.t json representation @@ -237,3 +239,5 @@ TEST(TestLocation, jsonSerialization) // we compare strings here, as e.g. Json::Int(5) != Json::Uint(5), but their json representation is the same EXPECT_EQ(result_str.str(), reference_str.str()); } + +#endif // MEMILIO_HAS_JSONCPP diff --git a/cpp/tests/test_abm_person.cpp b/cpp/tests/test_abm_person.cpp index 39e3970cef..c7834dc5c7 100644 --- a/cpp/tests/test_abm_person.cpp +++ b/cpp/tests/test_abm_person.cpp @@ -325,6 +325,8 @@ TEST(Person, rng) EXPECT_EQ(p_rng.get_counter(), mio::rng_totalsequence_counter(13, mio::Counter{1})); } +#ifdef MEMILIO_HAS_JSONCPP + TEST(TestPerson, jsonSerialization) { // Test that a json value x representing Person is equal to serialize(deserialize(x)) w.r.t json representation @@ -384,3 +386,5 @@ TEST(TestPerson, jsonSerialization) // we compare strings here, as e.g. Json::Int(5) != Json::Uint(5), but their json representation is the same EXPECT_EQ(result_str.str(), reference_str.str()); } + +#endif // MEMILIO_HAS_JSONCPP diff --git a/cpp/tests/test_abm_world.cpp b/cpp/tests/test_abm_world.cpp index 59dfe8c59a..f17b0bbae9 100644 --- a/cpp/tests/test_abm_world.cpp +++ b/cpp/tests/test_abm_world.cpp @@ -561,6 +561,8 @@ TEST(TestWorld, checkParameterConstraints) ASSERT_EQ(params.check_constraints(), true); } +#ifdef MEMILIO_HAS_JSONCPP + TEST(TestWorld, abmTripJsonSerialization) { mio::abm::Trip trip(0, mio::abm::TimePoint(0) + mio::abm::hours(8), 1, 2); @@ -624,3 +626,5 @@ TEST(TestWorld, jsonSerialization) // we compare strings here, as e.g. Json::Int(5) != Json::Uint(5), but their json representation is the same EXPECT_EQ(result_str.str(), reference_str.str()); } + +#endif // MEMILIO_HAS_JSONCPP From 5037147b5e2fc3ad866b02e07d0ab77d3aee796e Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Tue, 23 Jul 2024 08:51:54 +0200 Subject: [PATCH 03/42] move serialization tests and improve coverage --- cpp/tests/CMakeLists.txt | 1 + cpp/tests/test_abm_location.cpp | 46 ------ cpp/tests/test_abm_person.cpp | 64 ------- cpp/tests/test_abm_serialization.cpp | 238 +++++++++++++++++++++++++++ cpp/tests/test_abm_world.cpp | 71 +------- 5 files changed, 240 insertions(+), 180 deletions(-) create mode 100644 cpp/tests/test_abm_serialization.cpp diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 4e6c389fe8..9fe42496ea 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -36,6 +36,7 @@ set(TESTSOURCES test_abm_migration_rules.cpp test_abm_person.cpp test_abm_simulation.cpp + test_abm_serialization.cpp test_abm_testing_strategy.cpp test_abm_world.cpp test_math_floating_point.cpp diff --git a/cpp/tests/test_abm_location.cpp b/cpp/tests/test_abm_location.cpp index c108b8fbf5..fbba080a73 100644 --- a/cpp/tests/test_abm_location.cpp +++ b/cpp/tests/test_abm_location.cpp @@ -195,49 +195,3 @@ TEST(TestLocation, getGeographicalLocation) ASSERT_EQ(location.get_geographical_location(), geographical_location); } - -#ifdef MEMILIO_HAS_JSONCPP - -TEST(TestLocation, jsonSerialization) -{ - // Test that a json value x representing Location is equal to serialize(deserialize(x)) w.r.t json representation - - // Assuming (de)serialization does not depend on specific values of member variables, and that deserialize is - // injective (meaning two instances with different values do not have the same json representation, which can - // happen e.g. if not all member variables are serialized), - // this sufficiently tests that serialize and deserialize are inverse functions to each other - - unsigned i = 1; // counter s.t. members have different values - - // define a json value for a Location - Json::Value reference_json; // aka x - reference_json["cells"][0]["capacity"]["persons"] = Json::UInt(i++); - reference_json["cells"][0]["capacity"]["volume"] = Json::UInt(i++); - reference_json["geographical_location"]["latitude"] = Json::Value((double)i++); - reference_json["geographical_location"]["longitude"] = Json::Value((double)i++); - reference_json["id"] = Json::UInt(i++); - reference_json["npi_active"] = Json::Value(false); - reference_json["parameters"]["ContactRates"] = - mio::serialize_json(mio::abm::ContactRates::get_default(i++)).value(); - reference_json["parameters"]["MaximumContacts"] = Json::Value((double)i++); - reference_json["parameters"]["UseLocationCapacityForTransmissions"] = Json::Value(false); - reference_json["required_mask"] = Json::UInt(0); - - // check that the json is deserializable (i.e. a valid representation) - auto r = mio::deserialize_json(reference_json, mio::Tag()); - ASSERT_THAT(print_wrap(r), IsSuccess()); - // check that the resulting Person is serializable - auto result = mio::serialize_json(r.value()); - ASSERT_TRUE(result.value()); - // write the resulting json value and the reference value to string to compare their representations. - Json::StreamWriterBuilder swb; - swb["indentation"] = " "; - auto js_writer = std::unique_ptr(swb.newStreamWriter()); - std::stringstream result_str, reference_str; - js_writer->write(reference_json, &reference_str); - js_writer->write(result.value(), &result_str); - // we compare strings here, as e.g. Json::Int(5) != Json::Uint(5), but their json representation is the same - EXPECT_EQ(result_str.str(), reference_str.str()); -} - -#endif // MEMILIO_HAS_JSONCPP diff --git a/cpp/tests/test_abm_person.cpp b/cpp/tests/test_abm_person.cpp index c7834dc5c7..f52b03138f 100644 --- a/cpp/tests/test_abm_person.cpp +++ b/cpp/tests/test_abm_person.cpp @@ -324,67 +324,3 @@ TEST(Person, rng) EXPECT_EQ(p.get_rng_counter(), mio::Counter(1)); EXPECT_EQ(p_rng.get_counter(), mio::rng_totalsequence_counter(13, mio::Counter{1})); } - -#ifdef MEMILIO_HAS_JSONCPP - -TEST(TestPerson, jsonSerialization) -{ - // Test that a json value x representing Person is equal to serialize(deserialize(x)) w.r.t json representation - - // Assuming (de)serialization does not depend on specific values of member variables, and that deserialize is - // injective (meaning two instances with different values do not have the same json representation, which can - // happen e.g. if not all member variables are serialized), - // this sufficiently tests that serialize and deserialize are inverse functions to each other - - auto json_uint_array = [](std::vector values) { - return mio::serialize_json(values).value(); - }; - auto json_double_array = [](std::vector values) { - return mio::serialize_json(values).value(); - }; - - unsigned i = 1; // counter s.t. members have different values - - // define a json value for a Person - Json::Value reference_json; // aka x - reference_json["age_group"] = Json::UInt(i++); - reference_json["assigned_locations"] = json_uint_array({i++, i++, i++, i++, i++, i++, i++, i++, i++, i++, i++}); - reference_json["cells"] = json_uint_array({i++}); - reference_json["id"] = Json::UInt(i++); - reference_json["infections"] = Json::Value(Json::arrayValue); - reference_json["last_transport_mode"] = Json::UInt(i++); - reference_json["location"] = Json::UInt(i++); - reference_json["mask"]["mask_type"] = Json::UInt(0); - reference_json["mask"]["time_used"]["seconds"] = Json::UInt(i++); - reference_json["mask_compliance"] = - json_double_array({(double)i++, (double)i++, (double)i++, (double)i++, (double)i++, (double)i++, (double)i++, - (double)i++, (double)i++, (double)i++, (double)i++}); - reference_json["quarantine_start"]["seconds"] = Json::UInt(i++); - reference_json["rnd_go_to_school_hour"] = Json::Value((double)i++); - reference_json["rnd_go_to_work_hour"] = Json::Value((double)i++); - reference_json["rnd_schoolgroup"] = Json::Value((double)i++); - reference_json["rnd_workgroup"] = Json::Value((double)i++); - reference_json["rng_counter"] = Json::UInt(i++); - reference_json["time_at_location"]["seconds"] = Json::UInt(i++); - reference_json["time_of_last_test"]["seconds"] = Json::UInt(i++); - reference_json["vaccinations"] = Json::Value(Json::arrayValue); - reference_json["wears_mask"] = Json::Value(false); - - // check that the json is deserializable (i.e. a valid representation) - auto r = mio::deserialize_json(reference_json, mio::Tag()); - ASSERT_THAT(print_wrap(r), IsSuccess()); - // check that the resulting Person is serializable - auto result = mio::serialize_json(r.value()); - ASSERT_TRUE(result.value()); - // write the resulting json value and the reference value to string to compare their representations. - Json::StreamWriterBuilder swb; - swb["indentation"] = " "; - auto js_writer = std::unique_ptr(swb.newStreamWriter()); - std::stringstream result_str, reference_str; - js_writer->write(reference_json, &reference_str); - js_writer->write(result.value(), &result_str); - // we compare strings here, as e.g. Json::Int(5) != Json::Uint(5), but their json representation is the same - EXPECT_EQ(result_str.str(), reference_str.str()); -} - -#endif // MEMILIO_HAS_JSONCPP diff --git a/cpp/tests/test_abm_serialization.cpp b/cpp/tests/test_abm_serialization.cpp new file mode 100644 index 0000000000..a1f9eb1608 --- /dev/null +++ b/cpp/tests/test_abm_serialization.cpp @@ -0,0 +1,238 @@ +#include "abm/vaccine.h" +#include "matchers.h" +#include "memilio/io/json_serializer.h" +#include "models/abm/location.h" +#include "models/abm/person.h" +#include "models/abm/time.h" +#include "models/abm/trip_list.h" +#include "models/abm/world.h" +#include "json/config.h" +#include "json/value.h" + +#ifdef MEMILIO_HAS_JSONCPP + +void test_equal_json_representation(const Json::Value& test_json, const Json::Value& reference_json) +{ + // write the resulting json value and the reference value to string to compare their representations. + Json::StreamWriterBuilder swb; + swb["indentation"] = " "; + auto js_writer = std::unique_ptr(swb.newStreamWriter()); + std::stringstream test_str, reference_str; + js_writer->write(reference_json, &reference_str); + js_writer->write(test_json, &test_str); + // we compare strings here, as e.g. Json::Int(5) != Json::Uint(5), but their json representation is the same + EXPECT_EQ(test_str.str(), reference_str.str()); +} + +/** + * @brief Test de- and serialization of an object by comparing its json representation. + * + * Test that a json value x representing type T is equal to serialize(deserialize(x)) w.r.t json representation. + * + * Assuming the (de)serialization functions' general behavior is independent of specific values of member variables, + * i.e. the function does not contain conditionals (`if (t > 0)`), optionals (`add_optional`/`expect_optional`), etc., + * and assuming that deserialize is injective (meaning that two unequal instances of T do not have the same json + * representation, which can happen e.g. if not all member variables are serialized), + * this sufficiently tests that serialize and deserialize are inverse functions to each other. + * + * @tparam T The type to test. + * @param reference_json A json value representing an instance of T. + */ +template +void test_json_serialization_by_representation(const Json::Value& reference_json) +{ + // check that the json is deserializable (i.e. a valid representation) + auto t_result = mio::deserialize_json(reference_json, mio::Tag()); + ASSERT_THAT(print_wrap(t_result), IsSuccess()); + + // check that the resulting type T is serializable + auto json_result = mio::serialize_json(t_result.value()); + ASSERT_TRUE(json_result); + + test_equal_json_representation(json_result.value(), reference_json); +} + +/** + * @brief Test de- and serialization of an object by comparing its json representation and using its equality operator. + * + * First, test that serializing the reference_object is equal to the reference_json (w.r.t. their representation), + * and that deserializing the reference_json results in an object equal to the reference_object. + * Then, repeat this step using its own results as arguments to (de)serialize, to check that serialization and + * deserialization are inverse functions to each other. + * + * @tparam T The type to test. + * @param reference_object An instance of T. + * @param reference_json A json value representing reference_object. + */ +template +void test_json_serialization_full(const T& reference_object, const Json::Value& reference_json) +{ + // check that the reference type T is serializable + auto json_result = mio::serialize_json(reference_object); + ASSERT_TRUE(json_result); + + // check that the reference json is deserializable + auto t_result = mio::deserialize_json(reference_json, mio::Tag()); + ASSERT_THAT(print_wrap(t_result), IsSuccess()); + + // compare both results with other reference values + EXPECT_EQ(t_result.value(), reference_object); + test_equal_json_representation(json_result.value(), reference_json); + + // do the same once more using the results from above + auto json_result_2 = mio::serialize_json(t_result.value()); + ASSERT_TRUE(json_result_2); + auto t_result_2 = mio::deserialize_json(json_result.value(), mio::Tag()); + ASSERT_THAT(print_wrap(t_result_2), IsSuccess()); + + EXPECT_EQ(t_result_2.value(), reference_object); + test_equal_json_representation(json_result_2.value(), reference_json); +} + +TEST(TestAbmSerialization, Trip) +{ + // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. + // See test_json_serialization_by_representation for more detail. + + mio::abm::Trip trip(1, mio::abm::TimePoint(0) + mio::abm::hours(2), 3, 4); + + Json::Value reference_json; // aka x + reference_json["person_id"] = Json::UInt(1); + reference_json["time"] = Json::Int(mio::abm::hours(2).seconds()); + reference_json["destination"] = Json::UInt(3); + reference_json["origin"] = Json::UInt(4); + + test_json_serialization_full(trip, reference_json); +} + +TEST(TestAbmSerialization, Vaccination) +{ + // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. + // See test_json_serialization_by_representation for more detail. + + Json::Value reference_json; // aka x + reference_json["exposure_type"] = Json::Int(1); + reference_json["time"]["seconds"] = Json::UInt(2); + + test_json_serialization_by_representation(reference_json); +} + +TEST(TestAbmSerialization, Infection) +{ + // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. + // See test_json_serialization_by_representation for more detail. + + unsigned i = 1; // counter s.t. members have different values + + Json::Value viral_load; + viral_load["decline"] = Json::Value((double)i++); + viral_load["end_date"]["seconds"] = Json::UInt(i++); + viral_load["incline"] = Json::Value((double)i++); + viral_load["peak"] = Json::Value((double)i++); + viral_load["start_date"]["seconds"] = Json::UInt(i++); + + Json::Value reference_json; // aka x + reference_json["infection_course"] = Json::Value(Json::arrayValue); + reference_json["virus_variant"] = Json::UInt(0); + reference_json["viral_load"] = viral_load; + reference_json["log_norm_alpha"] = Json::Value((double)i++); + reference_json["log_norm_beta"] = Json::Value((double)i++); + reference_json["detected"] = Json::Value((bool)0); + + test_json_serialization_by_representation(reference_json); +} + +TEST(TestAbmSerialization, Person) +{ + // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. + // See test_json_serialization_by_representation for more detail. + + auto json_uint_array = [](std::vector values) { + return mio::serialize_json(values).value(); + }; + auto json_double_array = [](std::vector values) { + return mio::serialize_json(values).value(); + }; + + unsigned i = 1; // counter s.t. members have different values + + Json::Value reference_json; // aka x + reference_json["age_group"] = Json::UInt(i++); + reference_json["assigned_locations"] = json_uint_array({i++, i++, i++, i++, i++, i++, i++, i++, i++, i++, i++}); + reference_json["cells"] = json_uint_array({i++}); + reference_json["id"] = Json::UInt(i++); + reference_json["infections"] = Json::Value(Json::arrayValue); + reference_json["last_transport_mode"] = Json::UInt(i++); + reference_json["location"] = Json::UInt(i++); + reference_json["mask"]["mask_type"] = Json::UInt(0); + reference_json["mask"]["time_used"]["seconds"] = Json::UInt(i++); + reference_json["mask_compliance"] = + json_double_array({(double)i++, (double)i++, (double)i++, (double)i++, (double)i++, (double)i++, (double)i++, + (double)i++, (double)i++, (double)i++, (double)i++}); + reference_json["quarantine_start"]["seconds"] = Json::UInt(i++); + reference_json["rnd_go_to_school_hour"] = Json::Value((double)i++); + reference_json["rnd_go_to_work_hour"] = Json::Value((double)i++); + reference_json["rnd_schoolgroup"] = Json::Value((double)i++); + reference_json["rnd_workgroup"] = Json::Value((double)i++); + reference_json["rng_counter"] = Json::UInt(i++); + reference_json["time_at_location"]["seconds"] = Json::UInt(i++); + reference_json["time_of_last_test"]["seconds"] = Json::UInt(i++); + reference_json["vaccinations"] = Json::Value(Json::arrayValue); + reference_json["wears_mask"] = Json::Value(false); + + test_json_serialization_by_representation(reference_json); +} + +TEST(TestAbmSerialization, Location) +{ + // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. + // See test_json_serialization_by_representation for more detail. + + unsigned i = 1; // counter s.t. members have different values + + Json::Value reference_json; // aka x + reference_json["cells"][0]["capacity"]["persons"] = Json::UInt(i++); + reference_json["cells"][0]["capacity"]["volume"] = Json::UInt(i++); + reference_json["geographical_location"]["latitude"] = Json::Value((double)i++); + reference_json["geographical_location"]["longitude"] = Json::Value((double)i++); + reference_json["id"] = Json::UInt(i++); + reference_json["npi_active"] = Json::Value(false); + reference_json["parameters"]["ContactRates"] = + mio::serialize_json(mio::abm::ContactRates::get_default(i++)).value(); + reference_json["parameters"]["MaximumContacts"] = Json::Value((double)i++); + reference_json["parameters"]["UseLocationCapacityForTransmissions"] = Json::Value(false); + reference_json["required_mask"] = Json::UInt(0); + + test_json_serialization_by_representation(reference_json); +} + +TEST(TestAbmSerialization, World) +{ + // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. + // See test_json_serialization_by_representation for more detail. + + auto json_uint_array = [](std::vector values) { + return mio::serialize_json(values).value(); + }; + + unsigned i = 1; // counter s.t. members have different values + + Json::Value reference_json; // aka x + reference_json["cemetery_id"] = Json::UInt(i++); + reference_json["location_types"] = Json::UInt(i++); + reference_json["locations"] = Json::Value(Json::arrayValue); + reference_json["parameters"] = mio::serialize_json(mio::abm::Parameters(i++)).value(); + reference_json["persons"] = Json::Value(Json::arrayValue); + reference_json["rng"]["counter"] = Json::UInt(i++); + reference_json["rng"]["key"] = Json::UInt(i++); + reference_json["rng"]["seeds"] = json_uint_array({i++, i++, i++, i++, i++, i++}); + reference_json["testing_strategy"]["schemes"] = Json::Value(Json::arrayValue); + reference_json["trip_list"]["index"] = Json::UInt(i++); + reference_json["trip_list"]["trips_weekday"] = Json::Value(Json::arrayValue); + reference_json["trip_list"]["trips_weekend"] = Json::Value(Json::arrayValue); + reference_json["use_migration_rules"] = Json::Value(true); + + test_json_serialization_by_representation(reference_json); +} + +#endif diff --git a/cpp/tests/test_abm_world.cpp b/cpp/tests/test_abm_world.cpp index f17b0bbae9..8eb47b8914 100644 --- a/cpp/tests/test_abm_world.cpp +++ b/cpp/tests/test_abm_world.cpp @@ -18,9 +18,8 @@ * limitations under the License. */ #include "abm/person.h" +#include "abm/world.h" #include "abm_helpers.h" -#include "matchers.h" -#include "memilio/io/json_serializer.h" #include "memilio/utils/random_number_generator.h" TEST(TestWorld, init) @@ -560,71 +559,3 @@ TEST(TestWorld, checkParameterConstraints) params.get() = mio::abm::TimePoint(-2); ASSERT_EQ(params.check_constraints(), true); } - -#ifdef MEMILIO_HAS_JSONCPP - -TEST(TestWorld, abmTripJsonSerialization) -{ - mio::abm::Trip trip(0, mio::abm::TimePoint(0) + mio::abm::hours(8), 1, 2); - auto js = mio::serialize_json(trip, true); - Json::Value expected_json; - expected_json["person_id"] = Json::UInt(0); - expected_json["time"] = Json::Int(mio::abm::hours(8).seconds()); - expected_json["destination"] = Json::UInt(1); - expected_json["origin"] = Json::UInt(2); - ASSERT_EQ(js.value(), expected_json); - - auto r = mio::deserialize_json(expected_json, mio::Tag()); - ASSERT_THAT(print_wrap(r), IsSuccess()); - EXPECT_EQ(r.value(), trip); -} - -TEST(TestWorld, jsonSerialization) -{ - // Test that a json value x representing World is equal to serialize(deserialize(x)) w.r.t json representation - - // Assuming (de)serialization does not depend on specific values of member variables, and that deserialize is - // injective (meaning two instances with different values do not have the same json representation, which can - // happen e.g. if not all member variables are serialized), - // this sufficiently tests that serialize and deserialize are inverse functions to each other - - auto json_uint_array = [](std::vector values) { - return mio::serialize_json(values).value(); - }; - - unsigned i = 1; // counter s.t. members have different values - - // define a json value for a World - Json::Value reference_json; // aka x - reference_json["cemetery_id"] = Json::UInt(i++); - reference_json["location_types"] = Json::UInt(i++); - reference_json["locations"] = Json::Value(Json::arrayValue); - reference_json["parameters"] = mio::serialize_json(mio::abm::Parameters(i++)).value(); - reference_json["persons"] = Json::Value(Json::arrayValue); - reference_json["rng"]["counter"] = Json::UInt(i++); - reference_json["rng"]["key"] = Json::UInt(i++); - reference_json["rng"]["seeds"] = json_uint_array({i++, i++, i++, i++, i++, i++}); - reference_json["testing_strategy"]["schemes"] = Json::Value(Json::arrayValue); - reference_json["trip_list"]["index"] = Json::UInt(i++); - reference_json["trip_list"]["trips_weekday"] = Json::Value(Json::arrayValue); - reference_json["trip_list"]["trips_weekend"] = Json::Value(Json::arrayValue); - reference_json["use_migration_rules"] = Json::Value(true); - - // check that the json is deserializable (i.e. a valid representation) - auto r = mio::deserialize_json(reference_json, mio::Tag()); - ASSERT_THAT(print_wrap(r), IsSuccess()); - // check that the resulting Person is serializable - auto result = mio::serialize_json(r.value()); - ASSERT_TRUE(result.value()); - // write the resulting json value and the reference value to string to compare their representations. - Json::StreamWriterBuilder swb; - swb["indentation"] = " "; - auto js_writer = std::unique_ptr(swb.newStreamWriter()); - std::stringstream result_str, reference_str; - js_writer->write(reference_json, &reference_str); - js_writer->write(result.value(), &result_str); - // we compare strings here, as e.g. Json::Int(5) != Json::Uint(5), but their json representation is the same - EXPECT_EQ(result_str.str(), reference_str.str()); -} - -#endif // MEMILIO_HAS_JSONCPP From 2810cf9dd3993d9e89f6fef034871e11a102969a Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Tue, 23 Jul 2024 09:15:33 +0200 Subject: [PATCH 04/42] fix and cover testing_strategy serialization --- cpp/models/abm/testing_strategy.h | 17 +++++----- cpp/tests/test_abm_serialization.cpp | 46 ++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index 44270d855b..87b8aab454 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -188,8 +188,7 @@ class TestingScheme obj.add_element("min_time_since_last_test", m_minimal_time_since_last_test); obj.add_element("start_date", m_start_date); obj.add_element("end_date", m_end_date); - obj.add_element("test_type", - m_test_type.get_default()); // FIXME: m_test_type should contain TestParameters directly + obj.add_element("test_params", m_test_parameters); obj.add_element("probability", m_probability); obj.add_element("is_active", m_is_active); } @@ -206,18 +205,18 @@ class TestingScheme auto min_time_since_last_test = obj.expect_element("min_time_since_last_test", Tag{}); auto start_date = obj.expect_element("start_date", Tag{}); auto end_date = obj.expect_element("end_date", Tag{}); - auto test_type = obj.expect_element( - "test_type", Tag{}); // FIXME: m_test_type should contain TestParameters directly - auto probability = obj.expect_element("probability", Tag{}); - auto is_active = obj.expect_element("is_active", Tag{}); + auto test_params = obj.expect_element("test_params", Tag{}); + auto probability = obj.expect_element("probability", Tag{}); + auto is_active = obj.expect_element("is_active", Tag{}); return apply( io, [](auto&& criteria_, auto&& min_time_since_last_test_, auto&& start_date_, auto&& end_date_, - auto&& test_type_, auto&& probability_, auto&& is_active_) { + auto&& test_params_, auto&& probability_, auto&& is_active_) { return TestingScheme{ - criteria_, min_time_since_last_test_, start_date_, end_date_, test_type_, probability_, is_active_}; + criteria_, min_time_since_last_test_, start_date_, end_date_, test_params_, probability_, + is_active_}; }, - criteria, min_time_since_last_test, start_date, end_date, test_type, probability, is_active); + criteria, min_time_since_last_test, start_date, end_date, test_params, probability, is_active); } private: diff --git a/cpp/tests/test_abm_serialization.cpp b/cpp/tests/test_abm_serialization.cpp index a1f9eb1608..0c71d8165f 100644 --- a/cpp/tests/test_abm_serialization.cpp +++ b/cpp/tests/test_abm_serialization.cpp @@ -1,6 +1,9 @@ +#include "abm/parameters.h" +#include "abm/testing_strategy.h" #include "abm/vaccine.h" #include "matchers.h" #include "memilio/io/json_serializer.h" +#include "memilio/utils/uncertain_value.h" #include "models/abm/location.h" #include "models/abm/person.h" #include "models/abm/time.h" @@ -142,6 +145,49 @@ TEST(TestAbmSerialization, Infection) test_json_serialization_by_representation(reference_json); } +TEST(TestAbmSerialization, TestingScheme) +{ + // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. + // See test_json_serialization_by_representation for more detail. + + mio::abm::TestingScheme testing_scheme(mio::abm::TestingCriteria({}, {}), mio::abm::TimeSpan(1), + mio::abm::TimePoint(2), mio::abm::TimePoint(3), + mio::abm::TestParameters{{4.0}, {5.0}}, 6.0); + + Json::Value test_parameters; + test_parameters["sensitivity"] = mio::serialize_json(mio::UncertainValue{4.0}).value(); + test_parameters["specitivity"] = mio::serialize_json(mio::UncertainValue{5.0}).value(); + + Json::Value reference_json; // aka x + reference_json["criteria"] = Json::Value(Json::arrayValue); + reference_json["min_time_since_last_test"]["seconds"] = Json::UInt(1); + reference_json["start_date"]["seconds"] = Json::UInt(2); + reference_json["end_date"]["seconds"] = Json::UInt(3); + reference_json["test_params"] = test_parameters; + reference_json["probability"] = Json::Value((double)6); + reference_json["is_active"] = Json::Value((bool)0); + + test_json_serialization_full(testing_scheme, reference_json); +} + +TEST(TestAbmSerialization, TestingStrategy) +{ + // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. + // See test_json_serialization_by_representation for more detail. + + unsigned i = 1; // counter s.t. members have different values + + Json::Value local_strategy; + local_strategy["id"] = Json::UInt(i++); + local_strategy["schemes"] = Json::Value(Json::arrayValue); + local_strategy["type"] = Json::UInt(i++); + + Json::Value reference_json; // aka x + reference_json["schemes"][0] = local_strategy; + + test_json_serialization_by_representation(reference_json); +} + TEST(TestAbmSerialization, Person) { // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. From 64f566ee346764330de9570013cc76d94391f12c Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Tue, 23 Jul 2024 09:22:58 +0200 Subject: [PATCH 05/42] remove unwanted includes --- cpp/tests/test_abm_serialization.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/tests/test_abm_serialization.cpp b/cpp/tests/test_abm_serialization.cpp index 0c71d8165f..416ea6d70e 100644 --- a/cpp/tests/test_abm_serialization.cpp +++ b/cpp/tests/test_abm_serialization.cpp @@ -9,8 +9,6 @@ #include "models/abm/time.h" #include "models/abm/trip_list.h" #include "models/abm/world.h" -#include "json/config.h" -#include "json/value.h" #ifdef MEMILIO_HAS_JSONCPP From 2f7bc792852ba987bd0e51226a24f3f93b8078e1 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Tue, 23 Jul 2024 09:55:28 +0200 Subject: [PATCH 06/42] fix and cover testing_strategy serialization v2 --- cpp/models/abm/testing_strategy.h | 95 ++++++++++++++++------------ cpp/tests/test_abm_serialization.cpp | 55 +++++++++------- 2 files changed, 87 insertions(+), 63 deletions(-) diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index 87b8aab454..a71e84b25f 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -119,7 +119,10 @@ class TestingCriteria return apply( io, [](auto&& ages_, auto&& infection_states_) { - return TestingCriteria{ages_, infection_states_}; + TestingCriteria c; + c.m_ages = ages_; + c.m_infection_states = infection_states_; + return c; }, ages, infection_states); } @@ -176,50 +179,62 @@ class TestingScheme */ bool run_scheme(PersonalRandomNumberGenerator& rng, Person& person, TimePoint t) const; - /** - * serialize this. - * @see mio::serialize - */ - template - void serialize(IOContext& io) const - { - auto obj = io.create_object("TestingScheme"); - obj.add_element("criteria", m_testing_criteria); - obj.add_element("min_time_since_last_test", m_minimal_time_since_last_test); - obj.add_element("start_date", m_start_date); - obj.add_element("end_date", m_end_date); - obj.add_element("test_params", m_test_parameters); - obj.add_element("probability", m_probability); - obj.add_element("is_active", m_is_active); - } + // /** + // * serialize this. + // * @see mio::serialize + // */ + // template + // void serialize(IOContext& io) const + // { + // auto obj = io.create_object("TestingScheme"); + // obj.add_element("criteria", m_testing_criteria), ; + // obj.add_element("min_time_since_last_test", m_minimal_time_since_last_test), ; + // obj.add_element("start_date", m_start_date), ; + // obj.add_element("end_date", m_end_date), ; + // obj.add_element("test_params", m_test_parameters), ; + // obj.add_element("probability", m_probability), ; + // obj.add_element("is_active", m_is_active), ; + // } - /** - * deserialize an object of this class. - * @see mio::deserialize - */ - template - static IOResult deserialize(IOContext& io) + auto auto_serialize() { - auto obj = io.expect_object("TestingScheme"); - auto criteria = obj.expect_element("criteria", Tag{}); - auto min_time_since_last_test = obj.expect_element("min_time_since_last_test", Tag{}); - auto start_date = obj.expect_element("start_date", Tag{}); - auto end_date = obj.expect_element("end_date", Tag{}); - auto test_params = obj.expect_element("test_params", Tag{}); - auto probability = obj.expect_element("probability", Tag{}); - auto is_active = obj.expect_element("is_active", Tag{}); - return apply( - io, - [](auto&& criteria_, auto&& min_time_since_last_test_, auto&& start_date_, auto&& end_date_, - auto&& test_params_, auto&& probability_, auto&& is_active_) { - return TestingScheme{ - criteria_, min_time_since_last_test_, start_date_, end_date_, test_params_, probability_, - is_active_}; - }, - criteria, min_time_since_last_test, start_date, end_date, test_params, probability, is_active); + return make_auto_serialization("TestingScheme", NVP("criteria", m_testing_criteria), + NVP("min_time_since_last_test", m_minimal_time_since_last_test), + NVP("start_date", m_start_date), NVP("end_date", m_end_date), + NVP("test_params", m_test_parameters), NVP("probability", m_probability), + NVP("is_active", m_is_active)); } + // /** + // * deserialize an object of this class. + // * @see mio::deserialize + // */ + // template + // static IOResult deserialize(IOContext& io) + // { + // auto obj = io.expect_object("TestingScheme"); + // auto criteria = obj.expect_element("criteria", Tag{}); + // auto min_time_since_last_test = obj.expect_element("min_time_since_last_test", Tag{}); + // auto start_date = obj.expect_element("start_date", Tag{}); + // auto end_date = obj.expect_element("end_date", Tag{}); + // auto test_params = obj.expect_element("test_params", Tag{}); + // auto probability = obj.expect_element("probability", Tag{}); + // auto is_active = obj.expect_element("is_active", Tag{}); + // return apply( + // io, + // [](auto&& criteria_, auto&& min_time_since_last_test_, auto&& start_date_, auto&& end_date_, + // auto&& test_params_, auto&& probability_, auto&& is_active_) { + // return TestingScheme{ + // criteria_, min_time_since_last_test_, start_date_, end_date_, test_params_, probability_, + // is_active_}; + // }, + // criteria, min_time_since_last_test, start_date, end_date, test_params, probability, is_active); + // } + private: + friend AutoSerializableFactory; + TestingScheme() = default; + TestingCriteria m_testing_criteria; ///< TestingCriteria of the scheme. TimeSpan m_minimal_time_since_last_test; ///< Shortest period of time between two tests. TimePoint m_start_date; ///< Starting date of the scheme. diff --git a/cpp/tests/test_abm_serialization.cpp b/cpp/tests/test_abm_serialization.cpp index 416ea6d70e..667ba293ee 100644 --- a/cpp/tests/test_abm_serialization.cpp +++ b/cpp/tests/test_abm_serialization.cpp @@ -1,7 +1,9 @@ +#include "abm/infection_state.h" #include "abm/parameters.h" #include "abm/testing_strategy.h" #include "abm/vaccine.h" #include "matchers.h" +#include "memilio/epidemiology/age_group.h" #include "memilio/io/json_serializer.h" #include "memilio/utils/uncertain_value.h" #include "models/abm/location.h" @@ -48,7 +50,7 @@ void test_json_serialization_by_representation(const Json::Value& reference_json // check that the resulting type T is serializable auto json_result = mio::serialize_json(t_result.value()); - ASSERT_TRUE(json_result); + ASSERT_THAT(print_wrap(json_result), IsSuccess()); test_equal_json_representation(json_result.value(), reference_json); } @@ -70,7 +72,7 @@ void test_json_serialization_full(const T& reference_object, const Json::Value& { // check that the reference type T is serializable auto json_result = mio::serialize_json(reference_object); - ASSERT_TRUE(json_result); + ASSERT_THAT(print_wrap(json_result), IsSuccess()); // check that the reference json is deserializable auto t_result = mio::deserialize_json(reference_json, mio::Tag()); @@ -82,7 +84,7 @@ void test_json_serialization_full(const T& reference_object, const Json::Value& // do the same once more using the results from above auto json_result_2 = mio::serialize_json(t_result.value()); - ASSERT_TRUE(json_result_2); + ASSERT_THAT(print_wrap(json_result_2), IsSuccess()); auto t_result_2 = mio::deserialize_json(json_result.value(), mio::Tag()); ASSERT_THAT(print_wrap(t_result_2), IsSuccess()); @@ -148,21 +150,25 @@ TEST(TestAbmSerialization, TestingScheme) // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. // See test_json_serialization_by_representation for more detail. - mio::abm::TestingScheme testing_scheme(mio::abm::TestingCriteria({}, {}), mio::abm::TimeSpan(1), - mio::abm::TimePoint(2), mio::abm::TimePoint(3), - mio::abm::TestParameters{{4.0}, {5.0}}, 6.0); + mio::abm::TestingScheme testing_scheme(mio::abm::TestingCriteria({mio::AgeGroup(1)}, {mio::abm::InfectionState(2)}), + mio::abm::TimeSpan(3), mio::abm::TimePoint(4), mio::abm::TimePoint(5), + mio::abm::TestParameters{{6.0}, {7.0}}, 8.0); + + Json::Value testing_criteria; + testing_criteria["ages"] = Json::UInt(1 << 1); + testing_criteria["infection_states"] = Json::UInt(1 << 2); Json::Value test_parameters; - test_parameters["sensitivity"] = mio::serialize_json(mio::UncertainValue{4.0}).value(); - test_parameters["specitivity"] = mio::serialize_json(mio::UncertainValue{5.0}).value(); + test_parameters["sensitivity"] = mio::serialize_json(mio::UncertainValue{6.0}).value(); + test_parameters["specificity"] = mio::serialize_json(mio::UncertainValue{7.0}).value(); Json::Value reference_json; // aka x - reference_json["criteria"] = Json::Value(Json::arrayValue); - reference_json["min_time_since_last_test"]["seconds"] = Json::UInt(1); - reference_json["start_date"]["seconds"] = Json::UInt(2); - reference_json["end_date"]["seconds"] = Json::UInt(3); + reference_json["criteria"] = testing_criteria; + reference_json["min_time_since_last_test"]["seconds"] = Json::UInt(3); + reference_json["start_date"]["seconds"] = Json::UInt(4); + reference_json["end_date"]["seconds"] = Json::UInt(5); reference_json["test_params"] = test_parameters; - reference_json["probability"] = Json::Value((double)6); + reference_json["probability"] = Json::Value((double)8); reference_json["is_active"] = Json::Value((bool)0); test_json_serialization_full(testing_scheme, reference_json); @@ -183,7 +189,7 @@ TEST(TestAbmSerialization, TestingStrategy) Json::Value reference_json; // aka x reference_json["schemes"][0] = local_strategy; - test_json_serialization_by_representation(reference_json); + test_json_serialization_by_representation(reference_json); } TEST(TestAbmSerialization, Person) @@ -234,15 +240,16 @@ TEST(TestAbmSerialization, Location) unsigned i = 1; // counter s.t. members have different values + Json::Value contact_rates = mio::serialize_json(mio::abm::ContactRates::get_default(i++)).value(); + Json::Value reference_json; // aka x - reference_json["cells"][0]["capacity"]["persons"] = Json::UInt(i++); - reference_json["cells"][0]["capacity"]["volume"] = Json::UInt(i++); - reference_json["geographical_location"]["latitude"] = Json::Value((double)i++); - reference_json["geographical_location"]["longitude"] = Json::Value((double)i++); - reference_json["id"] = Json::UInt(i++); - reference_json["npi_active"] = Json::Value(false); - reference_json["parameters"]["ContactRates"] = - mio::serialize_json(mio::abm::ContactRates::get_default(i++)).value(); + reference_json["cells"][0]["capacity"]["persons"] = Json::UInt(i++); + reference_json["cells"][0]["capacity"]["volume"] = Json::UInt(i++); + reference_json["geographical_location"]["latitude"] = Json::Value((double)i++); + reference_json["geographical_location"]["longitude"] = Json::Value((double)i++); + reference_json["id"] = Json::UInt(i++); + reference_json["npi_active"] = Json::Value(false); + reference_json["parameters"]["ContactRates"] = contact_rates; reference_json["parameters"]["MaximumContacts"] = Json::Value((double)i++); reference_json["parameters"]["UseLocationCapacityForTransmissions"] = Json::Value(false); reference_json["required_mask"] = Json::UInt(0); @@ -261,11 +268,13 @@ TEST(TestAbmSerialization, World) unsigned i = 1; // counter s.t. members have different values + Json::Value abm_parameters = mio::serialize_json(mio::abm::Parameters(i++)).value(); + Json::Value reference_json; // aka x reference_json["cemetery_id"] = Json::UInt(i++); reference_json["location_types"] = Json::UInt(i++); reference_json["locations"] = Json::Value(Json::arrayValue); - reference_json["parameters"] = mio::serialize_json(mio::abm::Parameters(i++)).value(); + reference_json["parameters"] = abm_parameters; reference_json["persons"] = Json::Value(Json::arrayValue); reference_json["rng"]["counter"] = Json::UInt(i++); reference_json["rng"]["key"] = Json::UInt(i++); From a1832ca457a99aac14c70ae8b9a93911988b09da Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Tue, 23 Jul 2024 14:43:40 +0200 Subject: [PATCH 07/42] add TDPF tests --- .../math/time_dependent_parameter_functor.h | 12 +-- cpp/tests/CMakeLists.txt | 1 + cpp/tests/test_math_tdpf.cpp | 76 +++++++++++++++++++ 3 files changed, 84 insertions(+), 5 deletions(-) create mode 100644 cpp/tests/test_math_tdpf.cpp diff --git a/cpp/memilio/math/time_dependent_parameter_functor.h b/cpp/memilio/math/time_dependent_parameter_functor.h index 3b442df8fa..1c4afee1a7 100644 --- a/cpp/memilio/math/time_dependent_parameter_functor.h +++ b/cpp/memilio/math/time_dependent_parameter_functor.h @@ -40,7 +40,7 @@ class TimeDependentParameterFunctor }; using DataType = std::vector>; - TimeDependentParameterFunctor(Type type, DataType data) + TimeDependentParameterFunctor(Type type, const DataType& data) : m_type(type) , m_data(data) { @@ -69,9 +69,11 @@ class TimeDependentParameterFunctor ScalarType operator()(ScalarType time) const { + ScalarType value = 0.0; switch (m_type) { case Type::Zero: - return 0.0; + // value is explicitly zero-initialized + break; case Type::LinearInterpolation: // find next time point in m_data (strictly) after time const auto next_tp = std::upper_bound(m_data.begin(), m_data.end(), time, [](auto&& t, auto&& tp) { @@ -84,10 +86,10 @@ class TimeDependentParameterFunctor return m_data.back()[1]; } const auto tp = next_tp - 1; - return linear_interpolation(time, (*tp)[0], (*next_tp)[0], (*tp)[1], (*next_tp)[1]); + value = linear_interpolation(time, (*tp)[0], (*next_tp)[0], (*tp)[1], (*next_tp)[1]); + break; } - - return 0.0; // should be unreachable, but without this the compiler may complain about a missing return. + return value; } /// This method is used by the auto-serialization feature. diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 9fe42496ea..b2a53a202a 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -40,6 +40,7 @@ set(TESTSOURCES test_abm_testing_strategy.cpp test_abm_world.cpp test_math_floating_point.cpp + test_math_tdpf.cpp test_analyze_result.cpp test_contact_matrix.cpp test_type_safe.cpp diff --git a/cpp/tests/test_math_tdpf.cpp b/cpp/tests/test_math_tdpf.cpp new file mode 100644 index 0000000000..953d07fa45 --- /dev/null +++ b/cpp/tests/test_math_tdpf.cpp @@ -0,0 +1,76 @@ +#include "memilio/math/time_dependent_parameter_functor.h" +#include "memilio/utils/random_number_generator.h" + +#include + +#include +#include + +class TestMathTdpf : public ::testing::Test +{ +public: + const int num_evals = 1000; + + double fuzzy_val(double min, double max) + { + return mio::UniformDistribution::get_instance()(m_rng, min, max); + } + +protected: + void SetUp() override + { + log_rng_seeds(m_rng, mio::LogLevel::warn); + } + +private: + mio::RandomNumberGenerator m_rng{}; +}; + +TEST_F(TestMathTdpf, zero) +{ + // Test that the Zero-TDPF always returns zero, using a random evaluation point. + + // initialize + mio::TimeDependentParameterFunctor tdpf; + + // verify output + for (int i = 0; i < this->num_evals; i++) { + auto random_t_eval = fuzzy_val(-std::numeric_limits::max(), std::numeric_limits::max()); + EXPECT_EQ(tdpf(random_t_eval), 0.0); + } +} + +TEST_F(TestMathTdpf, linearInterpolation) +{ + // Test that the LinearInterpolation-TDPF correctly reproduces a (piecewise) linear function, using random samples. + // Since the initialization uses unsorted data, this also checks that the data gets sorted + + const double min = -1e+3, max = 1e+3; // reasonably large values for lin_fct height and slopes + const double t_min = -1, t_max = 1, t_mid = fuzzy_val(t_min, t_max); + const double slope1 = fuzzy_val(min, max), slope2 = fuzzy_val(min, max), height = fuzzy_val(min, max); + + const auto pcw_lin_fct = [&](double t) { + // continuous function with different slopes between t_min, t_mid and t_max, constant otherwise + return height + slope1 * std::clamp(t - t_min, 0.0, t_mid - t_min) + + slope2 * std::clamp(t - t_mid, 0.0, t_max - t_mid); + }; + + // initialize the data with the critical points + std::vector> unsorted_data{ + {t_max, pcw_lin_fct(t_max)}, {t_min, pcw_lin_fct(t_min)}, {t_mid, pcw_lin_fct(t_mid)}}; + // randomly add a few more evaluations in between + for (int i = 0; i < 10; i++) { + const double t = fuzzy_val(-1.0, 1.0); + unsorted_data.push_back({t, pcw_lin_fct(t)}); + } + + mio::TimeDependentParameterFunctor tdpf(mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + unsorted_data); + + // verify output + for (int i = 0; i < this->num_evals; i++) { + // sample in the interval [t_min - (t_max - t_min) / 4, t_max + (t_max - tmin) / 4] + double random_t_eval = fuzzy_val(1.25 * t_min - 0.25 * t_max, 1.25 * t_max - 0.25 * t_min); + EXPECT_NEAR(tdpf(random_t_eval), pcw_lin_fct(random_t_eval), 1e-10); + } +} From ed0ec32ebc53b51de12a47ce5097a5fc3f140c1d Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Tue, 23 Jul 2024 14:56:25 +0200 Subject: [PATCH 08/42] fix logging level accross tests --- cpp/tests/test_abm_infection.cpp | 3 -- cpp/tests/test_abm_location.cpp | 2 - cpp/tests/test_abm_person.cpp | 2 - cpp/tests/test_abm_world.cpp | 79 ++++++++++++++++---------------- cpp/tests/test_flows.cpp | 14 +++--- cpp/tests/test_odesecir.cpp | 1 + 6 files changed, 49 insertions(+), 52 deletions(-) diff --git a/cpp/tests/test_abm_infection.cpp b/cpp/tests/test_abm_infection.cpp index 39a050bd9b..1e060afee7 100644 --- a/cpp/tests/test_abm_infection.cpp +++ b/cpp/tests/test_abm_infection.cpp @@ -21,7 +21,6 @@ #include "abm/location_type.h" #include "abm/person.h" #include "abm_helpers.h" -#include "memilio/math/interpolation.h" #include "memilio/utils/random_number_generator.h" #include "abm_helpers.h" @@ -190,14 +189,12 @@ TEST(TestInfection, getPersonalProtectiveFactor) EXPECT_NEAR(defaut_severity_protection, 0, eps); // Test linear interpolation with one node - // mio::set_log_level(mio::LogLevel::critical); //this throws an error either way params.get()[{mio::abm::ExposureType::GenericVaccine, person.get_age(), mio::abm::VirusVariant::Wildtype}] = mio::TimeDependentParameterFunctor{mio::TimeDependentParameterFunctor::Type::LinearInterpolation, {{2, 0.91}}}; auto t = mio::abm::TimePoint(6 * 24 * 60 * 60); // TODO: Discuss: Assumption of interpolation in TDPF is that the function is constant with value at front/back entry outside of [front, back] time range. This works with one node as well and prints no errors EXPECT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0.91, eps); - // mio::set_log_level(mio::LogLevel::warn); //this throws an error either way params.get()[{mio::abm::ExposureType::GenericVaccine, person.get_age(), mio::abm::VirusVariant::Wildtype}] = mio::TimeDependentParameterFunctor{mio::TimeDependentParameterFunctor::Type::LinearInterpolation, diff --git a/cpp/tests/test_abm_location.cpp b/cpp/tests/test_abm_location.cpp index fbba080a73..2e95746ca9 100644 --- a/cpp/tests/test_abm_location.cpp +++ b/cpp/tests/test_abm_location.cpp @@ -23,8 +23,6 @@ #include "abm/person.h" #include "abm/world.h" #include "abm_helpers.h" -#include "matchers.h" -#include "memilio/io/json_serializer.h" #include "memilio/utils/random_number_generator.h" TEST(TestLocation, initCell) diff --git a/cpp/tests/test_abm_person.cpp b/cpp/tests/test_abm_person.cpp index d80a30aafd..625c3080fb 100644 --- a/cpp/tests/test_abm_person.cpp +++ b/cpp/tests/test_abm_person.cpp @@ -24,8 +24,6 @@ #include "abm/person.h" #include "abm/time.h" #include "abm_helpers.h" -#include "matchers.h" -#include "memilio/io/json_serializer.h" #include "memilio/utils/random_number_generator.h" #include diff --git a/cpp/tests/test_abm_world.cpp b/cpp/tests/test_abm_world.cpp index 79eb88e0ce..ce7c7a2b08 100644 --- a/cpp/tests/test_abm_world.cpp +++ b/cpp/tests/test_abm_world.cpp @@ -39,8 +39,8 @@ TEST(TestWorld, addLocation) auto work_id = world.add_location(mio::abm::LocationType::Work); auto home_id = world.add_location(mio::abm::LocationType::Home); - ASSERT_EQ(school_id1.get(), 1u); - ASSERT_EQ(school_id2.get(), 2u); + EXPECT_EQ(school_id1.get(), 1u); + EXPECT_EQ(school_id2.get(), 2u); auto& school1 = world.get_location(school_id1); auto& school2 = world.get_location(school_id2); @@ -53,12 +53,12 @@ TEST(TestWorld, addLocation) count_schools++; } } - ASSERT_EQ(count_schools, 2); + EXPECT_EQ(count_schools, 2); - ASSERT_EQ(world.get_locations()[1], school1); - ASSERT_EQ(world.get_locations()[2], school2); - ASSERT_EQ(world.get_locations()[3], work); - ASSERT_EQ(world.get_locations()[4], home); + EXPECT_EQ(world.get_locations()[1], school1); + EXPECT_EQ(world.get_locations()[2], school2); + EXPECT_EQ(world.get_locations()[3], work); + EXPECT_EQ(world.get_locations()[4], home); } TEST(TestWorld, addPerson) @@ -69,9 +69,9 @@ TEST(TestWorld, addPerson) world.add_person(location, age_group_15_to_34); world.add_person(location, age_group_35_to_59); - ASSERT_EQ(world.get_persons().size(), 2); - ASSERT_EQ(world.get_person(0).get_age(), age_group_15_to_34); - ASSERT_EQ(world.get_person(1).get_age(), age_group_35_to_59); + EXPECT_EQ(world.get_persons().size(), 2); + EXPECT_EQ(world.get_person(0).get_age(), age_group_15_to_34); + EXPECT_EQ(world.get_person(1).get_age(), age_group_35_to_59); } TEST(TestWorld, getSubpopulationCombined) @@ -89,13 +89,13 @@ TEST(TestWorld, getSubpopulationCombined) add_test_person(world, school3, age_group_15_to_34, mio::abm::InfectionState::InfectedNoSymptoms); add_test_person(world, home1, age_group_15_to_34, mio::abm::InfectionState::InfectedNoSymptoms); - ASSERT_EQ(world.get_subpopulation_combined_per_location_type(t, mio::abm::InfectionState::Susceptible, + EXPECT_EQ(world.get_subpopulation_combined_per_location_type(t, mio::abm::InfectionState::Susceptible, mio::abm::LocationType::School), 3); - ASSERT_EQ(world.get_subpopulation_combined_per_location_type(t, mio::abm::InfectionState::InfectedNoSymptoms, + EXPECT_EQ(world.get_subpopulation_combined_per_location_type(t, mio::abm::InfectionState::InfectedNoSymptoms, mio::abm::LocationType::School), 2); - ASSERT_EQ(world.get_subpopulation_combined(t, mio::abm::InfectionState::InfectedNoSymptoms), 3); + EXPECT_EQ(world.get_subpopulation_combined(t, mio::abm::InfectionState::InfectedNoSymptoms), 3); } TEST(TestWorld, findLocation) @@ -448,13 +448,13 @@ TEST(TestWorldTestingCriteria, testAddingAndUpdatingAndRunningTestingSchemes) const auto start_date = mio::abm::TimePoint(20); const auto end_date = mio::abm::TimePoint(60 * 60 * 24 * 3); const auto probability = 1.0; - const auto test_params_pcr = mio::abm::TestParameters{0.9, 0.99}; + const auto test_params_pcr = mio::abm::TestParameters{0.9, 0.99}; - auto testing_scheme = - mio::abm::TestingScheme(testing_criteria, testing_frequency, start_date, end_date, test_params_pcr, probability); + auto testing_scheme = mio::abm::TestingScheme(testing_criteria, testing_frequency, start_date, end_date, + test_params_pcr, probability); world.get_testing_strategy().add_testing_scheme(mio::abm::LocationType::Work, testing_scheme); - ASSERT_EQ(world.get_testing_strategy().run_strategy(rng_person, person, work, current_time), + EXPECT_EQ(world.get_testing_strategy().run_strategy(rng_person, person, work, current_time), true); // no active testing scheme -> person can enter current_time = mio::abm::TimePoint(30); world.get_testing_strategy().update_activity_status(current_time); @@ -463,12 +463,12 @@ TEST(TestWorldTestingCriteria, testAddingAndUpdatingAndRunningTestingSchemes) .Times(testing::AtLeast(2)) .WillOnce(testing::Return(0.7)) .WillOnce(testing::Return(0.4)); - ASSERT_EQ(world.get_testing_strategy().run_strategy(rng_person, person, work, current_time), false); + EXPECT_EQ(world.get_testing_strategy().run_strategy(rng_person, person, work, current_time), false); world.get_testing_strategy().add_testing_scheme(mio::abm::LocationType::Work, testing_scheme); //doesn't get added because of == operator world.get_testing_strategy().remove_testing_scheme(mio::abm::LocationType::Work, testing_scheme); - ASSERT_EQ(world.get_testing_strategy().run_strategy(rng_person, person, work, current_time), + EXPECT_EQ(world.get_testing_strategy().run_strategy(rng_person, person, work, current_time), true); // no more testing_schemes } @@ -497,65 +497,66 @@ TEST(TestWorld, checkParameterConstraints) params.get()[mio::abm::MaskType::FFP2] = 0.6; params.get()[mio::abm::MaskType::Surgical] = 0.7; params.get() = mio::abm::TimePoint(0); - ASSERT_EQ(params.check_constraints(), false); + EXPECT_EQ(params.check_constraints(), false); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -1.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 1.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -2.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 2.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -3.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 3.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -4.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 4.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -5.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 5.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -6.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 6.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -7.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 7.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -8.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 8.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -9.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 9.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -10.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 10.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 1.1; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.3; params.get()[age_group_35_to_59] = mio::abm::hours(30); - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[age_group_35_to_59] = mio::abm::hours(4); params.get()[age_group_35_to_59] = mio::abm::hours(30); - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[age_group_35_to_59] = mio::abm::hours(8); params.get()[age_group_0_to_4] = mio::abm::hours(30); - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[age_group_0_to_4] = mio::abm::hours(3); params.get()[age_group_0_to_4] = mio::abm::hours(30); - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[age_group_0_to_4] = mio::abm::hours(6); params.get()[mio::abm::MaskType::Community] = 1.2; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[mio::abm::MaskType::Community] = 0.5; params.get()[mio::abm::MaskType::FFP2] = 1.2; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[mio::abm::MaskType::FFP2] = 0.6; params.get()[mio::abm::MaskType::Surgical] = 1.2; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[mio::abm::MaskType::Surgical] = 0.7; params.get() = mio::abm::TimePoint(-2); - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); + mio::set_log_level(mio::LogLevel::warn); } diff --git a/cpp/tests/test_flows.cpp b/cpp/tests/test_flows.cpp index ca74b1db61..e86df429c4 100644 --- a/cpp/tests/test_flows.cpp +++ b/cpp/tests/test_flows.cpp @@ -96,8 +96,6 @@ TEST(TestFlows, FlowChart) TEST(TestFlows, FlowSimulation) { - mio::set_log_level(mio::LogLevel::off); - double t0 = 0; double tmax = 1; double dt = 0.001; @@ -119,9 +117,12 @@ TEST(TestFlows, FlowSimulation) model.parameters.set>(0.04); model.parameters.get>().get_cont_freq_mat()[0].get_baseline().setConstant(10); + mio::set_log_level(mio::LogLevel::off); model.check_constraints(); auto IC = std::make_shared>(); auto seir = mio::simulate_flows>(t0, tmax, dt, model, IC); + mio::set_log_level(mio::LogLevel::warn); + // verify results (computed using flows) auto results = seir[0].get_last_value(); EXPECT_NEAR(results[0], 9660.5835936179408, 1e-14); @@ -137,8 +138,6 @@ TEST(TestFlows, FlowSimulation) TEST(TestFlows, CompareSimulations) { - mio::set_log_level(mio::LogLevel::off); - double t0 = 0; double tmax = 1; double dt = 0.001; @@ -162,11 +161,14 @@ TEST(TestFlows, CompareSimulations) model.parameters.get>().get_cont_freq_mat(); contact_matrix[0].get_baseline().setConstant(10); + mio::set_log_level(mio::LogLevel::off); model.check_constraints(); auto seir_sim_flows = simulate_flows(t0, tmax, dt, model); auto seir_sim = simulate(t0, tmax, dt, model); - auto results_flows = seir_sim_flows[0].get_last_value(); - auto results = seir_sim.get_last_value(); + mio::set_log_level(mio::LogLevel::warn); + + auto results_flows = seir_sim_flows[0].get_last_value(); + auto results = seir_sim.get_last_value(); EXPECT_NEAR(results[0], results_flows[0], 1e-10); EXPECT_NEAR(results[1], results_flows[1], 1e-10); diff --git a/cpp/tests/test_odesecir.cpp b/cpp/tests/test_odesecir.cpp index ef10f5b7ca..341296f289 100644 --- a/cpp/tests/test_odesecir.cpp +++ b/cpp/tests/test_odesecir.cpp @@ -728,6 +728,7 @@ TEST(TestOdeSecir, testModelConstraints) EXPECT_LE(secihurd.get_value(i)[5], 9000) << " at row " << i; } } + mio::set_log_level(mio::LogLevel::warn); } TEST(Secir, testAndTraceCapacity) From 9ee8ae4a10dbfbbfd09decab1101e192233655cc Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Tue, 23 Jul 2024 15:53:07 +0200 Subject: [PATCH 09/42] improve TDPF implementation --- cpp/memilio/math/time_dependent_parameter_functor.h | 12 +++++++----- cpp/tests/test_math_tdpf.cpp | 11 ++++++----- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/cpp/memilio/math/time_dependent_parameter_functor.h b/cpp/memilio/math/time_dependent_parameter_functor.h index 1c4afee1a7..675207cad9 100644 --- a/cpp/memilio/math/time_dependent_parameter_functor.h +++ b/cpp/memilio/math/time_dependent_parameter_functor.h @@ -80,13 +80,15 @@ class TimeDependentParameterFunctor return t < tp[0]; }); if (next_tp == m_data.begin()) { // time is before first data point - return m_data.front()[1]; + value = m_data.front()[1]; } - if (next_tp == m_data.end()) { // time is past last data point - return m_data.back()[1]; + else if (next_tp == m_data.end()) { // time is past last data point + value = m_data.back()[1]; + } + else { // time is in between data points + const auto tp = next_tp - 1; + value = linear_interpolation(time, (*tp)[0], (*next_tp)[0], (*tp)[1], (*next_tp)[1]); } - const auto tp = next_tp - 1; - value = linear_interpolation(time, (*tp)[0], (*next_tp)[0], (*tp)[1], (*next_tp)[1]); break; } return value; diff --git a/cpp/tests/test_math_tdpf.cpp b/cpp/tests/test_math_tdpf.cpp index 953d07fa45..3399c80f7a 100644 --- a/cpp/tests/test_math_tdpf.cpp +++ b/cpp/tests/test_math_tdpf.cpp @@ -35,7 +35,7 @@ TEST_F(TestMathTdpf, zero) // verify output for (int i = 0; i < this->num_evals; i++) { - auto random_t_eval = fuzzy_val(-std::numeric_limits::max(), std::numeric_limits::max()); + auto random_t_eval = this->fuzzy_val(-std::numeric_limits::max(), std::numeric_limits::max()); EXPECT_EQ(tdpf(random_t_eval), 0.0); } } @@ -46,8 +46,9 @@ TEST_F(TestMathTdpf, linearInterpolation) // Since the initialization uses unsorted data, this also checks that the data gets sorted const double min = -1e+3, max = 1e+3; // reasonably large values for lin_fct height and slopes - const double t_min = -1, t_max = 1, t_mid = fuzzy_val(t_min, t_max); - const double slope1 = fuzzy_val(min, max), slope2 = fuzzy_val(min, max), height = fuzzy_val(min, max); + const double t_min = -1, t_max = 1, t_mid = this->fuzzy_val(t_min, t_max); + const double slope1 = this->fuzzy_val(min, max), slope2 = this->fuzzy_val(min, max), + height = this->fuzzy_val(min, max); const auto pcw_lin_fct = [&](double t) { // continuous function with different slopes between t_min, t_mid and t_max, constant otherwise @@ -60,7 +61,7 @@ TEST_F(TestMathTdpf, linearInterpolation) {t_max, pcw_lin_fct(t_max)}, {t_min, pcw_lin_fct(t_min)}, {t_mid, pcw_lin_fct(t_mid)}}; // randomly add a few more evaluations in between for (int i = 0; i < 10; i++) { - const double t = fuzzy_val(-1.0, 1.0); + const double t = this->fuzzy_val(-1.0, 1.0); unsorted_data.push_back({t, pcw_lin_fct(t)}); } @@ -70,7 +71,7 @@ TEST_F(TestMathTdpf, linearInterpolation) // verify output for (int i = 0; i < this->num_evals; i++) { // sample in the interval [t_min - (t_max - t_min) / 4, t_max + (t_max - tmin) / 4] - double random_t_eval = fuzzy_val(1.25 * t_min - 0.25 * t_max, 1.25 * t_max - 0.25 * t_min); + double random_t_eval = this->fuzzy_val(1.25 * t_min - 0.25 * t_max, 1.25 * t_max - 0.25 * t_min); EXPECT_NEAR(tdpf(random_t_eval), pcw_lin_fct(random_t_eval), 1e-10); } } From dd88980877cce5cc784389848c5d286302d9543e Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Tue, 23 Jul 2024 16:26:51 +0200 Subject: [PATCH 10/42] debug CI --- cpp/tests/test_math_tdpf.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/cpp/tests/test_math_tdpf.cpp b/cpp/tests/test_math_tdpf.cpp index 3399c80f7a..96b34287b0 100644 --- a/cpp/tests/test_math_tdpf.cpp +++ b/cpp/tests/test_math_tdpf.cpp @@ -1,4 +1,5 @@ #include "memilio/math/time_dependent_parameter_functor.h" +#include "memilio/utils/logging.h" #include "memilio/utils/random_number_generator.h" #include @@ -29,15 +30,19 @@ class TestMathTdpf : public ::testing::Test TEST_F(TestMathTdpf, zero) { // Test that the Zero-TDPF always returns zero, using a random evaluation point. - + mio::set_log_level(mio::LogLevel::trace); + mio::log_info("entered test"); // initialize mio::TimeDependentParameterFunctor tdpf; + mio::log_info("created tdpf"); // verify output for (int i = 0; i < this->num_evals; i++) { - auto random_t_eval = this->fuzzy_val(-std::numeric_limits::max(), std::numeric_limits::max()); + double random_t_eval = this->fuzzy_val(-std::numeric_limits::max(), std::numeric_limits::max()); + mio::log_debug("testing {}", random_t_eval); EXPECT_EQ(tdpf(random_t_eval), 0.0); } + mio::log_info("finished test"); } TEST_F(TestMathTdpf, linearInterpolation) From 7310b2edebf8cc7448523467c199e3c6e768a8a9 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:02:32 +0200 Subject: [PATCH 11/42] CI --- .github/workflows/main.yml | 6 +++--- cpp/tests/test_math_tdpf.cpp | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index de1e66fc3c..b1889079a6 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -28,7 +28,7 @@ jobs: - uses: pre-commit/action@v3.0.1 build-cpp-gcc_clang: - if: github.event.pull_request.draft == false + if: false strategy: matrix: compiler: ["gcc", "clang"] @@ -51,7 +51,7 @@ jobs: sanitizers: ${{ (matrix.compiler == 'gcc' && matrix.config == 'Debug' && matrix.version == 'latest') && 'ON' || 'OFF' }} build-cpp-gcc-no-optional-deps: - if: github.event.pull_request.draft == false + if: false runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -63,7 +63,7 @@ jobs: optional-dependencies: OFF build-cpp-gcc-openmp: - if: github.event.pull_request.draft == false + if: false runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/cpp/tests/test_math_tdpf.cpp b/cpp/tests/test_math_tdpf.cpp index 96b34287b0..f882061435 100644 --- a/cpp/tests/test_math_tdpf.cpp +++ b/cpp/tests/test_math_tdpf.cpp @@ -37,8 +37,8 @@ TEST_F(TestMathTdpf, zero) mio::log_info("created tdpf"); // verify output - for (int i = 0; i < this->num_evals; i++) { - double random_t_eval = this->fuzzy_val(-std::numeric_limits::max(), std::numeric_limits::max()); + for (int i = 0; i < 10; i++) { + double random_t_eval = this->fuzzy_val(-1e+5, 1e+5); mio::log_debug("testing {}", random_t_eval); EXPECT_EQ(tdpf(random_t_eval), 0.0); } From 6e99e0811ff722dfae4be79cc39a320037011448 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:34:02 +0200 Subject: [PATCH 12/42] Revert "CI" This reverts commit 7310b2edebf8cc7448523467c199e3c6e768a8a9. --- .github/workflows/main.yml | 6 +++--- cpp/tests/test_math_tdpf.cpp | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b1889079a6..de1e66fc3c 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -28,7 +28,7 @@ jobs: - uses: pre-commit/action@v3.0.1 build-cpp-gcc_clang: - if: false + if: github.event.pull_request.draft == false strategy: matrix: compiler: ["gcc", "clang"] @@ -51,7 +51,7 @@ jobs: sanitizers: ${{ (matrix.compiler == 'gcc' && matrix.config == 'Debug' && matrix.version == 'latest') && 'ON' || 'OFF' }} build-cpp-gcc-no-optional-deps: - if: false + if: github.event.pull_request.draft == false runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -63,7 +63,7 @@ jobs: optional-dependencies: OFF build-cpp-gcc-openmp: - if: false + if: github.event.pull_request.draft == false runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/cpp/tests/test_math_tdpf.cpp b/cpp/tests/test_math_tdpf.cpp index f882061435..96b34287b0 100644 --- a/cpp/tests/test_math_tdpf.cpp +++ b/cpp/tests/test_math_tdpf.cpp @@ -37,8 +37,8 @@ TEST_F(TestMathTdpf, zero) mio::log_info("created tdpf"); // verify output - for (int i = 0; i < 10; i++) { - double random_t_eval = this->fuzzy_val(-1e+5, 1e+5); + for (int i = 0; i < this->num_evals; i++) { + double random_t_eval = this->fuzzy_val(-std::numeric_limits::max(), std::numeric_limits::max()); mio::log_debug("testing {}", random_t_eval); EXPECT_EQ(tdpf(random_t_eval), 0.0); } From e669dde346a62cb9ecc5075fb536f683f14d53bc Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:34:25 +0200 Subject: [PATCH 13/42] Revert "debug CI" This reverts commit dd88980877cce5cc784389848c5d286302d9543e. --- cpp/tests/test_math_tdpf.cpp | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/cpp/tests/test_math_tdpf.cpp b/cpp/tests/test_math_tdpf.cpp index 96b34287b0..3399c80f7a 100644 --- a/cpp/tests/test_math_tdpf.cpp +++ b/cpp/tests/test_math_tdpf.cpp @@ -1,5 +1,4 @@ #include "memilio/math/time_dependent_parameter_functor.h" -#include "memilio/utils/logging.h" #include "memilio/utils/random_number_generator.h" #include @@ -30,19 +29,15 @@ class TestMathTdpf : public ::testing::Test TEST_F(TestMathTdpf, zero) { // Test that the Zero-TDPF always returns zero, using a random evaluation point. - mio::set_log_level(mio::LogLevel::trace); - mio::log_info("entered test"); + // initialize mio::TimeDependentParameterFunctor tdpf; - mio::log_info("created tdpf"); // verify output for (int i = 0; i < this->num_evals; i++) { - double random_t_eval = this->fuzzy_val(-std::numeric_limits::max(), std::numeric_limits::max()); - mio::log_debug("testing {}", random_t_eval); + auto random_t_eval = this->fuzzy_val(-std::numeric_limits::max(), std::numeric_limits::max()); EXPECT_EQ(tdpf(random_t_eval), 0.0); } - mio::log_info("finished test"); } TEST_F(TestMathTdpf, linearInterpolation) From ec78f5893cfbe5ccb9f337e427edaf8f80c41d02 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:37:12 +0200 Subject: [PATCH 14/42] limit TDPF Zero evaluation to avoid potential msvc bug --- cpp/tests/test_math_tdpf.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/tests/test_math_tdpf.cpp b/cpp/tests/test_math_tdpf.cpp index 3399c80f7a..25b5201edb 100644 --- a/cpp/tests/test_math_tdpf.cpp +++ b/cpp/tests/test_math_tdpf.cpp @@ -3,7 +3,6 @@ #include -#include #include class TestMathTdpf : public ::testing::Test @@ -35,7 +34,7 @@ TEST_F(TestMathTdpf, zero) // verify output for (int i = 0; i < this->num_evals; i++) { - auto random_t_eval = this->fuzzy_val(-std::numeric_limits::max(), std::numeric_limits::max()); + auto random_t_eval = this->fuzzy_val(-1e+3, 1e+3); EXPECT_EQ(tdpf(random_t_eval), 0.0); } } From 4d1d92405f144b15fe460b95594af1210064a5ef Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Wed, 24 Jul 2024 09:58:01 +0200 Subject: [PATCH 15/42] try out a codecov arg to fix some lines incorrectly marked as uncovered --- .github/workflows/main.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index de1e66fc3c..17c4680285 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -347,6 +347,7 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} directory: ./coverage_python files: ./coverage_python/**,./coverage.info + functionalities: fixes verbose: true pages: From 3dd06e64bb7b99177aacbe2246da485a60cb533c Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Wed, 24 Jul 2024 10:39:19 +0200 Subject: [PATCH 16/42] remove codecov arg, instead disable some gcc optimisations --- .github/workflows/main.yml | 1 - cpp/CMakeLists.txt | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 17c4680285..de1e66fc3c 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -347,7 +347,6 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} directory: ./coverage_python files: ./coverage_python/**,./coverage.info - functionalities: fixes verbose: true pages: diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index b4e3595e82..ec32aa8686 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -76,6 +76,8 @@ if(CMAKE_BUILD_TYPE STREQUAL "Debug" OR CMAKE_BUILD_TYPE STREQUAL "DEBUG") message(STATUS "Coverage enabled") include(CodeCoverage) append_coverage_compiler_flags() + # also disable elision and inlining to prevent e.g. closing brackets being marked as uncovered + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-elide-constructors -fno-default-inline") setup_target_for_coverage_lcov( NAME coverage EXECUTABLE memilio-test From 07d863b57a41a26a972174b934051cc940d4a175 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Thu, 1 Aug 2024 14:49:01 +0200 Subject: [PATCH 17/42] move NVPs by value, use add/expect_list --- cpp/memilio/io/auto_serialize.h | 60 ++++++++++++++++++++++----------- 1 file changed, 41 insertions(+), 19 deletions(-) diff --git a/cpp/memilio/io/auto_serialize.h b/cpp/memilio/io/auto_serialize.h index aa056320d3..949e2dc1a6 100644 --- a/cpp/memilio/io/auto_serialize.h +++ b/cpp/memilio/io/auto_serialize.h @@ -56,7 +56,7 @@ struct NVP { Type value; NVP() = delete; - NVP(const NVP&) = delete; + NVP(const NVP&) = default; NVP(NVP&&) = default; NVP& operator=(const NVP&) = delete; NVP& operator=(NVP&&) = delete; @@ -77,9 +77,9 @@ struct NVP { * @return Collection of all name views and value references used for auto-(de)serialization. */ template -[[nodiscard]] inline auto make_auto_serialization(const std::string_view class_name, NVP&&... class_members) +[[nodiscard]] inline auto make_auto_serialization(const std::string_view&& class_name, NVP&&... class_members) { - return std::make_pair(class_name, std::make_tuple(std::move(class_members)...)); + return std::make_pair(class_name, std::make_tuple(class_members...)); } /** @@ -112,33 +112,54 @@ namespace details template using auto_serialize_expr_t = decltype(std::declval().auto_serialize()); -template -void add_nvp(IOObject& obj, NVP const&& nvp) +template +using serialize_internal_expr_t = decltype(serialize_internal(std::declval(), std::declval())); + +template +void add_nvp(IOObject& obj, const NVP nvp) { - obj.add_element(std::string{nvp.name}, nvp.value); + if constexpr (is_container::value && + !is_expression_valid::value) { + obj.add_list(std::string{nvp.name}, nvp.value.begin(), nvp.value.end()); + } + else { + obj.add_element(std::string{nvp.name}, nvp.value); + } } template -void auto_serialize_impl(IOContext& io, const std::string_view name, std::tuple...> const&& targets) +void auto_serialize_impl(IOContext& io, const std::string_view name, const std::tuple...> targets) { auto obj = io.create_object(std::string{name}); std::apply( - [&obj](NVP const&&... nvps) { - (add_nvp(obj, std::move(nvps)), ...); + [&obj](const NVP... nvps) { + (add_nvp(obj, nvps), ...); }, - std::move(targets)); + targets); } -template -IOResult expect_nvp(IOObject& obj, NVP&& nvp) +template +IOResult expect_nvp(IOObject& obj, const NVP nvp) { - return obj.expect_element(std::string{nvp.name}, Tag{}); + if constexpr (is_container::value && + !is_expression_valid::value) { + return obj.expect_list(std::string{nvp.name}, Tag{}); + } + else { + return obj.expect_element(std::string{nvp.name}, Tag{}); + } } +// template +// void assign_nvp(NVP nvp, const Target& value) +// { +// nvp.value = value; +// } + template IOResult auto_deserialize_impl(IOContext& io, AutoSerializable& a, std::string_view name, - std::tuple...>&& targets) + std::tuple...> targets) { auto obj = io.expect_object(std::string{name}); @@ -146,13 +167,14 @@ IOResult auto_deserialize_impl(IOContext& io, AutoSerializable return apply( io, [&a, &nvps...](const Targets&... values) { + // (assign_nvp(nvps, values), ...); ((nvps.value = values), ...); return a; }, - expect_nvp(obj, std::move(nvps))...); + expect_nvp(obj, nvps)...); }; - return std::apply(unpacked_apply, std::move(targets)); + return std::apply(unpacked_apply, targets); } } // namespace details @@ -173,8 +195,8 @@ template < void serialize_internal(IOContext& io, const AutoSerializable& t) { // Note that this cast is only safe if we do not modify targets. - const auto targets = const_cast(&t)->auto_serialize(); - details::auto_serialize_impl(io, targets.first, std::move(targets.second)); + const auto targets = const_cast(t).auto_serialize(); + details::auto_serialize_impl(io, targets.first, targets.second); } // disables itself if a deserialize member is present or if there is no auto_serialize member @@ -187,7 +209,7 @@ IOResult deserialize_internal(IOContext& io, Tag::create(); auto targets = a.auto_serialize(); - return details::auto_deserialize_impl(io, a, targets.first, std::move(targets.second)); + return details::auto_deserialize_impl(io, a, targets.first, targets.second); } } // namespace mio From d945bf4ddd6ecae2109c41a5334c01693cb81272 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Wed, 7 Aug 2024 11:18:50 +0200 Subject: [PATCH 18/42] Enable auto_serialize to add lists directly. Add and use mio::apply with a tuple of results in auto_deserialize_impl. --- cpp/memilio/io/auto_serialize.h | 93 ++++++++++++++++++++------------- cpp/memilio/io/io.h | 12 ++++- 2 files changed, 69 insertions(+), 36 deletions(-) diff --git a/cpp/memilio/io/auto_serialize.h b/cpp/memilio/io/auto_serialize.h index 949e2dc1a6..114b5b420b 100644 --- a/cpp/memilio/io/auto_serialize.h +++ b/cpp/memilio/io/auto_serialize.h @@ -60,6 +60,12 @@ struct NVP { NVP(NVP&&) = default; NVP& operator=(const NVP&) = delete; NVP& operator=(NVP&&) = delete; + + NVP& operator=(const ValueType& v) + { + this->value = v; + return *this; + } }; /** @@ -112,14 +118,24 @@ namespace details template using auto_serialize_expr_t = decltype(std::declval().auto_serialize()); -template -using serialize_internal_expr_t = decltype(serialize_internal(std::declval(), std::declval())); +} // namespace details + +/** + * @brief Detect whether T has a auto_serialize member function. + * @tparam T Any type. + */ +template +using has_auto_serialize = is_expression_valid; + +namespace details +{ +/// add a name-value pair to an io object template void add_nvp(IOObject& obj, const NVP nvp) { if constexpr (is_container::value && - !is_expression_valid::value) { + !(has_serialize::value || has_auto_serialize::value)) { obj.add_list(std::string{nvp.name}, nvp.value.begin(), nvp.value.end()); } else { @@ -127,6 +143,7 @@ void add_nvp(IOObject& obj, const NVP nvp) } } +/// unpack all name-value pairs from the tuple and add them to a new io object with the given name template void auto_serialize_impl(IOContext& io, const std::string_view name, const std::tuple...> targets) { @@ -139,11 +156,12 @@ void auto_serialize_impl(IOContext& io, const std::string_view name, const std:: targets); } +/// retrieve a name-value pair from an io object template IOResult expect_nvp(IOObject& obj, const NVP nvp) { if constexpr (is_container::value && - !is_expression_valid::value) { + !(has_serialize::value || has_auto_serialize::value)) { return obj.expect_list(std::string{nvp.name}, Tag{}); } else { @@ -151,62 +169,67 @@ IOResult expect_nvp(IOObject& obj, const NVP nvp) } } -// template -// void assign_nvp(NVP nvp, const Target& value) -// { -// nvp.value = value; -// } - +/// read an io object and its members from the io context using the given names and assign the values to a template IOResult auto_deserialize_impl(IOContext& io, AutoSerializable& a, std::string_view name, std::tuple...> targets) { auto obj = io.expect_object(std::string{name}); - auto unpacked_apply = [&io, &a, &obj](NVP... nvps) { - return apply( - io, - [&a, &nvps...](const Targets&... values) { - // (assign_nvp(nvps, values), ...); - ((nvps.value = values), ...); - return a; - }, - expect_nvp(obj, nvps)...); - }; - - return std::apply(unpacked_apply, targets); + const auto results = std::apply( + [&obj](NVP... nvps) { + return std::make_tuple(expect_nvp(obj, nvps)...); + }, + targets); + + return apply( + io, + [&a, &targets](const Targets&... values) { + targets = std::make_tuple(values...); + return a; + }, + results); } } // namespace details /** - * @brief Detect whether T has a auto_serialize member function. - * @tparam T Any type. + * @brief Serialization implementation for the auto-serialization feature. + * Disables itself (SFINAE) if there is no auto_serialize member or if a serialize member is present. + * Generates the serialize method depending on the NVPs given by auto_serialize. + * @tparam IOContext A type that models the IOContext concept. + * @tparam AutoSerializable A type that can be auto-serialized. + * @param io An IO context. + * @param a An instance of AutoSerializable to be serialized. */ -template -using has_auto_serialize = is_expression_valid; - -// disables itself if a deserialize member is present or if there is no auto_serialize member -// generates serialize method depending on NVPs given by auto_serialize template < class IOContext, class AutoSerializable, std::enable_if_t::value && !has_serialize::value, AutoSerializable*> = nullptr> -void serialize_internal(IOContext& io, const AutoSerializable& t) +void serialize_internal(IOContext& io, const AutoSerializable& a) { - // Note that this cast is only safe if we do not modify targets. - const auto targets = const_cast(t).auto_serialize(); + // Note that this cons_cast is only safe if we do not modify targets. + const auto targets = const_cast(a).auto_serialize(); details::auto_serialize_impl(io, targets.first, targets.second); } -// disables itself if a deserialize member is present or if there is no auto_serialize member -// generates deserialize method depending on NVPs given by auto_serialize +/** + * @brief Deserialization implementation for the auto-serialization feature. + * Disables itself (SFINAE) if there is no auto_serialize member or if a deserialize member is present. + * Generates the deserialize method depending on the NVPs given by auto_serialize. + * @tparam IOContext A type that models the IOContext concept. + * @tparam AutoSerializable A type that can be auto-serialized. + * @param io An IO context. + * @param tag Defines the type of the object that is to be deserialized (i.e. AutoSerializble). + * @return The restored object if successful, an error otherwise. + */ template ::value && !has_deserialize::value, AutoSerializable*> = nullptr> -IOResult deserialize_internal(IOContext& io, Tag) +IOResult deserialize_internal(IOContext& io, Tag tag) { + mio::unused(tag); AutoSerializable a = AutoSerializableFactory::create(); auto targets = a.auto_serialize(); return details::auto_deserialize_impl(io, a, targets.first, targets.second); diff --git a/cpp/memilio/io/io.h b/cpp/memilio/io/io.h index 7e8e9dbf82..664646aa58 100644 --- a/cpp/memilio/io/io.h +++ b/cpp/memilio/io/io.h @@ -482,7 +482,7 @@ details::ApplyResultT apply(IOContext& io, F f, const IOResult&... r IOStatus status[] = {(rs ? IOStatus{} : rs.error())...}; auto iter_err = std::find_if(std::begin(status), std::end(status), [](auto& s) { return s.is_error(); - }); + }); //evaluate f if all succesful auto result = @@ -497,6 +497,16 @@ details::ApplyResultT apply(IOContext& io, F f, const IOResult&... r return result; } +template +details::ApplyResultT apply(IOContext& io, F f, const std::tuple...>& rs) +{ + return std::apply( + [&io, f](const IOResult&... args) { + return apply(io, f, args...); + }, + rs); +} + //utility for (de-)serializing tuple-like objects namespace details { From 8fbe7dc7b6d7afc7ddb9741fba2195a528ea06d7 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Wed, 7 Aug 2024 15:55:14 +0200 Subject: [PATCH 19/42] Simplify auto_(de)serialize_impl by moving tuple unpacking to (de)serialize_internal --- cpp/memilio/io/auto_serialize.h | 69 +++++++++++++-------------------- cpp/memilio/io/io.h | 10 ----- 2 files changed, 28 insertions(+), 51 deletions(-) diff --git a/cpp/memilio/io/auto_serialize.h b/cpp/memilio/io/auto_serialize.h index 114b5b420b..371157c25c 100644 --- a/cpp/memilio/io/auto_serialize.h +++ b/cpp/memilio/io/auto_serialize.h @@ -130,65 +130,42 @@ using has_auto_serialize = is_expression_valid +/// Add a name-value pair to an io object. +template void add_nvp(IOObject& obj, const NVP nvp) { - if constexpr (is_container::value && - !(has_serialize::value || has_auto_serialize::value)) { - obj.add_list(std::string{nvp.name}, nvp.value.begin(), nvp.value.end()); - } - else { - obj.add_element(std::string{nvp.name}, nvp.value); - } + obj.add_element(std::string{nvp.name}, nvp.value); } -/// unpack all name-value pairs from the tuple and add them to a new io object with the given name +/// Unpack all name-value pairs from the tuple and add them to a new io object with the given name. template -void auto_serialize_impl(IOContext& io, const std::string_view name, const std::tuple...> targets) +void auto_serialize_impl(IOContext& io, const std::string_view name, const NVP... nvps) { auto obj = io.create_object(std::string{name}); - - std::apply( - [&obj](const NVP... nvps) { - (add_nvp(obj, nvps), ...); - }, - targets); + (add_nvp(obj, nvps), ...); } -/// retrieve a name-value pair from an io object -template +/// Retrieve a name-value pair from an io object. +template IOResult expect_nvp(IOObject& obj, const NVP nvp) { - if constexpr (is_container::value && - !(has_serialize::value || has_auto_serialize::value)) { - return obj.expect_list(std::string{nvp.name}, Tag{}); - } - else { - return obj.expect_element(std::string{nvp.name}, Tag{}); - } + return obj.expect_element(std::string{nvp.name}, Tag{}); } -/// read an io object and its members from the io context using the given names and assign the values to a +/// Read an io object and its members from the io context using the given names and assign the values to a. template IOResult auto_deserialize_impl(IOContext& io, AutoSerializable& a, std::string_view name, - std::tuple...> targets) + NVP... nvps) { auto obj = io.expect_object(std::string{name}); - const auto results = std::apply( - [&obj](NVP... nvps) { - return std::make_tuple(expect_nvp(obj, nvps)...); - }, - targets); - return apply( io, - [&a, &targets](const Targets&... values) { - targets = std::make_tuple(values...); + [&a, &nvps...](const Targets&... values) { + ((nvps.value = values), ...); return a; }, - results); + expect_nvp(obj, nvps)...); } } // namespace details @@ -208,14 +185,19 @@ template < AutoSerializable*> = nullptr> void serialize_internal(IOContext& io, const AutoSerializable& a) { - // Note that this cons_cast is only safe if we do not modify targets. + // Note that the following cons_cast is only safe if we do not modify targets. const auto targets = const_cast(a).auto_serialize(); - details::auto_serialize_impl(io, targets.first, targets.second); + // unpack targets and serialize + std::apply( + [&io, &targets](auto... nvps) { + details::auto_serialize_impl(io, targets.first, nvps...); + }, + targets.second); } /** * @brief Deserialization implementation for the auto-serialization feature. - * Disables itself (SFINAE) if there is no auto_serialize member or if a deserialize member is present. + * Disables itself (SFINAE) if there is no auto_serialize member or if a deserialize meember is present. * Generates the deserialize method depending on the NVPs given by auto_serialize. * @tparam IOContext A type that models the IOContext concept. * @tparam AutoSerializable A type that can be auto-serialized. @@ -232,7 +214,12 @@ IOResult deserialize_internal(IOContext& io, Tag::create(); auto targets = a.auto_serialize(); - return details::auto_deserialize_impl(io, a, targets.first, targets.second); + // unpack targets and deserialize + return std::apply( + [&io, &targets, &a](auto... nvps) { + return details::auto_deserialize_impl(io, a, targets.first, nvps...); + }, + targets.second); } } // namespace mio diff --git a/cpp/memilio/io/io.h b/cpp/memilio/io/io.h index 664646aa58..8347030af9 100644 --- a/cpp/memilio/io/io.h +++ b/cpp/memilio/io/io.h @@ -497,16 +497,6 @@ details::ApplyResultT apply(IOContext& io, F f, const IOResult&... r return result; } -template -details::ApplyResultT apply(IOContext& io, F f, const std::tuple...>& rs) -{ - return std::apply( - [&io, f](const IOResult&... args) { - return apply(io, f, args...); - }, - rs); -} - //utility for (de-)serializing tuple-like objects namespace details { From 736cd1be7b54d38f6a6723daa60b05f8cdef9c76 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Wed, 7 Aug 2024 15:58:08 +0200 Subject: [PATCH 20/42] Rework TDPF to use TimeSeries and rename it. --- cpp/memilio/CMakeLists.txt | 2 +- ...ameter_functor.h => time_series_functor.h} | 61 +++++++------- cpp/memilio/utils/time_series.h | 37 ++++++++- cpp/models/abm/parameters.h | 19 ++--- cpp/simulations/abm.cpp | 3 +- cpp/simulations/abm_braunschweig.cpp | 52 ++++++------ cpp/tests/CMakeLists.txt | 2 +- cpp/tests/test_abm_infection.cpp | 18 ++--- cpp/tests/test_math_tdpf.cpp | 76 ------------------ cpp/tests/test_math_time_series_functor.cpp | 80 +++++++++++++++++++ 10 files changed, 194 insertions(+), 156 deletions(-) rename cpp/memilio/math/{time_dependent_parameter_functor.h => time_series_functor.h} (51%) delete mode 100644 cpp/tests/test_math_tdpf.cpp create mode 100644 cpp/tests/test_math_time_series_functor.cpp diff --git a/cpp/memilio/CMakeLists.txt b/cpp/memilio/CMakeLists.txt index aef7abb775..4e121f5d84 100644 --- a/cpp/memilio/CMakeLists.txt +++ b/cpp/memilio/CMakeLists.txt @@ -58,7 +58,7 @@ add_library(memilio math/matrix_shape.cpp math/interpolation.h math/interpolation.cpp - math/time_dependent_parameter_functor.h + math/time_series_functor.h mobility/metapopulation_mobility_instant.h mobility/metapopulation_mobility_instant.cpp mobility/metapopulation_mobility_stochastic.h diff --git a/cpp/memilio/math/time_dependent_parameter_functor.h b/cpp/memilio/math/time_series_functor.h similarity index 51% rename from cpp/memilio/math/time_dependent_parameter_functor.h rename to cpp/memilio/math/time_series_functor.h index 675207cad9..74d745a6dc 100644 --- a/cpp/memilio/math/time_dependent_parameter_functor.h +++ b/cpp/memilio/math/time_series_functor.h @@ -17,30 +17,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MIO_MATH_TIME_DEPENDENT_PARAMETER_FUNCTOR_H -#define MIO_MATH_TIME_DEPENDENT_PARAMETER_FUNCTOR_H +#ifndef MIO_MATH_TIME_SERIES_FUNCTOR_H +#define MIO_MATH_TIME_SERIES_FUNCTOR_H -#include "memilio/config.h" #include "memilio/io/auto_serialize.h" #include "memilio/math/interpolation.h" +#include "memilio/utils/time_series.h" #include +#include #include namespace mio { -class TimeDependentParameterFunctor +template +class TimeSeriesFunctor { public: - enum class Type + enum Type { Zero, LinearInterpolation, }; - using DataType = std::vector>; - TimeDependentParameterFunctor(Type type, const DataType& data) + TimeSeriesFunctor(Type type, const TimeSeries& data) : m_type(type) , m_data(data) { @@ -51,43 +52,45 @@ class TimeDependentParameterFunctor break; case Type::LinearInterpolation: // make sure data has the correct shape, i.e. a list of (time, value) pairs - assert(m_data.size() > 0); - assert(std::all_of(m_data.begin(), m_data.end(), [](auto&& a) { - return a.size() == 2; - })); - // sort by time - std::sort(m_data.begin(), m_data.end(), [](auto&& a, auto&& b) { - return a[0] < b[0]; - }); + assert(m_data.get_num_time_points() > 0 && "Need at least one time point for LinearInterpolation."); + assert(m_data.get_num_elements() == 1 && "LinearInterpolation requires exactly one value per time point."); + assert(m_data.is_sorted()); } } - TimeDependentParameterFunctor() - : TimeDependentParameterFunctor(Type::Zero, {}) + TimeSeriesFunctor(Type type, std::vector>&& table) + : TimeSeriesFunctor(type, TimeSeries{table}) + { + } + + TimeSeriesFunctor() + : TimeSeriesFunctor(Type::Zero, TimeSeries{0}) { } - ScalarType operator()(ScalarType time) const + FP operator()(FP time) const { - ScalarType value = 0.0; + FP value = 0.0; switch (m_type) { case Type::Zero: // value is explicitly zero-initialized break; case Type::LinearInterpolation: // find next time point in m_data (strictly) after time - const auto next_tp = std::upper_bound(m_data.begin(), m_data.end(), time, [](auto&& t, auto&& tp) { - return t < tp[0]; + auto tp_range = m_data.get_times(); + const auto next_tp = std::upper_bound(tp_range.begin(), tp_range.end(), time, [](auto&& t, auto&& tp) { + return t < tp; }); - if (next_tp == m_data.begin()) { // time is before first data point - value = m_data.front()[1]; + if (next_tp == tp_range.begin()) { // time is before first data point + value = m_data.get_value(0)[0]; } - else if (next_tp == m_data.end()) { // time is past last data point - value = m_data.back()[1]; + else if (next_tp == tp_range.end()) { // time is past last data point + value = m_data.get_last_value()[0]; } else { // time is in between data points - const auto tp = next_tp - 1; - value = linear_interpolation(time, (*tp)[0], (*next_tp)[0], (*tp)[1], (*next_tp)[1]); + const auto i = next_tp - tp_range.begin(); + value = linear_interpolation(time, m_data.get_time(i - 1), m_data.get_time(i), + m_data.get_value(i - 1)[0], m_data.get_value(i)[0]); } break; } @@ -97,12 +100,12 @@ class TimeDependentParameterFunctor /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("TimeDependentParameterFunctor", NVP("type", m_type), NVP("data", m_data)); + return make_auto_serialization("TimeSeriesFunctor", NVP("type", m_type), NVP("data", m_data)); } private: Type m_type; - DataType m_data; + TimeSeries m_data; }; } // namespace mio diff --git a/cpp/memilio/utils/time_series.h b/cpp/memilio/utils/time_series.h index 0016e38d71..d835cdba91 100644 --- a/cpp/memilio/utils/time_series.h +++ b/cpp/memilio/utils/time_series.h @@ -21,10 +21,10 @@ #define EPI_TIME_SERIES_H #include "memilio/io/io.h" -#include "memilio/math/eigen.h" #include "memilio/utils/stl_util.h" #include "memilio/math/floating_point.h" +#include #include #include #include @@ -102,6 +102,31 @@ class TimeSeries col.tail(expr.rows()) = expr; } + /** + * @brief Initialize a TimeSeries with a table. + * @param table Consists of a list of time points, each of the form (time, value_0, value_1, ..., value_n) for + * some fixed n >= 0. + */ + TimeSeries(std::vector> table) + : m_data() + , m_num_time_points(table.size()) + { + assert(table.size() > 0); + assert(std::all_of(table.begin(), table.end(), [&table](auto&& a) { + return a.size() == table.front().size(); + })); + m_data.resize(table.front().size(), table.size()); + // sort by time + std::sort(table.begin(), table.end(), [](auto&& a, auto&& b) { + return a[0] < b[0]; + }); + for (Eigen::Index tp = 0; tp < m_data.cols(); tp++) { + for (Eigen::Index i = 0; i < m_data.rows(); i++) { + m_data(i, tp) = table[tp][i]; + } + } + } + /** copy ctor */ TimeSeries(const TimeSeries& other) : m_data(other.get_num_elements() + 1, details::next_pow2(other.m_num_time_points)) @@ -148,6 +173,16 @@ class TimeSeries TimeSeries(TimeSeries&& other) = default; TimeSeries& operator=(TimeSeries&& other) = default; + /// Check if the time is strictly monotonic increasing. + bool is_sorted() + { + const auto times = get_times(); + auto time_itr = times.begin(); + return std::all_of(++times.begin(), times.end(), [&](const auto& t) { + return *(time_itr++) < t; + }); + } + /** * number of time points in the series */ diff --git a/cpp/models/abm/parameters.h b/cpp/models/abm/parameters.h index c68b3a5973..a85d169476 100644 --- a/cpp/models/abm/parameters.h +++ b/cpp/models/abm/parameters.h @@ -28,7 +28,7 @@ #include "memilio/config.h" #include "memilio/io/auto_serialize.h" #include "memilio/io/io.h" -#include "memilio/math/time_dependent_parameter_functor.h" +#include "memilio/math/time_series_functor.h" #include "memilio/utils/custom_index_array.h" #include "memilio/utils/uncertain_value.h" #include "memilio/utils/parameter_set.h" @@ -298,19 +298,15 @@ struct AerosolTransmissionRates { } }; -// using InputFunctionForProtectionLevel = std::function; -using InputFunctionForProtectionLevel = TimeDependentParameterFunctor; - /** * @brief Personal protection factor against #Infection% after #Infection and #Vaccination, which depends on #ExposureType, * #AgeGroup and #VirusVariant. Its value is between 0 and 1. */ struct InfectionProtectionFactor { - using Type = CustomIndexArray; + using Type = CustomIndexArray, ExposureType, AgeGroup, VirusVariant>; static auto get_default(AgeGroup size) { - return Type({ExposureType::Count, size, VirusVariant::Count}, - Type::value_type(TimeDependentParameterFunctor::Type::Zero, {})); + return Type({ExposureType::Count, size, VirusVariant::Count}, TimeSeriesFunctor()); } static std::string name() { @@ -323,11 +319,10 @@ struct InfectionProtectionFactor { * #AgeGroup and #VirusVariant. Its value is between 0 and 1. */ struct SeverityProtectionFactor { - using Type = CustomIndexArray; + using Type = CustomIndexArray, ExposureType, AgeGroup, VirusVariant>; static auto get_default(AgeGroup size) { - return Type({ExposureType::Count, size, VirusVariant::Count}, - Type::value_type(TimeDependentParameterFunctor::Type::Zero, {})); + return Type({ExposureType::Count, size, VirusVariant::Count}, TimeSeriesFunctor()); } static std::string name() { @@ -339,10 +334,10 @@ struct SeverityProtectionFactor { * @brief Personal protective factor against high viral load. Its value is between 0 and 1. */ struct HighViralLoadProtectionFactor { - using Type = InputFunctionForProtectionLevel; + using Type = TimeSeriesFunctor; static auto get_default() { - return Type(TimeDependentParameterFunctor::Type::Zero, {}); + return Type(); } static std::string name() { diff --git a/cpp/simulations/abm.cpp b/cpp/simulations/abm.cpp index 00cd7eb327..b04e560b83 100644 --- a/cpp/simulations/abm.cpp +++ b/cpp/simulations/abm.cpp @@ -21,6 +21,7 @@ #include "abm/common_abm_loggers.h" #include "abm/household.h" #include "abm/lockdown_rules.h" +#include "memilio/config.h" #include "memilio/io/result_io.h" #include "memilio/math/interpolation.h" #include "memilio/utils/random_number_generator.h" @@ -472,7 +473,7 @@ void set_parameters(mio::abm::Parameters params) // Set protection level from high viral load. Information based on: https://doi.org/10.1093/cid/ciaa886 params.get() = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.863}, {1, 0.969}, {7, 0.029}, {10, 0.002}, {14, 0.0014}, {21, 0}}}; //0-4 diff --git a/cpp/simulations/abm_braunschweig.cpp b/cpp/simulations/abm_braunschweig.cpp index 51672387b2..a39def895f 100644 --- a/cpp/simulations/abm_braunschweig.cpp +++ b/cpp/simulations/abm_braunschweig.cpp @@ -401,7 +401,7 @@ void set_parameters(mio::abm::Parameters params) // Set protection level from high viral load. Information based on: https://doi.org/10.1093/cid/ciaa886 params.get() = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.863}, {1, 0.969}, {7, 0.029}, {10, 0.002}, {14, 0.0014}, {21, 0}}}; //0-4 @@ -493,7 +493,7 @@ void set_parameters(mio::abm::Parameters params) // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_0_to_4, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.852}, {180, 0.852}, {210, 0.845}, @@ -510,14 +510,14 @@ void set_parameters(mio::abm::Parameters params) // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_0_to_4, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.5}, {30, 0.91}, {60, 0.92}, {90, 0.88}, {120, 0.84}, {150, 0.81}, {180, 0.88}, {450, 0.5}}}; // Set up age-related severe protection levels, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_0_to_4, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.967}, {30, 0.975}, {60, 0.977}, @@ -535,7 +535,7 @@ void set_parameters(mio::abm::Parameters params) // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_0_to_4, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {450, 0.5}}}; //5-14 @@ -553,7 +553,7 @@ void set_parameters(mio::abm::Parameters params) // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_5_to_14, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.852}, {180, 0.852}, {210, 0.845}, @@ -569,14 +569,14 @@ void set_parameters(mio::abm::Parameters params) // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_5_to_14, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.5}, {30, 0.91}, {60, 0.92}, {90, 0.88}, {120, 0.84}, {150, 0.81}, {180, 0.88}, {450, 0.5}}}; // Set up age-related severe protection levels, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_5_to_14, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.967}, {30, 0.975}, {60, 0.977}, @@ -594,7 +594,7 @@ void set_parameters(mio::abm::Parameters params) // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_5_to_14, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {450, 0.5}}}; //15-34 @@ -612,7 +612,7 @@ void set_parameters(mio::abm::Parameters params) // Set up personal infection and vaccine protection levels, based on: https://doi.org/10.1038/s41577-021-00550-x, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.852}, {180, 0.852}, {210, 0.845}, @@ -628,13 +628,13 @@ void set_parameters(mio::abm::Parameters params) // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.5}, {30, 0.89}, {60, 0.84}, {90, 0.78}, {120, 0.68}, {150, 0.57}, {180, 0.39}, {450, 0.1}}}; // Set up age-related severe protection levels, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.967}, {30, 0.975}, {60, 0.977}, @@ -652,7 +652,7 @@ void set_parameters(mio::abm::Parameters params) // Information is from: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {180, 0.90}, {450, 0.5}}}; //35-59 @@ -671,7 +671,7 @@ void set_parameters(mio::abm::Parameters params) // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_35_to_59, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.852}, {180, 0.852}, {210, 0.845}, @@ -687,13 +687,13 @@ void set_parameters(mio::abm::Parameters params) // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_35_to_59, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.5}, {30, 0.89}, {60, 0.84}, {90, 0.78}, {120, 0.68}, {150, 0.57}, {180, 0.39}, {450, 0.1}}}; // Set up age-related severe protection levels, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_35_to_59, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.967}, {30, 0.975}, {60, 0.977}, @@ -711,7 +711,7 @@ void set_parameters(mio::abm::Parameters params) // Information is from: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_35_to_59, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {180, 0.90}, {450, 0.5}}}; //60-79 params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = @@ -729,7 +729,7 @@ void set_parameters(mio::abm::Parameters params) // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_60_to_79, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.852}, {180, 0.852}, {210, 0.845}, @@ -745,14 +745,14 @@ void set_parameters(mio::abm::Parameters params) // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_60_to_79, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.5}, {30, 0.87}, {60, 0.85}, {90, 0.78}, {120, 0.67}, {150, 0.61}, {180, 0.50}, {450, 0.1}}}; // Set up personal severe protection levels. // Protection of severe infection of age group 65 + is different from other age group, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_60_to_79, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.967}, {30, 0.975}, {60, 0.977}, @@ -769,7 +769,7 @@ void set_parameters(mio::abm::Parameters params) {360, 0.5}}}; params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_60_to_79, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.5}, {30, 0.91}, {60, 0.86}, {90, 0.91}, {120, 0.94}, {150, 0.95}, {180, 0.90}, {450, 0.5}}}; //80+ @@ -787,7 +787,7 @@ void set_parameters(mio::abm::Parameters params) // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_80_plus, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.852}, {180, 0.852}, {210, 0.845}, @@ -803,14 +803,14 @@ void set_parameters(mio::abm::Parameters params) // Information is from: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_80_plus, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.5}, {30, 0.80}, {60, 0.79}, {90, 0.75}, {120, 0.56}, {150, 0.49}, {180, 0.43}, {450, 0.1}}}; // Set up personal severe protection levels. // Protection of severe infection of age group 65 + is different from other age group, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_0_to_4, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.967}, {30, 0.975}, {60, 0.977}, @@ -828,7 +828,7 @@ void set_parameters(mio::abm::Parameters params) // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_80_plus, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.5}, {30, 0.84}, {60, 0.88}, {90, 0.89}, {120, 0.86}, {150, 0.85}, {180, 0.83}, {450, 0.5}}}; } @@ -904,7 +904,7 @@ void write_log_to_file_trip_data(const T& history) int start_index = mobility_data_index - 1; using Type = std::tuple; + mio::abm::TransportMode, mio::abm::ActivityType, mio::abm::InfectionState>; while (!std::binary_search(std::begin(mobility_data[start_index]), std::end(mobility_data[start_index]), mobility_data[mobility_data_index][trip_index], [](const Type& v1, const Type& v2) { diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index b8696233a1..e51c0bc028 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -40,7 +40,7 @@ set(TESTSOURCES test_abm_testing_strategy.cpp test_abm_model.cpp test_math_floating_point.cpp - test_math_tdpf.cpp + test_math_time_series_functor.cpp test_analyze_result.cpp test_contact_matrix.cpp test_type_safe.cpp diff --git a/cpp/tests/test_abm_infection.cpp b/cpp/tests/test_abm_infection.cpp index 1e060afee7..e0a637a5da 100644 --- a/cpp/tests/test_abm_infection.cpp +++ b/cpp/tests/test_abm_infection.cpp @@ -78,10 +78,10 @@ TEST(TestInfection, init) EXPECT_NEAR(infection.get_infectivity(mio::abm::TimePoint(0) + mio::abm::days(3)), 0.2689414213699951, 1e-14); params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_test, - virus_variant_test}] = mio::TimeDependentParameterFunctor{ - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, {{0, 0.91}, {30, 0.81}}}; - params.get() = mio::TimeDependentParameterFunctor{ - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, {{0, 0.91}, {30, 0.81}}}; + virus_variant_test}] = mio::TimeSeriesFunctor{ + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.91}, {30, 0.81}}}; + params.get() = mio::TimeSeriesFunctor{ + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.91}, {30, 0.81}}}; auto infection_w_previous_exp = mio::abm::Infection(rng, mio::abm::VirusVariant::Wildtype, age_group_test, params, mio::abm::TimePoint(0), mio::abm::InfectionState::InfectedSymptoms, @@ -191,20 +191,20 @@ TEST(TestInfection, getPersonalProtectiveFactor) // Test linear interpolation with one node params.get()[{mio::abm::ExposureType::GenericVaccine, person.get_age(), mio::abm::VirusVariant::Wildtype}] = - mio::TimeDependentParameterFunctor{mio::TimeDependentParameterFunctor::Type::LinearInterpolation, {{2, 0.91}}}; + mio::TimeSeriesFunctor{mio::TimeSeriesFunctor::Type::LinearInterpolation, {{2, 0.91}}}; auto t = mio::abm::TimePoint(6 * 24 * 60 * 60); // TODO: Discuss: Assumption of interpolation in TDPF is that the function is constant with value at front/back entry outside of [front, back] time range. This works with one node as well and prints no errors EXPECT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0.91, eps); params.get()[{mio::abm::ExposureType::GenericVaccine, person.get_age(), mio::abm::VirusVariant::Wildtype}] = - mio::TimeDependentParameterFunctor{mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor{mio::TimeSeriesFunctor::Type::LinearInterpolation, {{2, 0.91}, {30, 0.81}}}; params.get()[{mio::abm::ExposureType::GenericVaccine, person.get_age(), mio::abm::VirusVariant::Wildtype}] = - mio::TimeDependentParameterFunctor{mio::TimeDependentParameterFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctor{mio::TimeSeriesFunctor::Type::LinearInterpolation, {{2, 0.91}, {30, 0.81}}}; - params.get() = mio::TimeDependentParameterFunctor{ - mio::TimeDependentParameterFunctor::Type::LinearInterpolation, {{2, 0.91}, {30, 0.81}}}; + params.get() = mio::TimeSeriesFunctor{ + mio::TimeSeriesFunctor::Type::LinearInterpolation, {{2, 0.91}, {30, 0.81}}}; // Test Parameter InfectionProtectionFactor and get_protection_factor() t = mio::abm::TimePoint(0) + mio::abm::days(2); diff --git a/cpp/tests/test_math_tdpf.cpp b/cpp/tests/test_math_tdpf.cpp deleted file mode 100644 index 25b5201edb..0000000000 --- a/cpp/tests/test_math_tdpf.cpp +++ /dev/null @@ -1,76 +0,0 @@ -#include "memilio/math/time_dependent_parameter_functor.h" -#include "memilio/utils/random_number_generator.h" - -#include - -#include - -class TestMathTdpf : public ::testing::Test -{ -public: - const int num_evals = 1000; - - double fuzzy_val(double min, double max) - { - return mio::UniformDistribution::get_instance()(m_rng, min, max); - } - -protected: - void SetUp() override - { - log_rng_seeds(m_rng, mio::LogLevel::warn); - } - -private: - mio::RandomNumberGenerator m_rng{}; -}; - -TEST_F(TestMathTdpf, zero) -{ - // Test that the Zero-TDPF always returns zero, using a random evaluation point. - - // initialize - mio::TimeDependentParameterFunctor tdpf; - - // verify output - for (int i = 0; i < this->num_evals; i++) { - auto random_t_eval = this->fuzzy_val(-1e+3, 1e+3); - EXPECT_EQ(tdpf(random_t_eval), 0.0); - } -} - -TEST_F(TestMathTdpf, linearInterpolation) -{ - // Test that the LinearInterpolation-TDPF correctly reproduces a (piecewise) linear function, using random samples. - // Since the initialization uses unsorted data, this also checks that the data gets sorted - - const double min = -1e+3, max = 1e+3; // reasonably large values for lin_fct height and slopes - const double t_min = -1, t_max = 1, t_mid = this->fuzzy_val(t_min, t_max); - const double slope1 = this->fuzzy_val(min, max), slope2 = this->fuzzy_val(min, max), - height = this->fuzzy_val(min, max); - - const auto pcw_lin_fct = [&](double t) { - // continuous function with different slopes between t_min, t_mid and t_max, constant otherwise - return height + slope1 * std::clamp(t - t_min, 0.0, t_mid - t_min) + - slope2 * std::clamp(t - t_mid, 0.0, t_max - t_mid); - }; - - // initialize the data with the critical points - std::vector> unsorted_data{ - {t_max, pcw_lin_fct(t_max)}, {t_min, pcw_lin_fct(t_min)}, {t_mid, pcw_lin_fct(t_mid)}}; - // randomly add a few more evaluations in between - for (int i = 0; i < 10; i++) { - const double t = this->fuzzy_val(-1.0, 1.0); - unsorted_data.push_back({t, pcw_lin_fct(t)}); - } - - mio::TimeDependentParameterFunctor tdpf(mio::TimeDependentParameterFunctor::Type::LinearInterpolation, - unsorted_data); - - // verify output - for (int i = 0; i < this->num_evals; i++) { - // sample in the interval [t_min - (t_max - t_min) / 4, t_max + (t_max - tmin) / 4] - double random_t_eval = this->fuzzy_val(1.25 * t_min - 0.25 * t_max, 1.25 * t_max - 0.25 * t_min); - EXPECT_NEAR(tdpf(random_t_eval), pcw_lin_fct(random_t_eval), 1e-10); - } -} diff --git a/cpp/tests/test_math_time_series_functor.cpp b/cpp/tests/test_math_time_series_functor.cpp new file mode 100644 index 0000000000..f22ed7cfed --- /dev/null +++ b/cpp/tests/test_math_time_series_functor.cpp @@ -0,0 +1,80 @@ +#include "memilio/math/time_series_functor.h" +#include "memilio/utils/random_number_generator.h" + +#include + +class TestMathTimeSeriesFunctor : public ::testing::Test +{ +public: + using TSF = mio::TimeSeriesFunctor; + + const int num_evals = 1000; + const double min = -1e+3, max = 1e+3; // a reasonably large range for fuzzy_val + + double fuzzy_val(double min_, double max_) + { + return mio::UniformDistribution::get_instance()(m_rng, min_, max_); + } + +protected: + void SetUp() override + { + log_rng_seeds(m_rng, mio::LogLevel::warn); + } + +private: + mio::RandomNumberGenerator m_rng{}; +}; + +TEST_F(TestMathTimeSeriesFunctor, zero) +{ + // Test that the Zero-functor always returns zero, using a random evaluation point. + + // initialize functor using the default ctor + TSF tsf; + + // check one deterministic value first to avoid flooding the test output with failed tests + ASSERT_EQ(tsf(0.0), 0.0); + // verify output + for (int i = 0; i < this->num_evals; i++) { + auto random_t_eval = this->fuzzy_val(this->min, this->max); + EXPECT_EQ(tsf(random_t_eval), 0.0); + } +} + +TEST_F(TestMathTimeSeriesFunctor, linearInterpolation) +{ + // Test that the LinearInterpolation-functor correctly reproduces a (piecewise) linear function, using random + // samples. Since the initialization uses unsorted data, this also checks that the data gets sorted + + const double t_min = -1, t_max = 1, t_mid = this->fuzzy_val(t_min, t_max); + const double slope1 = this->fuzzy_val(this->min, this->max); + const double slope2 = this->fuzzy_val(this->min, this->max); + const double height = this->fuzzy_val(this->min, this->max); + + const auto pcw_lin_fct = [&](double t) { + // continuous function with different slopes between t_min, t_mid and t_max, constant otherwise + return height + slope1 * std::clamp(t - t_min, 0.0, t_mid - t_min) + + slope2 * std::clamp(t - t_mid, 0.0, t_max - t_mid); + }; + + // initialize the data with the critical points + std::vector> unsorted_data{ + {t_max, pcw_lin_fct(t_max)}, {t_min, pcw_lin_fct(t_min)}, {t_mid, pcw_lin_fct(t_mid)}}; + // randomly add a few more evaluations in between + for (int i = 0; i < 10; i++) { + const double t = this->fuzzy_val(-1.0, 1.0); + unsorted_data.push_back({t, pcw_lin_fct(t)}); + } + // initialize functor + TSF tsf(TSF::Type::LinearInterpolation, unsorted_data); + + // check one deterministic value first to avoid flooding the test output with failed tests + ASSERT_NEAR(tsf(0.5 * (t_max - t_min)), pcw_lin_fct(0.5 * (t_max - t_min)), 1e-10); + // verify output + for (int i = 0; i < this->num_evals; i++) { + // sample in the interval [t_min - (t_max - t_min) / 4, t_max + (t_max - tmin) / 4] + double random_t_eval = this->fuzzy_val(1.25 * t_min - 0.25 * t_max, 1.25 * t_max - 0.25 * t_min); + EXPECT_NEAR(tsf(random_t_eval), pcw_lin_fct(random_t_eval), 1e-10) << "i = " << i; + } +} From b37f65326f9979e3ebe9956683c394ef74efc9e8 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Thu, 8 Aug 2024 13:43:33 +0200 Subject: [PATCH 21/42] clean up, comments, documentation --- cpp/memilio/io/README.md | 73 +++++++++++++++++--------- cpp/memilio/io/auto_serialize.h | 31 ++++------- cpp/memilio/math/time_series_functor.h | 40 ++++++++++++-- cpp/memilio/utils/time_series.h | 19 ++++--- cpp/models/abm/testing_strategy.h | 44 +--------------- cpp/models/abm/trip_list.h | 37 ++----------- 6 files changed, 110 insertions(+), 134 deletions(-) diff --git a/cpp/memilio/io/README.md b/cpp/memilio/io/README.md index 828315bda7..ef07d54833 100644 --- a/cpp/memilio/io/README.md +++ b/cpp/memilio/io/README.md @@ -4,7 +4,24 @@ This directory contains utilities for reading and writing data from and to files ## The Serialization framework -## Main functions and types +### Using serialization + +In the next sections we will explain how to implement serialization (both for types and formats), here we quickly show +how to use it once it already is implemented for a type. Currently, there is support for the Json and a binary format, +which can be used through the `serialize_json`/`deserialize_json` and `serialize_binary`/`deserialize_binary`, +respectively. For example + +```cpp +Foo foo{5}; +mio::IOResult js_result = mio::serialize_json(foo); +``` +```cpp +Json::Value js_value; +js_value["i"] = Json::Int(5); +mio::IOResult foo_result = mio::deserialize_json(js_value, mio::Tag{}); +``` + +### Main functions and types - functions serialize and deserialize: Main entry points to the framework to write and read values, respectively. The functions expect an IOContext @@ -14,7 +31,33 @@ This directory contains utilities for reading and writing data from and to files - IOStatus and IOResult: Used for error handling, see section "Error Handling" below. -## Concepts +### Auto-serialization + +Before we get into the details of the framework, this feature provides an easy and convenient alternative to the +serialize and deserialize functions. To give an example: + +```cpp +struct Foo { + int i; + auto auto_serialize() { + return make_auto_serialization("Foo", NVP("i", i)); + } +}; +``` +The auto-serialization is less flexible than the serialize and deserialize functions and has additional requirements: +- The class must be trivially constructible. + - Alternatively, you may provide a specialization of the struct `AutoSerializableFactory`. For more details, + view the struct's documentation. +- There is exactly one NVP for every class member (though the names and their order are arbitrary). + - Values must be passed directly, like in the example. No copies, accessors, etc. +- Every class member itself is both (auto-)(de)serializable and assignable. + +As to the feature set, auto-serialization only supports the `add_element` and `expect_element` operations defined in +the Concepts section below, where each operation's arguments are provided by the name-value pairs (NVPs). Note that the +value part of an NVP is also used to assign a value during deserialization, hence the class members must be used +directly in the NVP constructor (i.e. as a non-const lvalue reference). + +### Concepts 1. IOContext Stores data that describes serialized objects of any type in some unspecified format and provides structured @@ -66,7 +109,7 @@ for an IOObject `obj`: value or it may be empty. Otherwise returns an error. Note that for some formats a wrong key is indistinguishable from an empty optional, so make sure to provide the correct key. -## Error handling +### Error handling Errors are handled by returning error codes. The type IOStatus contains an error code and an optional string with additional information. The type IOResult contains either a value or an IOStatus that describes an error. Operations that can fail return @@ -78,7 +121,7 @@ inspected, so `expect_...` operations return an IOResult. The `apply` utility fu of multiple `expect_...` operations and use the values if all are succesful. See the documentation of `IOStatus`, `IOResult` and `apply` below for more details. -## Adding a new data type to be serialized +### Adding a new data type to be serialized Serialization of a new type T can be customized by providing _either_ member functions `serialize` and `deserialize` _or_ free functions `serialize_internal` and `deserialize_internal`. @@ -121,28 +164,6 @@ more efficiently than the provided general free functions. - HDF5 support classes for C++ - Reading of mobility matrix files -## Auto-serialization - -This feature provides an easy and convenient method to serialize and deserialize classes, but with additional requirements and a reduced feature set. To give an example: - -```cpp -struct Foo { - int i; - auto auto_serialize() { - return make_auto_serialization("Foo", NVP("i", i)); - } -}; -``` - -The auto-serialization effectively only supports the `add_element` and `expect_element` operations defined in the Concepts section, where the function arguments are provided by the name-value pairs (NVPs). Note that the value part of an NVP is also used to assign a value during deserialization, hence the class members must be used directly in the NVP constructor (i.e. as a non-const lvalue reference). - -The requirements for auto-serialization are: -- The class must be trivially constructible. - - Alternatively, you may provide a spezialisation of the struct `AutoSerializableFactory`. -- There is exactly one NVP for every class member (but the names and their order is arbitrary). - - Values must be passed directly. -- Every class member itself is both (auto-)serializable and assignable. - ## The command line interface We provide a function `mio::command_line_interface` in the header `memilio/io/cli.h`, that can be used to write to or read from a parameter set. It can take parameters from command line arguments (i.e. the content of `argv` in the main function), and assign them to or get them from a `mio::ParameterSet`. A small example can be seen in `cpp/examples/cli.cpp`. diff --git a/cpp/memilio/io/auto_serialize.h b/cpp/memilio/io/auto_serialize.h index 371157c25c..8c3dbb5f90 100644 --- a/cpp/memilio/io/auto_serialize.h +++ b/cpp/memilio/io/auto_serialize.h @@ -41,6 +41,10 @@ namespace mio template struct NVP { using Type = ValueType&; + + const std::string_view name; + Type value; + /** * @brief Create a (name, value) pair. * @@ -52,20 +56,12 @@ struct NVP { , value(v) { } - const std::string_view name; - Type value; NVP() = delete; NVP(const NVP&) = default; NVP(NVP&&) = default; NVP& operator=(const NVP&) = delete; NVP& operator=(NVP&&) = delete; - - NVP& operator=(const ValueType& v) - { - this->value = v; - return *this; - } }; /** @@ -118,18 +114,6 @@ namespace details template using auto_serialize_expr_t = decltype(std::declval().auto_serialize()); -} // namespace details - -/** - * @brief Detect whether T has a auto_serialize member function. - * @tparam T Any type. - */ -template -using has_auto_serialize = is_expression_valid; - -namespace details -{ - /// Add a name-value pair to an io object. template void add_nvp(IOObject& obj, const NVP nvp) @@ -170,6 +154,13 @@ IOResult auto_deserialize_impl(IOContext& io, AutoSerializable } // namespace details +/** + * @brief Detect whether T has a auto_serialize member function. + * @tparam T Any type. + */ +template +using has_auto_serialize = is_expression_valid; + /** * @brief Serialization implementation for the auto-serialization feature. * Disables itself (SFINAE) if there is no auto_serialize member or if a serialize member is present. diff --git a/cpp/memilio/math/time_series_functor.h b/cpp/memilio/math/time_series_functor.h index 74d745a6dc..a2f99110e6 100644 --- a/cpp/memilio/math/time_series_functor.h +++ b/cpp/memilio/math/time_series_functor.h @@ -35,20 +35,36 @@ template class TimeSeriesFunctor { public: + /** + * @brief Type of a TimeSeriesFunctor. + * The available types are: + * - Zero: + * - No data used. + * - Always returns 0. + * - LinearInterpolation: + * - Requires at least one time point with exactly one value each. Time must be strictly monotic increasing. + * - Linearly interpolates between data points. Stays constant with first/last value outside of provided data. + */ enum Type { Zero, LinearInterpolation, }; + /** + * @brief Creates a functor using the given data. + * Note the data requirements of the given functor type. + * @param type The type of the functor. + * @param table A list of time points, passed to the TimeSeries constructor. + */ TimeSeriesFunctor(Type type, const TimeSeries& data) : m_type(type) , m_data(data) { - // data preprocessing + // data shape checks and preprocessing switch (m_type) { case Type::Zero: - // no preprocessing needed + // no checks needed break; case Type::LinearInterpolation: // make sure data has the correct shape, i.e. a list of (time, value) pairs @@ -58,16 +74,30 @@ class TimeSeriesFunctor } } + /** + * @brief Creates a functor using the given table. + * Note the data requirements of the given functor type. + * @param type The type of the functor. + * @param table A list of time points, passed to the TimeSeries constructor. + */ TimeSeriesFunctor(Type type, std::vector>&& table) : TimeSeriesFunctor(type, TimeSeries{table}) { } + /** + * @brief Creates a Zero functor. + */ TimeSeriesFunctor() : TimeSeriesFunctor(Type::Zero, TimeSeries{0}) { } + /** + * @brief Function returning a scalar value. + * @param time A scalar value. + * @return A scalar value computed from data, depending on the functor's type. + */ FP operator()(FP time) const { FP value = 0.0; @@ -76,8 +106,8 @@ class TimeSeriesFunctor // value is explicitly zero-initialized break; case Type::LinearInterpolation: + auto tp_range = m_data.get_times(); // find next time point in m_data (strictly) after time - auto tp_range = m_data.get_times(); const auto next_tp = std::upper_bound(tp_range.begin(), tp_range.end(), time, [](auto&& t, auto&& tp) { return t < tp; }); @@ -104,8 +134,8 @@ class TimeSeriesFunctor } private: - Type m_type; - TimeSeries m_data; + Type m_type; ///< Determines what kind of functor this is, e.g. linear interpolation. + TimeSeries m_data; ///< Data used by the functor to compute its values. Its shape depends on type. }; } // namespace mio diff --git a/cpp/memilio/utils/time_series.h b/cpp/memilio/utils/time_series.h index d835cdba91..883b879c7d 100644 --- a/cpp/memilio/utils/time_series.h +++ b/cpp/memilio/utils/time_series.h @@ -105,21 +105,26 @@ class TimeSeries /** * @brief Initialize a TimeSeries with a table. * @param table Consists of a list of time points, each of the form (time, value_0, value_1, ..., value_n) for - * some fixed n >= 0. + * some fixed n >= 0. */ TimeSeries(std::vector> table) - : m_data() + : m_data() // resized in body , m_num_time_points(table.size()) { - assert(table.size() > 0); - assert(std::all_of(table.begin(), table.end(), [&table](auto&& a) { - return a.size() == table.front().size(); - })); + // check table sizes + assert(table.size() > 0 && "At least one entry is required to determine the number of elements."); + assert(std::all_of(table.begin(), table.end(), + [&table](auto&& a) { + return a.size() == table.front().size(); + }) && + "All table entries must have the same size."); + // resize data. note that the table entries contain both time and values m_data.resize(table.front().size(), table.size()); - // sort by time + // sort table by time std::sort(table.begin(), table.end(), [](auto&& a, auto&& b) { return a[0] < b[0]; }); + // assign table to data for (Eigen::Index tp = 0; tp < m_data.cols(); tp++) { for (Eigen::Index i = 0; i < m_data.rows(); i++) { m_data(i, tp) = table[tp][i]; diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index a71e84b25f..33dc3c215b 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -179,23 +179,7 @@ class TestingScheme */ bool run_scheme(PersonalRandomNumberGenerator& rng, Person& person, TimePoint t) const; - // /** - // * serialize this. - // * @see mio::serialize - // */ - // template - // void serialize(IOContext& io) const - // { - // auto obj = io.create_object("TestingScheme"); - // obj.add_element("criteria", m_testing_criteria), ; - // obj.add_element("min_time_since_last_test", m_minimal_time_since_last_test), ; - // obj.add_element("start_date", m_start_date), ; - // obj.add_element("end_date", m_end_date), ; - // obj.add_element("test_params", m_test_parameters), ; - // obj.add_element("probability", m_probability), ; - // obj.add_element("is_active", m_is_active), ; - // } - + /// This method is used by the auto-serialization feature. auto auto_serialize() { return make_auto_serialization("TestingScheme", NVP("criteria", m_testing_criteria), @@ -205,32 +189,6 @@ class TestingScheme NVP("is_active", m_is_active)); } - // /** - // * deserialize an object of this class. - // * @see mio::deserialize - // */ - // template - // static IOResult deserialize(IOContext& io) - // { - // auto obj = io.expect_object("TestingScheme"); - // auto criteria = obj.expect_element("criteria", Tag{}); - // auto min_time_since_last_test = obj.expect_element("min_time_since_last_test", Tag{}); - // auto start_date = obj.expect_element("start_date", Tag{}); - // auto end_date = obj.expect_element("end_date", Tag{}); - // auto test_params = obj.expect_element("test_params", Tag{}); - // auto probability = obj.expect_element("probability", Tag{}); - // auto is_active = obj.expect_element("is_active", Tag{}); - // return apply( - // io, - // [](auto&& criteria_, auto&& min_time_since_last_test_, auto&& start_date_, auto&& end_date_, - // auto&& test_params_, auto&& probability_, auto&& is_active_) { - // return TestingScheme{ - // criteria_, min_time_since_last_test_, start_date_, end_date_, test_params_, probability_, - // is_active_}; - // }, - // criteria, min_time_since_last_test, start_date, end_date, test_params, probability, is_active); - // } - private: friend AutoSerializableFactory; TestingScheme() = default; diff --git a/cpp/models/abm/trip_list.h b/cpp/models/abm/trip_list.h index bdfd361a7f..ecb01d4f48 100644 --- a/cpp/models/abm/trip_list.h +++ b/cpp/models/abm/trip_list.h @@ -192,40 +192,11 @@ class TripList return m_current_index; } - /** - * serialize this. - * @see mio::serialize - */ - template - void serialize(IOContext& io) const + /// This method is used by the auto-serialization feature. + auto auto_serialize() { - auto obj = io.create_object("TripList"); - obj.add_list("trips_weekday", m_trips_weekday.cbegin(), m_trips_weekday.cend()); - obj.add_list("trips_weekend", m_trips_weekend.cbegin(), m_trips_weekend.cend()); - obj.add_element("index", m_current_index); - } - - /** - * deserialize an object of this class. - * @see mio::deserialize - */ - template - static IOResult deserialize(IOContext& io) - { - auto obj = io.expect_object("TripList"); - auto trips_wd = obj.expect_list("trips_weekday", Tag{}); - auto trips_we = obj.expect_list("trips_weekend", Tag{}); - auto index = obj.expect_element("index", Tag{}); - return apply( - io, - [](auto&& trips_wd_, auto&& trips_we_, auto&& index_) { - TripList tl; - tl.m_trips_weekday = trips_wd_; - tl.m_trips_weekend = trips_we_; - tl.m_current_index = index_; - return tl; - }, - trips_wd, trips_we, index); + return make_auto_serialization("TestingScheme", NVP("trips_weekday", m_trips_weekday), + NVP("trips_weekend", m_trips_weekend), NVP("index", m_current_index)); } private: From 57af8ba6b85e2fd9c5eeec7d7562d692125047c7 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Thu, 8 Aug 2024 17:29:31 +0200 Subject: [PATCH 22/42] combine PickleType specialization for small ints. was ambiguous when serializing abm::TestType --- cpp/models/abm/parameters.h | 12 +++++++----- .../memilio/simulation/pickle_serializer.h | 8 ++------ 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/cpp/models/abm/parameters.h b/cpp/models/abm/parameters.h index 10411efe9c..1748fc6746 100644 --- a/cpp/models/abm/parameters.h +++ b/cpp/models/abm/parameters.h @@ -38,21 +38,23 @@ #include #include +#include namespace mio { -template -void serialize_internal(IOContext& io, const UniformDistribution::ParamType& p) +template ::ParamType, T>, void*> = nullptr> +void serialize_internal(IOContext& io, const T& p) { auto obj = io.create_object("UniformDistributionParams"); obj.add_element("a", p.params.a()); obj.add_element("b", p.params.b()); } -template -IOResult::ParamType> deserialize_internal(IOContext& io, - Tag::ParamType>) +template ::ParamType, T>, void*> = nullptr> +IOResult::ParamType> deserialize_internal(IOContext& io, Tag) { auto obj = io.expect_object("UniformDistributionParams"); auto a = obj.expect_element("a", Tag{}); diff --git a/pycode/memilio-simulation/memilio/simulation/pickle_serializer.h b/pycode/memilio-simulation/memilio/simulation/pickle_serializer.h index 8ad6975235..88a93a22fb 100644 --- a/pycode/memilio-simulation/memilio/simulation/pickle_serializer.h +++ b/pycode/memilio-simulation/memilio/simulation/pickle_serializer.h @@ -52,14 +52,10 @@ struct PickleType : std::true_type { template using is_small_integral = std::integral_constant::value && sizeof(T) <= 4)>; -//signed small ints -template -struct PickleType::value>> : std::true_type { -}; -//unsigned small ints +//small ints template -struct PickleType, std::is_unsigned>>> : std::true_type { +struct PickleType::value>> : std::true_type { }; //signed big ints From 49034b99821e6f4214d96421d4313b5c8630a889 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Fri, 9 Aug 2024 14:05:55 +0200 Subject: [PATCH 23/42] add serialization test for TestResult --- cpp/tests/test_abm_serialization.cpp | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/cpp/tests/test_abm_serialization.cpp b/cpp/tests/test_abm_serialization.cpp index 4644c05b16..75c821c83f 100644 --- a/cpp/tests/test_abm_serialization.cpp +++ b/cpp/tests/test_abm_serialization.cpp @@ -13,6 +13,8 @@ #include "models/abm/time.h" #include "models/abm/trip_list.h" #include "models/abm/model.h" +#include "json/config.h" +#include "json/value.h" #ifdef MEMILIO_HAS_JSONCPP @@ -96,8 +98,8 @@ void test_json_serialization_full(const T& reference_object, const Json::Value& TEST(TestAbmSerialization, Trip) { - // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. - // See test_json_serialization_by_representation for more detail. + // Test (de)serialization w.r.t json representation and the types own equality operator. + // See test_json_serialization_full for more detail. mio::abm::Trip trip(1, mio::abm::TimePoint(0) + mio::abm::hours(2), 3, 4); @@ -149,8 +151,8 @@ TEST(TestAbmSerialization, Infection) TEST(TestAbmSerialization, TestingScheme) { - // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. - // See test_json_serialization_by_representation for more detail. + // Test (de)serialization w.r.t json representation and the types own equality operator. + // See test_json_serialization_full for more detail. mio::abm::TestingScheme testing_scheme( mio::abm::TestingCriteria({mio::AgeGroup(1)}, {mio::abm::InfectionState(2)}), mio::abm::TimeSpan(3), @@ -197,6 +199,18 @@ TEST(TestAbmSerialization, TestingStrategy) test_json_serialization_by_representation(reference_json); } +TEST(TestAbmSerialization, TestResult) +{ + // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. + // See test_json_serialization_by_representation for more detail. + + Json::Value reference_json; // aka x + reference_json["result"] = Json::Value(false); + reference_json["time_of_testing"]["seconds"] = Json::UInt(1); + + test_json_serialization_by_representation(reference_json); +} + TEST(TestAbmSerialization, Person) { // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. From cef8cba8159618fc4a463c71157f233217969db0 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Thu, 15 Aug 2024 19:44:23 +0200 Subject: [PATCH 24/42] implement most review suggestions --- cpp/memilio/io/README.md | 14 +- cpp/memilio/io/auto_serialize.h | 147 +++++++++-------- cpp/memilio/io/io.h | 52 ++++++ cpp/memilio/math/interpolation.h | 47 +++++- cpp/memilio/math/time_series_functor.h | 78 ++++----- cpp/memilio/utils/random_number_generator.h | 31 +++- cpp/memilio/utils/time_series.h | 6 +- cpp/models/abm/infection.h | 21 ++- cpp/models/abm/location.h | 19 ++- cpp/models/abm/mask.h | 4 +- cpp/models/abm/parameters.h | 48 ++---- cpp/models/abm/person.h | 48 +++--- cpp/models/abm/test_type.h | 2 +- cpp/models/abm/testing_strategy.h | 50 ++---- cpp/models/abm/time.h | 4 +- cpp/models/abm/trip_list.h | 53 +++--- cpp/models/abm/vaccine.h | 4 +- cpp/simulations/abm.cpp | 3 +- cpp/simulations/abm_braunschweig.cpp | 50 +++--- cpp/tests/CMakeLists.txt | 16 +- cpp/tests/matchers.cpp | 64 ++++++++ cpp/tests/matchers.h | 57 +++++++ cpp/tests/random_number_test.h | 53 ++++++ cpp/tests/test_abm_infection.cpp | 20 ++- cpp/tests/test_abm_serialization.cpp | 170 ++++++++++---------- cpp/tests/test_math_time_series_functor.cpp | 92 ++++++----- 26 files changed, 688 insertions(+), 465 deletions(-) create mode 100644 cpp/tests/matchers.cpp create mode 100644 cpp/tests/random_number_test.h diff --git a/cpp/memilio/io/README.md b/cpp/memilio/io/README.md index ef07d54833..5a21a4ffbe 100644 --- a/cpp/memilio/io/README.md +++ b/cpp/memilio/io/README.md @@ -40,22 +40,22 @@ serialize and deserialize functions. To give an example: struct Foo { int i; auto auto_serialize() { - return make_auto_serialization("Foo", NVP("i", i)); + return Members("Foo").add("i", i); } }; ``` The auto-serialization is less flexible than the serialize and deserialize functions and has additional requirements: - The class must be trivially constructible. - - Alternatively, you may provide a specialization of the struct `AutoSerializableFactory`. For more details, + - Alternatively, you may provide a specialization of the struct `DefaultFactory`. For more details, view the struct's documentation. -- There is exactly one NVP for every class member (though the names and their order are arbitrary). - - Values must be passed directly, like in the example. No copies, accessors, etc. +- Every class member must be added to Members exactly once (though the names and their order are arbitrary). + - The members must be passed directly, like in the example. No copies, accessors, etc. - Every class member itself is both (auto-)(de)serializable and assignable. As to the feature set, auto-serialization only supports the `add_element` and `expect_element` operations defined in -the Concepts section below, where each operation's arguments are provided by the name-value pairs (NVPs). Note that the -value part of an NVP is also used to assign a value during deserialization, hence the class members must be used -directly in the NVP constructor (i.e. as a non-const lvalue reference). +the Concepts section below, where each operation's arguments are provided through the `add` function. Note that the +value provided to `add` is also used to assign a value during deserialization, hence the class members must be used +directly in the function (i.e. as a non-const lvalue reference). ### Concepts diff --git a/cpp/memilio/io/auto_serialize.h b/cpp/memilio/io/auto_serialize.h index 8c3dbb5f90..ad14e2abef 100644 --- a/cpp/memilio/io/auto_serialize.h +++ b/cpp/memilio/io/auto_serialize.h @@ -23,8 +23,6 @@ #include "memilio/io/io.h" #include "memilio/utils/metaprogramming.h" -#include -#include #include #include #include @@ -33,74 +31,75 @@ namespace mio { /** - * @brief Pair of name and value used for auto-(de)serialization. - * - * This object holds a view of a name and reference of a value. Mind their lifetime! + * @brief A pair of name and reference. + * + * Used for auto-(de)serialization. + * This object holds a pointer to a name and reference to value. Mind their lifetime! * @tparam ValueType The (non-cv, non-reference) type of the value. */ template -struct NVP { - using Type = ValueType&; +struct NamedRef { + using Reference = ValueType&; - const std::string_view name; - Type value; + const char* name; + Reference value; /** - * @brief Create a (name, value) pair. + * @brief Create a named reference. * - * @param n A view of the name. + * @param n A string literal. * @param v A non-const lvalue reference to the value. */ - explicit NVP(const std::string_view n, Type v) + explicit NamedRef(const char* n, Reference v) : name(n) , value(v) { } - - NVP() = delete; - NVP(const NVP&) = default; - NVP(NVP&&) = default; - NVP& operator=(const NVP&) = delete; - NVP& operator=(NVP&&) = delete; }; -/** - * @brief Provide names and values for auto-(de)serialization. - * - * This function packages the class name and a name-value pair for each class member together to define both a - * serialize and deserialize function (with limited features). - * - * Note that auto-serialization requires that all class members participate in serialization, and that - * each class member is (auto-)serializable and assignable. - * - * @tparam Targets List of each class member's type. - * @param class_name The name of the class to auto-serialize. - * @param class_members A name-value pair (NVP) for each class member. - * @return Collection of all name views and value references used for auto-(de)serialization. - */ -template -[[nodiscard]] inline auto make_auto_serialization(const std::string_view&& class_name, NVP&&... class_members) -{ - return std::make_pair(class_name, std::make_tuple(class_members...)); -} +template +struct Members { + + Members(const char* class_name) + : name(class_name) + , name_value_pairs() + { + } + + Members(const char* class_name, std::tuple...> nvps) + : name(class_name) + , name_value_pairs(nvps) + { + } + + template + [[nodiscard]] Members add(const char* member_name, T& member) + { + return Members{name, std::tuple_cat(name_value_pairs, std::tuple(NamedRef{member_name, member}))}; + } + + const char* name; + std::tuple...> name_value_pairs; +}; /** - * @brief Creates an instance of AutoSerializable for auto-deserialization. + * @brief Creates an instance of T for later initialization. * - * The default implementation uses the default constructor of AutoSerializable, if available. If there is no default - * constructor, this class must be spezialized to provide the method `static AutoSerializable create()`. If there is - * a default constructor, but it is private, AutoSerializableFactory can be marked as friend instead. + * The default implementation uses the default constructor of T, if available. If there is no default constructor, this + * class can be spezialized to provide the method `static T create()`. If there is a default constructor, but it is + * private, DefaultFactory can be marked as friend instead. * - * The state of the object retured by `create()` is completely arbitrary, as it is expected that auto-deserialization - * will overwrite the value of each class member. + * The state of the object retured by `create()` is completely arbitrary, and may be invalid. Make sure to set it to a + * valid state before using it further. * - * @tparam AutoSerializable A type with an auto_serialize member. + * @tparam T The type to create. */ -template -struct AutoSerializableFactory { - static AutoSerializable create() +template +struct DefaultFactory { + /// @brief Creates a new instance of T. + static T create() { - return AutoSerializable{}; + return T{}; } }; @@ -115,37 +114,37 @@ template using auto_serialize_expr_t = decltype(std::declval().auto_serialize()); /// Add a name-value pair to an io object. -template -void add_nvp(IOObject& obj, const NVP nvp) +template +void add_nvp(IOObject& obj, const NamedRef nvp) { - obj.add_element(std::string{nvp.name}, nvp.value); + obj.add_element(nvp.name, nvp.value); } /// Unpack all name-value pairs from the tuple and add them to a new io object with the given name. -template -void auto_serialize_impl(IOContext& io, const std::string_view name, const NVP... nvps) +template +void auto_serialize_impl(IOContext& io, const char* name, const NamedRef... nvps) { - auto obj = io.create_object(std::string{name}); + auto obj = io.create_object(name); (add_nvp(obj, nvps), ...); } /// Retrieve a name-value pair from an io object. -template -IOResult expect_nvp(IOObject& obj, const NVP nvp) +template +IOResult expect_nvp(IOObject& obj, const NamedRef nvp) { - return obj.expect_element(std::string{nvp.name}, Tag{}); + return obj.expect_element(nvp.name, Tag{}); } /// Read an io object and its members from the io context using the given names and assign the values to a. -template -IOResult auto_deserialize_impl(IOContext& io, AutoSerializable& a, std::string_view name, - NVP... nvps) +template +IOResult auto_deserialize_impl(IOContext& io, AutoSerializable& a, const char* name, + NamedRef... nvps) { - auto obj = io.expect_object(std::string{name}); + auto obj = io.expect_object(name); return apply( io, - [&a, &nvps...](const Targets&... values) { + [&a, &nvps...](const Members&... values) { ((nvps.value = values), ...); return a; }, @@ -176,14 +175,14 @@ template < AutoSerializable*> = nullptr> void serialize_internal(IOContext& io, const AutoSerializable& a) { - // Note that the following cons_cast is only safe if we do not modify targets. - const auto targets = const_cast(a).auto_serialize(); - // unpack targets and serialize + // Note that the following cons_cast is only safe if we do not modify members. + const auto members = const_cast(a).auto_serialize(); + // unpack members and serialize std::apply( - [&io, &targets](auto... nvps) { - details::auto_serialize_impl(io, targets.first, nvps...); + [&io, &members](auto... nvps) { + details::auto_serialize_impl(io, members.name, nvps...); }, - targets.second); + members.name_value_pairs); } /** @@ -203,14 +202,14 @@ template deserialize_internal(IOContext& io, Tag tag) { mio::unused(tag); - AutoSerializable a = AutoSerializableFactory::create(); - auto targets = a.auto_serialize(); - // unpack targets and deserialize + AutoSerializable a = DefaultFactory::create(); + auto members = a.auto_serialize(); + // unpack members and deserialize return std::apply( - [&io, &targets, &a](auto... nvps) { - return details::auto_deserialize_impl(io, a, targets.first, nvps...); + [&io, &members, &a](auto... nvps) { + return details::auto_deserialize_impl(io, a, members.name, nvps...); }, - targets.second); + members.name_value_pairs); } } // namespace mio diff --git a/cpp/memilio/io/io.h b/cpp/memilio/io/io.h index 8347030af9..0092fe9b89 100644 --- a/cpp/memilio/io/io.h +++ b/cpp/memilio/io/io.h @@ -27,6 +27,9 @@ #include "boost/outcome/result.hpp" #include "boost/outcome/try.hpp" #include "boost/optional.hpp" + +#include +#include #include #include @@ -632,6 +635,55 @@ IOResult deserialize_internal(IOContext& io, Tag /*tag*/) rows, cols, elements); } +/** + * @brief Serialize an std::bitset. + * @tparam IOContext A type that models the IOContext concept. + * @tparam N The size of the bitset. + * @param io An IO context. + * @param bitset A bitset to be serialized. + */ +template +void serialize_internal(IOContext& io, const std::bitset bitset) +{ + std::array bits; + for (size_t i = 0; i < N; i++) { + bits[i] = bitset[i]; + } + auto obj = io.create_object("BitSet"); + obj.add_list("bitset", bits.begin(), bits.end()); +} + +/** + * @brief Deserialize an std::bitset. + * @tparam IOContext A type that models the IOContext concept. + * @tparam N The size of the bitset. + * @param io An IO context. + * @param tag Defines the type of the object that is to be deserialized. + * @return The restored object if successful, an error otherwise. + */ +template +IOResult> deserialize_internal(IOContext& io, Tag> tag) +{ + mio::unused(tag); + auto obj = io.expect_object("BitSet"); + auto bits = obj.expect_list("bitset", Tag{}); + if (bits && bits.value().size() != N) { // "!bits" is handled by apply + return failure(StatusCode::InvalidValue, "Incorrent number of booleans to deserialize bitset. Expected " + + std::to_string(N) + ", got " + + std::to_string(bits.value().size()) + "."); + } + return apply( + io, + [](auto&& bits_) { + std::bitset bitset; + for (size_t i = 0; i < N; i++) { + bitset[i] = bits_[i]; + } + return bitset; + }, + bits); +} + /** * serialize an enum value as its underlying type. * @tparam IOContext a type that models the IOContext concept. diff --git a/cpp/memilio/math/interpolation.h b/cpp/memilio/math/interpolation.h index ad08eed373..518ace5d9e 100644 --- a/cpp/memilio/math/interpolation.h +++ b/cpp/memilio/math/interpolation.h @@ -17,11 +17,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef INTERPOLATION_H_ -#define INTERPOLATION_H_ +#ifndef MIO_MATH_INTERPOLATION_H_ +#define MIO_MATH_INTERPOLATION_H_ +#include "memilio/utils/logging.h" +#include "memilio/utils/time_series.h" + +#include #include #include -#include "memilio/utils/logging.h" namespace mio { @@ -34,7 +37,7 @@ namespace mio * @param[in] x_2 Right node of interpolation. * @param[in] y_1 Value at left node. * @param[in] y_2 Value at right node. - * @param[out] unnamed Interpolation result. + * @return Interpolation result. */ template auto linear_interpolation(const X& x_eval, const X& x_1, const X& x_2, const V& y1, const V& y2) @@ -43,12 +46,44 @@ auto linear_interpolation(const X& x_eval, const X& x_1, const X& x_2, const V& return y1 + weight * (y2 - y1); } +/** + * @brief Linear interpolation of a TimeSeries. + * Assumes that the time points are monotonic increasing. If the time series is strictly monotonic, this function is + * continuous in time. + * If time is outside of the provided time points, this function has a constant value of the first/last time point. + * @param[in] time The time at which to evaluate. + * @param[in] data Time points to interpolate. At least one is required. + * @return Interpolation result. + */ +template +typename TimeSeries::Vector linear_interpolation(FP time, const TimeSeries& data) +{ + assert(data.get_num_time_points() > 0 && "Interpolation requires at least one time point."); + auto tp_range = data.get_times(); + // find next time point in data (strictly) after time + const auto next_tp = std::upper_bound(tp_range.begin(), tp_range.end(), time, [](auto&& t, auto&& tp) { + return t < tp; + }); + // interpolate in between values if possible, otherwise return first/last value + if (next_tp == tp_range.begin()) { // time is before first data point + return data.get_value(0); + } + else if (next_tp == tp_range.end()) { // time is past last data point + return data.get_last_value(); + } + else { // time is in between data points + const auto i = next_tp - tp_range.begin(); + return linear_interpolation(time, data.get_time(i - 1), data.get_time(i), data.get_value(i - 1), + data.get_value(i)); + } +} + /** * @brief Linear interpolation between two points of a dataset, which is represented by a vector of pairs of node and value. * Return 0 if there is less than two points in the dataset. * @param[in] vector Vector of pairs of node and value. * @param[in] x_eval Location to evaluate interpolation. - * @param[out] unnamed Interpolation result. + * @return Interpolation result. */ template Y linear_interpolation_of_data_set(std::vector> vector, const X& x_eval) @@ -80,4 +115,4 @@ Y linear_interpolation_of_data_set(std::vector> vector, const X& } // namespace mio -#endif +#endif // MIO_MATH_INTERPOLATION_H_ diff --git a/cpp/memilio/math/time_series_functor.h b/cpp/memilio/math/time_series_functor.h index a2f99110e6..96f9792291 100644 --- a/cpp/memilio/math/time_series_functor.h +++ b/cpp/memilio/math/time_series_functor.h @@ -24,53 +24,49 @@ #include "memilio/math/interpolation.h" #include "memilio/utils/time_series.h" -#include #include #include namespace mio { +/** + * @brief Type of a TimeSeriesFunctor. + * The available types are: + * - LinearInterpolation: + * - Requires at least one time point with exactly one value each. Time must be strictly monotic increasing. + * - Linearly interpolates between data points. Stays constant outside of provided data with first/last value. + */ +enum TimeSeriesFunctorType +{ + LinearInterpolation, +}; + template class TimeSeriesFunctor { public: - /** - * @brief Type of a TimeSeriesFunctor. - * The available types are: - * - Zero: - * - No data used. - * - Always returns 0. - * - LinearInterpolation: - * - Requires at least one time point with exactly one value each. Time must be strictly monotic increasing. - * - Linearly interpolates between data points. Stays constant with first/last value outside of provided data. - */ - enum Type - { - Zero, - LinearInterpolation, - }; - /** * @brief Creates a functor using the given data. * Note the data requirements of the given functor type. * @param type The type of the functor. * @param table A list of time points, passed to the TimeSeries constructor. */ - TimeSeriesFunctor(Type type, const TimeSeries& data) + TimeSeriesFunctor(TimeSeriesFunctorType type, const TimeSeries& data) : m_type(type) , m_data(data) { // data shape checks and preprocessing switch (m_type) { - case Type::Zero: - // no checks needed - break; - case Type::LinearInterpolation: + case TimeSeriesFunctorType::LinearInterpolation: // make sure data has the correct shape, i.e. a list of (time, value) pairs assert(m_data.get_num_time_points() > 0 && "Need at least one time point for LinearInterpolation."); assert(m_data.get_num_elements() == 1 && "LinearInterpolation requires exactly one value per time point."); - assert(m_data.is_sorted()); + assert(m_data.is_strictly_monotonic()); + break; + default: + assert(false && "Unhandled TimeSeriesFunctorType!"); + break; } } @@ -80,7 +76,7 @@ class TimeSeriesFunctor * @param type The type of the functor. * @param table A list of time points, passed to the TimeSeries constructor. */ - TimeSeriesFunctor(Type type, std::vector>&& table) + TimeSeriesFunctor(TimeSeriesFunctorType type, std::vector>&& table) : TimeSeriesFunctor(type, TimeSeries{table}) { } @@ -89,7 +85,7 @@ class TimeSeriesFunctor * @brief Creates a Zero functor. */ TimeSeriesFunctor() - : TimeSeriesFunctor(Type::Zero, TimeSeries{0}) + : TimeSeriesFunctor(TimeSeriesFunctorType::LinearInterpolation, {{FP(0.0), FP(0.0)}}) { } @@ -100,41 +96,23 @@ class TimeSeriesFunctor */ FP operator()(FP time) const { - FP value = 0.0; switch (m_type) { - case Type::Zero: - // value is explicitly zero-initialized - break; - case Type::LinearInterpolation: - auto tp_range = m_data.get_times(); - // find next time point in m_data (strictly) after time - const auto next_tp = std::upper_bound(tp_range.begin(), tp_range.end(), time, [](auto&& t, auto&& tp) { - return t < tp; - }); - if (next_tp == tp_range.begin()) { // time is before first data point - value = m_data.get_value(0)[0]; - } - else if (next_tp == tp_range.end()) { // time is past last data point - value = m_data.get_last_value()[0]; - } - else { // time is in between data points - const auto i = next_tp - tp_range.begin(); - value = linear_interpolation(time, m_data.get_time(i - 1), m_data.get_time(i), - m_data.get_value(i - 1)[0], m_data.get_value(i)[0]); - } - break; + case TimeSeriesFunctorType::LinearInterpolation: + return linear_interpolation(time, m_data)[0]; + default: + assert(false && "Unhandled TimeSeriesFunctorType!"); + return FP(); } - return value; } /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("TimeSeriesFunctor", NVP("type", m_type), NVP("data", m_data)); + return Members("TimeSeriesFunctor").add("type", m_type).add("data", m_data); } private: - Type m_type; ///< Determines what kind of functor this is, e.g. linear interpolation. + TimeSeriesFunctorType m_type; ///< Determines what kind of functor this is, e.g. linear interpolation. TimeSeries m_data; ///< Data used by the functor to compute its values. Its shape depends on type. }; diff --git a/cpp/memilio/utils/random_number_generator.h b/cpp/memilio/utils/random_number_generator.h index bc1b2bbc77..8b22f64e0f 100644 --- a/cpp/memilio/utils/random_number_generator.h +++ b/cpp/memilio/utils/random_number_generator.h @@ -361,8 +361,7 @@ class RandomNumberGenerator : public RandomNumberGeneratorBase using UniformDistribution = DistributionAdapter>; +template ::ParamType>, + void*> = nullptr> +void serialize_internal(IOContext& io, const UniformDistributionParams& p) +{ + auto obj = io.create_object("UniformDistributionParams"); + obj.add_element("a", p.params.a()); + obj.add_element("b", p.params.b()); +} + +template ::ParamType>, + void*> = nullptr> +IOResult deserialize_internal(IOContext& io, Tag) +{ + auto obj = io.expect_object("UniformDistributionParams"); + auto a = obj.expect_element("a", Tag{}); + auto b = obj.expect_element("b", Tag{}); + return apply( + io, + [](auto&& a_, auto&& b_) { + return UniformDistributionParams{a_, b_}; + }, + a, b); +} + /** * adapted poisson_distribution. * @see DistributionAdapter diff --git a/cpp/memilio/utils/time_series.h b/cpp/memilio/utils/time_series.h index 883b879c7d..9b98ec7c02 100644 --- a/cpp/memilio/utils/time_series.h +++ b/cpp/memilio/utils/time_series.h @@ -119,7 +119,9 @@ class TimeSeries }) && "All table entries must have the same size."); // resize data. note that the table entries contain both time and values - m_data.resize(table.front().size(), table.size()); + m_data.resize(table.front().size(), 0); // set colums first so reserve allocates correctly + reserve(table.size()); // reserve needs to happen before setting the number of rows + m_data.resize(Eigen::NoChange, table.size()); // finalize resize by setting the rows // sort table by time std::sort(table.begin(), table.end(), [](auto&& a, auto&& b) { return a[0] < b[0]; @@ -179,7 +181,7 @@ class TimeSeries TimeSeries& operator=(TimeSeries&& other) = default; /// Check if the time is strictly monotonic increasing. - bool is_sorted() + bool is_strictly_monotonic() const { const auto times = get_times(); auto time_itr = times.begin(); diff --git a/cpp/models/abm/infection.h b/cpp/models/abm/infection.h index 4c2ed0499b..8d647159cd 100644 --- a/cpp/models/abm/infection.h +++ b/cpp/models/abm/infection.h @@ -49,8 +49,12 @@ struct ViralLoad { /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("ViralLoad", NVP("start_date", start_date), NVP("end_date", end_date), - NVP("peak", peak), NVP("incline", incline), NVP("decline", decline)); + return Members("ViralLoad") + .add("start_date", start_date) + .add("end_date", end_date) + .add("peak", peak) + .add("incline", incline) + .add("decline", decline); } }; @@ -124,14 +128,17 @@ class Infection /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("Infection", NVP("infection_course", m_infection_course), - NVP("virus_variant", m_virus_variant), NVP("viral_load", m_viral_load), - NVP("log_norm_alpha", m_log_norm_alpha), NVP("log_norm_beta", m_log_norm_beta), - NVP("detected", m_detected)); + return Members("Infection") + .add("infection_course", m_infection_course) + .add("virus_variant", m_virus_variant) + .add("viral_load", m_viral_load) + .add("log_norm_alpha", m_log_norm_alpha) + .add("log_norm_beta", m_log_norm_beta) + .add("detected", m_detected); } private: - friend AutoSerializableFactory; + friend DefaultFactory; Infection() = default; /** diff --git a/cpp/models/abm/location.h b/cpp/models/abm/location.h index 0c6936a526..acfa4ac5ef 100644 --- a/cpp/models/abm/location.h +++ b/cpp/models/abm/location.h @@ -53,7 +53,7 @@ struct GeographicalLocation { /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("GraphicalLocation", NVP("latitude", latitude), NVP("longitude", longitude)); + return Members("GraphicalLocation").add("latitude", latitude).add("longitude", longitude); } }; @@ -84,7 +84,7 @@ struct CellCapacity { /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("CellCapacity", NVP("volume", volume), NVP("persons", persons)); + return Members("CellCapacity").add("volume", volume).add("persons", persons); } }; @@ -104,7 +104,7 @@ struct Cell { /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("Cell", NVP("capacity", m_capacity)); + return Members("Cell").add("capacity", m_capacity); } }; // namespace mio @@ -272,14 +272,17 @@ class Location /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("Location", NVP("id", m_id), NVP("parameters", m_parameters), - NVP("cells", m_cells), NVP("required_mask", m_required_mask), - NVP("npi_active", m_npi_active), - NVP("geographical_location", m_geographical_location)); + return Members("Location") + .add("id", m_id) + .add("parameters", m_parameters) + .add("cells", m_cells) + .add("required_mask", m_required_mask) + .add("npi_active", m_npi_active) + .add("geographical_location", m_geographical_location); } private: - friend AutoSerializableFactory; + friend DefaultFactory; Location() = default; LocationType m_type; ///< Type of the Location. diff --git a/cpp/models/abm/mask.h b/cpp/models/abm/mask.h index f3697d7425..223e56ce6a 100644 --- a/cpp/models/abm/mask.h +++ b/cpp/models/abm/mask.h @@ -76,7 +76,7 @@ class Mask /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("Mask", NVP("mask_type", m_type), NVP("time_used", m_time_used)); + return Members("Mask").add("mask_type", m_type).add("time_used", m_time_used); } private: @@ -87,7 +87,7 @@ class Mask /// @brief Creates an instance of abm::Mask for auto-deserialization. template <> -struct AutoSerializableFactory { +struct DefaultFactory { static abm::Mask create() { return abm::Mask(abm::MaskType::Count); diff --git a/cpp/models/abm/parameters.h b/cpp/models/abm/parameters.h index 1748fc6746..19d03105ff 100644 --- a/cpp/models/abm/parameters.h +++ b/cpp/models/abm/parameters.h @@ -38,35 +38,10 @@ #include #include -#include +#include namespace mio { - -template ::ParamType, T>, void*> = nullptr> -void serialize_internal(IOContext& io, const T& p) -{ - auto obj = io.create_object("UniformDistributionParams"); - obj.add_element("a", p.params.a()); - obj.add_element("b", p.params.b()); -} - -template ::ParamType, T>, void*> = nullptr> -IOResult::ParamType> deserialize_internal(IOContext& io, Tag) -{ - auto obj = io.expect_object("UniformDistributionParams"); - auto a = obj.expect_element("a", Tag{}); - auto b = obj.expect_element("b", Tag{}); - return apply( - io, - [](auto&& a_, auto&& b_) { - return UniformDistribution::ParamType{a_, b_}; - }, - a, b); -} - namespace abm { @@ -205,9 +180,10 @@ struct ViralLoadDistributionsParameters { /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("ViralLoadDistributionsParameters", NVP("viral_load_peak", viral_load_peak), - NVP("viral_load_incline", viral_load_incline), - NVP("viral_load_decline", viral_load_decline)); + return Members("ViralLoadDistributionsParameters") + .add("viral_load_peak", viral_load_peak) + .add("viral_load_incline", viral_load_incline) + .add("viral_load_decline", viral_load_decline); } }; @@ -236,9 +212,9 @@ struct InfectivityDistributionsParameters { /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("InfectivityDistributionsParameters", - NVP("infectivity_alpha", infectivity_alpha), - NVP("infectivity_beta", infectivity_beta)); + return Members("InfectivityDistributionsParameters") + .add("infectivity_alpha", infectivity_alpha) + .add("infectivity_beta", infectivity_beta); } }; @@ -359,9 +335,11 @@ struct TestParameters { /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("TestParameters", NVP("sensitivity", sensitivity), - NVP("specificity", specificity), NVP("required_time", required_time), - NVP("test_type", type)); + return Members("TestParameters") + .add("sensitivity", sensitivity) + .add("specificity", specificity) + .add("required_time", required_time) + .add("test_type", type); } }; diff --git a/cpp/models/abm/person.h b/cpp/models/abm/person.h index 57589d302c..342f13cccb 100755 --- a/cpp/models/abm/person.h +++ b/cpp/models/abm/person.h @@ -391,31 +391,27 @@ class Person /// This method is used by the auto-serialization feature. auto auto_serialize() { - // clang-format off - return make_auto_serialization( - "Person", - NVP("location", m_location), - NVP("location_type", m_location_type), - NVP("assigned_locations", m_assigned_locations), - NVP("vaccinations", m_vaccinations), - NVP("infections", m_infections), - NVP("quarantine_start",m_quarantine_start), - NVP("age_group", m_age), - NVP("time_at_location", m_time_at_location), - NVP("rnd_workgroup", m_random_workgroup), - NVP("rnd_schoolgroup", m_random_schoolgroup), - NVP("rnd_go_to_work_hour", m_random_goto_work_hour), - NVP("rnd_go_to_school_hour", m_random_goto_school_hour), - NVP("mask", m_mask), - NVP("wears_mask", m_wears_mask), - NVP("mask_compliance", m_mask_compliance), - NVP("id", m_person_id), - NVP("cells", m_cells), - NVP("last_transport_mode", m_last_transport_mode), - NVP("rng_counter", m_rng_counter), - NVP("test_results", m_test_results) - ); - // clang-format on + return Members("Person") + .add("location", m_location) + .add("location_type", m_location_type) + .add("assigned_locations", m_assigned_locations) + .add("vaccinations", m_vaccinations) + .add("infections", m_infections) + .add("quarantine_start", m_quarantine_start) + .add("age_group", m_age) + .add("time_at_location", m_time_at_location) + .add("rnd_workgroup", m_random_workgroup) + .add("rnd_schoolgroup", m_random_schoolgroup) + .add("rnd_go_to_work_hour", m_random_goto_work_hour) + .add("rnd_go_to_school_hour", m_random_goto_school_hour) + .add("mask", m_mask) + .add("wears_mask", m_wears_mask) + .add("mask_compliance", m_mask_compliance) + .add("id", m_person_id) + .add("cells", m_cells) + .add("last_transport_mode", m_last_transport_mode) + .add("rng_counter", m_rng_counter) + .add("test_results", m_test_results); } /** @@ -462,7 +458,7 @@ class Person /// @brief Creates an instance of abm::Person for auto-deserialization. template <> -struct AutoSerializableFactory { +struct DefaultFactory { static abm::Person create() { return abm::Person(thread_local_rng(), abm::LocationType::Count, abm::LocationId(), AgeGroup(0), diff --git a/cpp/models/abm/test_type.h b/cpp/models/abm/test_type.h index d9073d443e..823f6f26f9 100644 --- a/cpp/models/abm/test_type.h +++ b/cpp/models/abm/test_type.h @@ -52,7 +52,7 @@ struct TestResult { /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("TestResult", NVP("time_of_testing", time_of_testing), NVP("result", result)); + return Members("TestResult").add("time_of_testing", time_of_testing).add("result", result); } }; diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index 184c37fb4b..0d58222b3a 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -94,37 +94,9 @@ class TestingCriteria */ bool evaluate(const Person& p, TimePoint t) const; - /** - * serialize this. - * @see mio::serialize - */ - template - void serialize(IOContext& io) const - { - auto obj = io.create_object("TestingCriteria"); - obj.add_element("ages", m_ages.to_ulong()); - obj.add_element("infection_states", m_infection_states.to_ulong()); - } - - /** - * deserialize an object of this class. - * @see mio::deserialize - */ - template - static IOResult deserialize(IOContext& io) + auto auto_serialize() { - auto obj = io.expect_object("TestingCriteria"); - auto ages = obj.expect_element("ages", Tag{}); - auto infection_states = obj.expect_element("infection_states", Tag{}); - return apply( - io, - [](auto&& ages_, auto&& infection_states_) { - TestingCriteria c; - c.m_ages = ages_; - c.m_infection_states = infection_states_; - return c; - }, - ages, infection_states); + return Members("TestingCriteria").add("ages", m_ages).add("infection_states", m_infection_states); } private: @@ -181,14 +153,18 @@ class TestingScheme /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("TestingScheme", NVP("criteria", m_testing_criteria), - NVP("validity_period", m_validity_period), NVP("start_date", m_start_date), - NVP("end_date", m_end_date), NVP("test_params", m_test_parameters), - NVP("probability", m_probability), NVP("is_active", m_is_active)); + return Members("TestingScheme") + .add("criteria", m_testing_criteria) + .add("validity_period", m_validity_period) + .add("start_date", m_start_date) + .add("end_date", m_end_date) + .add("test_params", m_test_parameters) + .add("probability", m_probability) + .add("is_active", m_is_active); } private: - friend AutoSerializableFactory; + friend DefaultFactory; TestingScheme() = default; TestingCriteria m_testing_criteria; ///< TestingCriteria of the scheme. @@ -218,7 +194,7 @@ class TestingStrategy /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("LocalStrategy", NVP("type", type), NVP("id", id), NVP("schemes", schemes)); + return Members("LocalStrategy").add("type", type).add("id", id).add("schemes", schemes); } }; @@ -288,7 +264,7 @@ class TestingStrategy /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("TestingStrategy", NVP("schemes", m_location_to_schemes_map)); + return Members("TestingStrategy").add("schemes", m_location_to_schemes_map); } private: diff --git a/cpp/models/abm/time.h b/cpp/models/abm/time.h index 1f3bb2ca62..917b80ecca 100644 --- a/cpp/models/abm/time.h +++ b/cpp/models/abm/time.h @@ -148,7 +148,7 @@ class TimeSpan /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("TimeSpan", NVP("seconds", m_seconds)); + return Members("TimeSpan").add("seconds", m_seconds); } private: @@ -295,7 +295,7 @@ class TimePoint /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("TimePoint", NVP("seconds", m_seconds)); + return Members("TimePoint").add("seconds", m_seconds); } private: diff --git a/cpp/models/abm/trip_list.h b/cpp/models/abm/trip_list.h index ecb01d4f48..d3c9f82221 100644 --- a/cpp/models/abm/trip_list.h +++ b/cpp/models/abm/trip_list.h @@ -24,6 +24,7 @@ #include "abm/mobility_data.h" #include "abm/person_id.h" #include "abm/time.h" +#include "memilio/io/auto_serialize.h" #include namespace mio @@ -89,38 +90,13 @@ struct Trip { (origin == other.origin); } - /** - * serialize this. - * @see mio::serialize - */ - template - void serialize(IOContext& io) const - { - auto obj = io.create_object("Trip"); - obj.add_element("person_id", person_id); - obj.add_element("time", time.seconds()); - obj.add_element("destination", destination); - obj.add_element("origin", origin); - } - - /** - * deserialize an object of this class. - * @see mio::deserialize - */ - template - static IOResult deserialize(IOContext& io) + auto auto_serialize() { - auto obj = io.expect_object("Trip"); - auto person_id = obj.expect_element("person_id", Tag{}); - auto time = obj.expect_element("time", Tag{}); - auto destination_id = obj.expect_element("destination", Tag{}); - auto origin_id = obj.expect_element("origin", Tag{}); - return apply( - io, - [](auto&& person_id_, auto&& time_, auto&& destination_id_, auto&& origin_id_) { - return Trip(person_id_, TimePoint(time_), destination_id_, origin_id_); - }, - person_id, time, destination_id, origin_id); + return Members("Trip") + .add("person_id", person_id) + .add("time", time) + .add("destination", destination) + .add("origin", origin); } }; @@ -195,8 +171,10 @@ class TripList /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("TestingScheme", NVP("trips_weekday", m_trips_weekday), - NVP("trips_weekend", m_trips_weekend), NVP("index", m_current_index)); + return Members("TestingScheme") + .add("trips_weekday", m_trips_weekday) + .add("trips_weekend", m_trips_weekend) + .add("index", m_current_index); } private: @@ -206,6 +184,15 @@ class TripList }; } // namespace abm + +template <> +struct DefaultFactory { + static abm::Trip create() + { + return abm::Trip{abm::PersonId{}, abm::TimePoint{}, abm::LocationId{}}; + } +}; + } // namespace mio #endif diff --git a/cpp/models/abm/vaccine.h b/cpp/models/abm/vaccine.h index 88638a1c57..5f77692b7b 100644 --- a/cpp/models/abm/vaccine.h +++ b/cpp/models/abm/vaccine.h @@ -56,7 +56,7 @@ struct Vaccination { /// This method is used by the auto-serialization feature. auto auto_serialize() { - return make_auto_serialization("Vaccination", NVP("exposure_type", exposure_type), NVP("time", time)); + return Members("Vaccination").add("exposure_type", exposure_type).add("time", time); } ExposureType exposure_type; @@ -67,7 +67,7 @@ struct Vaccination { /// @brief Creates an instance of abm::Vaccination for auto-deserialization. template <> -struct AutoSerializableFactory { +struct DefaultFactory { static abm::Vaccination create() { return abm::Vaccination(abm::ExposureType::Count, abm::TimePoint()); diff --git a/cpp/simulations/abm.cpp b/cpp/simulations/abm.cpp index 1dd0adef18..cc3d1cef5c 100644 --- a/cpp/simulations/abm.cpp +++ b/cpp/simulations/abm.cpp @@ -23,7 +23,6 @@ #include "abm/lockdown_rules.h" #include "memilio/config.h" #include "memilio/io/result_io.h" -#include "memilio/math/interpolation.h" #include "memilio/utils/random_number_generator.h" #include "memilio/utils/uncertain_value.h" @@ -471,7 +470,7 @@ void set_parameters(mio::abm::Parameters params) // Set protection level from high viral load. Information based on: https://doi.org/10.1093/cid/ciaa886 params.get() = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.863}, {1, 0.969}, {7, 0.029}, {10, 0.002}, {14, 0.0014}, {21, 0}}}; //0-4 diff --git a/cpp/simulations/abm_braunschweig.cpp b/cpp/simulations/abm_braunschweig.cpp index a39def895f..5aa486bf2f 100644 --- a/cpp/simulations/abm_braunschweig.cpp +++ b/cpp/simulations/abm_braunschweig.cpp @@ -401,7 +401,7 @@ void set_parameters(mio::abm::Parameters params) // Set protection level from high viral load. Information based on: https://doi.org/10.1093/cid/ciaa886 params.get() = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.863}, {1, 0.969}, {7, 0.029}, {10, 0.002}, {14, 0.0014}, {21, 0}}}; //0-4 @@ -493,7 +493,7 @@ void set_parameters(mio::abm::Parameters params) // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_0_to_4, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.852}, {180, 0.852}, {210, 0.845}, @@ -510,14 +510,14 @@ void set_parameters(mio::abm::Parameters params) // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_0_to_4, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.5}, {30, 0.91}, {60, 0.92}, {90, 0.88}, {120, 0.84}, {150, 0.81}, {180, 0.88}, {450, 0.5}}}; // Set up age-related severe protection levels, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_0_to_4, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.967}, {30, 0.975}, {60, 0.977}, @@ -535,7 +535,7 @@ void set_parameters(mio::abm::Parameters params) // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_0_to_4, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {450, 0.5}}}; //5-14 @@ -553,7 +553,7 @@ void set_parameters(mio::abm::Parameters params) // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_5_to_14, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.852}, {180, 0.852}, {210, 0.845}, @@ -569,14 +569,14 @@ void set_parameters(mio::abm::Parameters params) // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_5_to_14, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.5}, {30, 0.91}, {60, 0.92}, {90, 0.88}, {120, 0.84}, {150, 0.81}, {180, 0.88}, {450, 0.5}}}; // Set up age-related severe protection levels, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_5_to_14, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.967}, {30, 0.975}, {60, 0.977}, @@ -594,7 +594,7 @@ void set_parameters(mio::abm::Parameters params) // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_5_to_14, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {450, 0.5}}}; //15-34 @@ -612,7 +612,7 @@ void set_parameters(mio::abm::Parameters params) // Set up personal infection and vaccine protection levels, based on: https://doi.org/10.1038/s41577-021-00550-x, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.852}, {180, 0.852}, {210, 0.845}, @@ -628,13 +628,13 @@ void set_parameters(mio::abm::Parameters params) // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.5}, {30, 0.89}, {60, 0.84}, {90, 0.78}, {120, 0.68}, {150, 0.57}, {180, 0.39}, {450, 0.1}}}; // Set up age-related severe protection levels, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.967}, {30, 0.975}, {60, 0.977}, @@ -652,7 +652,7 @@ void set_parameters(mio::abm::Parameters params) // Information is from: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {180, 0.90}, {450, 0.5}}}; //35-59 @@ -671,7 +671,7 @@ void set_parameters(mio::abm::Parameters params) // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_35_to_59, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.852}, {180, 0.852}, {210, 0.845}, @@ -687,13 +687,13 @@ void set_parameters(mio::abm::Parameters params) // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_35_to_59, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.5}, {30, 0.89}, {60, 0.84}, {90, 0.78}, {120, 0.68}, {150, 0.57}, {180, 0.39}, {450, 0.1}}}; // Set up age-related severe protection levels, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_35_to_59, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.967}, {30, 0.975}, {60, 0.977}, @@ -711,7 +711,7 @@ void set_parameters(mio::abm::Parameters params) // Information is from: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_35_to_59, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {180, 0.90}, {450, 0.5}}}; //60-79 params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = @@ -729,7 +729,7 @@ void set_parameters(mio::abm::Parameters params) // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_60_to_79, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.852}, {180, 0.852}, {210, 0.845}, @@ -745,14 +745,14 @@ void set_parameters(mio::abm::Parameters params) // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_60_to_79, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.5}, {30, 0.87}, {60, 0.85}, {90, 0.78}, {120, 0.67}, {150, 0.61}, {180, 0.50}, {450, 0.1}}}; // Set up personal severe protection levels. // Protection of severe infection of age group 65 + is different from other age group, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_60_to_79, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.967}, {30, 0.975}, {60, 0.977}, @@ -769,7 +769,7 @@ void set_parameters(mio::abm::Parameters params) {360, 0.5}}}; params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_60_to_79, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.5}, {30, 0.91}, {60, 0.86}, {90, 0.91}, {120, 0.94}, {150, 0.95}, {180, 0.90}, {450, 0.5}}}; //80+ @@ -787,7 +787,7 @@ void set_parameters(mio::abm::Parameters params) // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_80_plus, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.852}, {180, 0.852}, {210, 0.845}, @@ -803,14 +803,14 @@ void set_parameters(mio::abm::Parameters params) // Information is from: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_80_plus, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.5}, {30, 0.80}, {60, 0.79}, {90, 0.75}, {120, 0.56}, {150, 0.49}, {180, 0.43}, {450, 0.1}}}; // Set up personal severe protection levels. // Protection of severe infection of age group 65 + is different from other age group, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_0_to_4, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.967}, {30, 0.975}, {60, 0.977}, @@ -828,7 +828,7 @@ void set_parameters(mio::abm::Parameters params) // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_80_plus, mio::abm::VirusVariant::Wildtype}] = { - mio::TimeSeriesFunctor::Type::LinearInterpolation, + mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.5}, {30, 0.84}, {60, 0.88}, {90, 0.89}, {120, 0.86}, {150, 0.85}, {180, 0.83}, {450, 0.5}}}; } diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 905758783d..4a7be57cb2 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -61,20 +61,22 @@ set(TESTSOURCES test_metaprogramming.cpp test_history.cpp test_utils.cpp - abm_helpers.h - abm_helpers.cpp test_ide_seir.cpp test_ide_secir.cpp test_state_age_function.cpp + test_lct_secir.cpp + test_lct_initializer_flows.cpp + test_ad.cpp + abm_helpers.h + abm_helpers.cpp + actions.h distributions_helpers.h distributions_helpers.cpp - actions.h + matchers.cpp matchers.h - temp_file_register.h + random_number_test.h sanitizers.cpp - test_lct_secir.cpp - test_lct_initializer_flows.cpp - test_ad.cpp + temp_file_register.h ) if(MEMILIO_HAS_JSONCPP) diff --git a/cpp/tests/matchers.cpp b/cpp/tests/matchers.cpp new file mode 100644 index 0000000000..9c222af0f4 --- /dev/null +++ b/cpp/tests/matchers.cpp @@ -0,0 +1,64 @@ +/* +* Copyright (C) 2020-2024 MEmilio +* +* Authors: René Schmieding +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#include "matchers.h" + +#ifdef MEMILIO_HAS_JSONCPP + +void Json::PrintTo(const Value& json, std::ostream* os) +{ + if (json.isObject()) { + // move opening bracket into its own line + *os << "\n"; + } + const static auto js_writer = [] { + StreamWriterBuilder swb; + swb["indentation"] = " "; + return std::unique_ptr(swb.newStreamWriter()); + }(); + js_writer->write(json, os); +} + +std::string json_type_to_string(Json::ValueType t) +{ + using namespace Json; + switch (t) { + case nullValue: + return "Null"; + case intValue: + return "Int"; + case uintValue: + return "UInt"; + case realValue: + return "Real"; + case stringValue: + return "String"; + case booleanValue: + return "Boolean"; + case arrayValue: + return "Array"; + case objectValue: + return "Object"; + default: + assert(false && "Unreachable"); + return ""; + } +} + +#endif // MEMILIO_HAS_JSONCPP diff --git a/cpp/tests/matchers.h b/cpp/tests/matchers.h index 3d89323cfe..4953625f75 100644 --- a/cpp/tests/matchers.h +++ b/cpp/tests/matchers.h @@ -25,6 +25,63 @@ #include "memilio/io/io.h" #include "gmock/gmock.h" +#ifdef MEMILIO_HAS_JSONCPP + +#include "json/json.h" + +namespace Json +{ +void PrintTo(const Value& json, std::ostream* os); +} // namespace Json + +std::string json_type_to_string(Json::ValueType t); + +MATCHER_P(JsonEqual, expected_json, testing::PrintToString(expected_json)) +{ + auto match_rec = [&](auto&& match, const Json::Value& a, const Json::Value& b, std::string name) { + // first check if the types match + if (a.type() != b.type()) { + *result_listener << "type mismatch for " << name << ", expected " << json_type_to_string(a.type()) + << ", actual " << json_type_to_string(b.type()); + return false; + } + // handle object types by recursively matching members + if (a.isObject()) { + for (auto& key : a.getMemberNames()) { + if (!b.isMember(key)) { + *result_listener << "missing key \"" << key << "\" in " << name; + return false; + } + if (!match(match, a[key], b[key], name + "[\"" + key + "\"]")) { + return false; + } + } + } + // handle arrays by recursively matching each item + else if (a.isArray()) { + if (a.size() != b.size()) { + *result_listener << "wrong number of items in " << name << ", expected " << a.size() << ", actual " + << b.size(); + return false; + } + for (Json::ArrayIndex i = 0; i < a.size(); ++i) { + if (!match(match, a[i], b[i], name + "[\"" + std::to_string(i) + "\"]")) { + return false; + } + } + } + // handle value types using Json::Value::operator== + else if (a != b) { + *result_listener << "value mismatch in " << name << ", expected " << testing::PrintToString(a) + << ", actual " << testing::PrintToString(b); + return false; + } + return true; + }; + return match_rec(match_rec, expected_json, arg, "Json::Value"); +} +#endif //MEMILIO_HAS_JSONCPP + /** * @brief overload gtest printer function for eigen matrices. * @note see https://stackoverflow.com/questions/25146997/teach-google-test-how-to-print-eigen-matrix diff --git a/cpp/tests/random_number_test.h b/cpp/tests/random_number_test.h new file mode 100644 index 0000000000..f95307deea --- /dev/null +++ b/cpp/tests/random_number_test.h @@ -0,0 +1,53 @@ +/* +* Copyright (C) 2020-2024 MEmilio +* +* Authors: René Schmieding +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#include "memilio/utils/random_number_generator.h" + +#include + +class RandomNumberTest : public ::testing::Test +{ +public: + /** + * @brief Draws a uniformly distributed random number. + * @tparam FP A floating point type, defaults to double. + * @param[in] min, max Lower and upper bound to the uniform distribution. + * @return A random value between min and max. + */ + template + double random_number(FP min = FP{-1e+3}, FP max = {1e+3}) + { + return mio::UniformDistribution::get_instance()(m_rng, min, max); + } + + /// @brief Access the random number generator. Should only be used for debugging. + mio::RandomNumberGenerator get_rng() + { + return m_rng; + } + +protected: + void SetUp() override + { + log_rng_seeds(m_rng, mio::LogLevel::warn); + } + +private: + mio::RandomNumberGenerator m_rng{}; ///< Seeded rng used by this test fixture. +}; diff --git a/cpp/tests/test_abm_infection.cpp b/cpp/tests/test_abm_infection.cpp index e0a637a5da..69792b1ce3 100644 --- a/cpp/tests/test_abm_infection.cpp +++ b/cpp/tests/test_abm_infection.cpp @@ -78,10 +78,10 @@ TEST(TestInfection, init) EXPECT_NEAR(infection.get_infectivity(mio::abm::TimePoint(0) + mio::abm::days(3)), 0.2689414213699951, 1e-14); params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_test, - virus_variant_test}] = mio::TimeSeriesFunctor{ - mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.91}, {30, 0.81}}}; - params.get() = mio::TimeSeriesFunctor{ - mio::TimeSeriesFunctor::Type::LinearInterpolation, {{0, 0.91}, {30, 0.81}}}; + virus_variant_test}] = + mio::TimeSeriesFunctor{mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.91}, {30, 0.81}}}; + params.get() = + mio::TimeSeriesFunctor{mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.91}, {30, 0.81}}}; auto infection_w_previous_exp = mio::abm::Infection(rng, mio::abm::VirusVariant::Wildtype, age_group_test, params, mio::abm::TimePoint(0), mio::abm::InfectionState::InfectedSymptoms, @@ -191,20 +191,18 @@ TEST(TestInfection, getPersonalProtectiveFactor) // Test linear interpolation with one node params.get()[{mio::abm::ExposureType::GenericVaccine, person.get_age(), mio::abm::VirusVariant::Wildtype}] = - mio::TimeSeriesFunctor{mio::TimeSeriesFunctor::Type::LinearInterpolation, {{2, 0.91}}}; + mio::TimeSeriesFunctor{mio::TimeSeriesFunctorType::LinearInterpolation, {{2, 0.91}}}; auto t = mio::abm::TimePoint(6 * 24 * 60 * 60); // TODO: Discuss: Assumption of interpolation in TDPF is that the function is constant with value at front/back entry outside of [front, back] time range. This works with one node as well and prints no errors EXPECT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0.91, eps); params.get()[{mio::abm::ExposureType::GenericVaccine, person.get_age(), mio::abm::VirusVariant::Wildtype}] = - mio::TimeSeriesFunctor{mio::TimeSeriesFunctor::Type::LinearInterpolation, - {{2, 0.91}, {30, 0.81}}}; + mio::TimeSeriesFunctor{mio::TimeSeriesFunctorType::LinearInterpolation, {{2, 0.91}, {30, 0.81}}}; params.get()[{mio::abm::ExposureType::GenericVaccine, person.get_age(), mio::abm::VirusVariant::Wildtype}] = - mio::TimeSeriesFunctor{mio::TimeSeriesFunctor::Type::LinearInterpolation, - {{2, 0.91}, {30, 0.81}}}; - params.get() = mio::TimeSeriesFunctor{ - mio::TimeSeriesFunctor::Type::LinearInterpolation, {{2, 0.91}, {30, 0.81}}}; + mio::TimeSeriesFunctor{mio::TimeSeriesFunctorType::LinearInterpolation, {{2, 0.91}, {30, 0.81}}}; + params.get() = + mio::TimeSeriesFunctor{mio::TimeSeriesFunctorType::LinearInterpolation, {{2, 0.91}, {30, 0.81}}}; // Test Parameter InfectionProtectionFactor and get_protection_factor() t = mio::abm::TimePoint(0) + mio::abm::days(2); diff --git a/cpp/tests/test_abm_serialization.cpp b/cpp/tests/test_abm_serialization.cpp index 75c821c83f..90d4ad4482 100644 --- a/cpp/tests/test_abm_serialization.cpp +++ b/cpp/tests/test_abm_serialization.cpp @@ -1,9 +1,29 @@ +/* +* Copyright (C) 2020-2024 MEmilio +* +* Authors: René Schmieding +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#include "matchers.h" +#include "abm/config.h" #include "abm/infection_state.h" #include "abm/parameters.h" #include "abm/test_type.h" #include "abm/testing_strategy.h" #include "abm/vaccine.h" -#include "matchers.h" #include "memilio/epidemiology/age_group.h" #include "memilio/io/json_serializer.h" #include "memilio/utils/custom_index_array.h" @@ -13,28 +33,18 @@ #include "models/abm/time.h" #include "models/abm/trip_list.h" #include "models/abm/model.h" -#include "json/config.h" -#include "json/value.h" #ifdef MEMILIO_HAS_JSONCPP -void test_equal_json_representation(const Json::Value& test_json, const Json::Value& reference_json) -{ - // write the resulting json value and the reference value to string to compare their representations. - Json::StreamWriterBuilder swb; - swb["indentation"] = " "; - auto js_writer = std::unique_ptr(swb.newStreamWriter()); - std::stringstream test_str, reference_str; - js_writer->write(reference_json, &reference_str); - js_writer->write(test_json, &test_str); - // we compare strings here, as e.g. Json::Int(5) != Json::Uint(5), but their json representation is the same - EXPECT_EQ(test_str.str(), reference_str.str()); -} +#include "json/value.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" /** - * @brief Test de- and serialization of an object by comparing its json representation. + * @brief Test de- and serialization of an object without equality operator. * - * Test that a json value x representing type T is equal to serialize(deserialize(x)) w.r.t json representation. + * Test that a json value x representing type T is equal to serialize(deserialize(x)). * * Assuming the (de)serialization functions' general behavior is independent of specific values of member variables, * i.e. the function does not contain conditionals (`if (t > 0)`), optionals (`add_optional`/`expect_optional`), etc., @@ -46,7 +56,7 @@ void test_equal_json_representation(const Json::Value& test_json, const Json::Va * @param reference_json A json value representing an instance of T. */ template -void test_json_serialization_by_representation(const Json::Value& reference_json) +void test_json_serialization_without_equality(const Json::Value& reference_json) { // check that the json is deserializable (i.e. a valid representation) auto t_result = mio::deserialize_json(reference_json, mio::Tag()); @@ -56,16 +66,14 @@ void test_json_serialization_by_representation(const Json::Value& reference_json auto json_result = mio::serialize_json(t_result.value()); ASSERT_THAT(print_wrap(json_result), IsSuccess()); - test_equal_json_representation(json_result.value(), reference_json); + EXPECT_THAT(json_result.value(), JsonEqual(reference_json)); } /** - * @brief Test de- and serialization of an object by comparing its json representation and using its equality operator. + * @brief Test de- and serialization of an object using its equality operator. * - * First, test that serializing the reference_object is equal to the reference_json (w.r.t. their representation), - * and that deserializing the reference_json results in an object equal to the reference_object. - * Then, repeat this step using its own results as arguments to (de)serialize, to check that serialization and - * deserialization are inverse functions to each other. + * First, test that serializing the reference_object is equal to the reference_json. + * Then, test that deserializing the reference_json results in an object equal to the reference_object. * * @tparam T The type to test. * @param reference_object An instance of T. @@ -77,68 +85,54 @@ void test_json_serialization_full(const T& reference_object, const Json::Value& // check that the reference type T is serializable auto json_result = mio::serialize_json(reference_object); ASSERT_THAT(print_wrap(json_result), IsSuccess()); + EXPECT_THAT(json_result.value(), JsonEqual(reference_json)); // check that the reference json is deserializable auto t_result = mio::deserialize_json(reference_json, mio::Tag()); ASSERT_THAT(print_wrap(t_result), IsSuccess()); - - // compare both results with other reference values EXPECT_EQ(t_result.value(), reference_object); - test_equal_json_representation(json_result.value(), reference_json); - - // do the same once more using the results from above - auto json_result_2 = mio::serialize_json(t_result.value()); - ASSERT_THAT(print_wrap(json_result_2), IsSuccess()); - auto t_result_2 = mio::deserialize_json(json_result.value(), mio::Tag()); - ASSERT_THAT(print_wrap(t_result_2), IsSuccess()); - - EXPECT_EQ(t_result_2.value(), reference_object); - test_equal_json_representation(json_result_2.value(), reference_json); } TEST(TestAbmSerialization, Trip) { - // Test (de)serialization w.r.t json representation and the types own equality operator. - // See test_json_serialization_full for more detail. + // See test_json_serialization_full for info on this test. - mio::abm::Trip trip(1, mio::abm::TimePoint(0) + mio::abm::hours(2), 3, 4); + mio::abm::Trip trip(1, mio::abm::TimePoint(2), 3, 4); - Json::Value reference_json; // aka x - reference_json["person_id"] = Json::UInt(1); - reference_json["time"] = Json::Int(mio::abm::hours(2).seconds()); - reference_json["destination"] = Json::UInt(3); - reference_json["origin"] = Json::UInt(4); + Json::Value reference_json; + reference_json["person_id"] = Json::UInt(1); + reference_json["time"]["seconds"] = Json::Int(2); + reference_json["destination"] = Json::UInt(3); + reference_json["origin"] = Json::UInt(4); test_json_serialization_full(trip, reference_json); } TEST(TestAbmSerialization, Vaccination) { - // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. - // See test_json_serialization_by_representation for more detail. + // See test_json_serialization_without_equality for info on this test. - Json::Value reference_json; // aka x - reference_json["exposure_type"] = Json::Int(1); - reference_json["time"]["seconds"] = Json::UInt(2); + Json::Value reference_json; + reference_json["exposure_type"] = Json::UInt(1); + reference_json["time"]["seconds"] = Json::Int(2); - test_json_serialization_by_representation(reference_json); + test_json_serialization_without_equality(reference_json); } TEST(TestAbmSerialization, Infection) { - // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. - // See test_json_serialization_by_representation for more detail. + // See test_json_serialization_without_equality for info on this test. unsigned i = 1; // counter s.t. members have different values Json::Value viral_load; viral_load["decline"] = Json::Value((double)i++); - viral_load["end_date"]["seconds"] = Json::UInt(i++); + viral_load["end_date"]["seconds"] = Json::Int(i++); viral_load["incline"] = Json::Value((double)i++); viral_load["peak"] = Json::Value((double)i++); - viral_load["start_date"]["seconds"] = Json::UInt(i++); + viral_load["start_date"]["seconds"] = Json::Int(i++); - Json::Value reference_json; // aka x + Json::Value reference_json; reference_json["infection_course"] = Json::Value(Json::arrayValue); reference_json["virus_variant"] = Json::UInt(0); reference_json["viral_load"] = viral_load; @@ -146,13 +140,12 @@ TEST(TestAbmSerialization, Infection) reference_json["log_norm_beta"] = Json::Value((double)i++); reference_json["detected"] = Json::Value((bool)0); - test_json_serialization_by_representation(reference_json); + test_json_serialization_without_equality(reference_json); } TEST(TestAbmSerialization, TestingScheme) { - // Test (de)serialization w.r.t json representation and the types own equality operator. - // See test_json_serialization_full for more detail. + // See test_json_serialization_full for info on this test. mio::abm::TestingScheme testing_scheme( mio::abm::TestingCriteria({mio::AgeGroup(1)}, {mio::abm::InfectionState(2)}), mio::abm::TimeSpan(3), @@ -160,20 +153,24 @@ TEST(TestAbmSerialization, TestingScheme) mio::abm::TestParameters{{6.0}, {7.0}, mio::abm::TimeSpan{8}, mio::abm::TestType(0)}, 9.0); Json::Value testing_criteria; - testing_criteria["ages"] = Json::UInt(1 << 1); - testing_criteria["infection_states"] = Json::UInt(1 << 2); + std::vector ages_bits(mio::abm::MAX_NUM_AGE_GROUPS, false); + ages_bits[1] = true; + testing_criteria["ages"]["bitset"] = mio::serialize_json(ages_bits).value(); + std::vector inf_st_bits((size_t)mio::abm::InfectionState::Count, false); + inf_st_bits[2] = true; + testing_criteria["infection_states"]["bitset"] = mio::serialize_json(inf_st_bits).value(); Json::Value test_parameters; test_parameters["sensitivity"] = mio::serialize_json(mio::UncertainValue{6.0}).value(); test_parameters["specificity"] = mio::serialize_json(mio::UncertainValue{7.0}).value(); - test_parameters["required_time"]["seconds"] = Json::UInt(8); + test_parameters["required_time"]["seconds"] = Json::Int(8); test_parameters["test_type"] = Json::UInt(0); - Json::Value reference_json; // aka x + Json::Value reference_json; reference_json["criteria"] = testing_criteria; - reference_json["validity_period"]["seconds"] = Json::UInt(3); - reference_json["start_date"]["seconds"] = Json::UInt(4); - reference_json["end_date"]["seconds"] = Json::UInt(5); + reference_json["validity_period"]["seconds"] = Json::Int(3); + reference_json["start_date"]["seconds"] = Json::Int(4); + reference_json["end_date"]["seconds"] = Json::Int(5); reference_json["test_params"] = test_parameters; reference_json["probability"] = Json::Value((double)9); reference_json["is_active"] = Json::Value((bool)0); @@ -183,8 +180,7 @@ TEST(TestAbmSerialization, TestingScheme) TEST(TestAbmSerialization, TestingStrategy) { - // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. - // See test_json_serialization_by_representation for more detail. + // See test_json_serialization_without_equality for info on this test. unsigned i = 1; // counter s.t. members have different values @@ -193,28 +189,26 @@ TEST(TestAbmSerialization, TestingStrategy) local_strategy["schemes"] = Json::Value(Json::arrayValue); local_strategy["type"] = Json::UInt(i++); - Json::Value reference_json; // aka x + Json::Value reference_json; reference_json["schemes"][0] = local_strategy; - test_json_serialization_by_representation(reference_json); + test_json_serialization_without_equality(reference_json); } TEST(TestAbmSerialization, TestResult) { - // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. - // See test_json_serialization_by_representation for more detail. + // See test_json_serialization_without_equality for info on this test. - Json::Value reference_json; // aka x + Json::Value reference_json; reference_json["result"] = Json::Value(false); - reference_json["time_of_testing"]["seconds"] = Json::UInt(1); + reference_json["time_of_testing"]["seconds"] = Json::Int(1); - test_json_serialization_by_representation(reference_json); + test_json_serialization_without_equality(reference_json); } TEST(TestAbmSerialization, Person) { - // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. - // See test_json_serialization_by_representation for more detail. + // See test_json_serialization_without_equality for info on this test. auto json_uint_array = [](std::vector values) { return mio::serialize_json(values).value(); @@ -225,7 +219,7 @@ TEST(TestAbmSerialization, Person) unsigned i = 1; // counter s.t. members have different values - Json::Value reference_json; // aka x + Json::Value reference_json; reference_json["age_group"] = Json::UInt(i++); reference_json["assigned_locations"] = json_uint_array({i++, i++, i++, i++, i++, i++, i++, i++, i++, i++, i++}); reference_json["cells"] = json_uint_array({i++}); @@ -235,11 +229,11 @@ TEST(TestAbmSerialization, Person) reference_json["location"] = Json::UInt(i++); reference_json["location_type"] = Json::UInt(0); reference_json["mask"]["mask_type"] = Json::UInt(0); - reference_json["mask"]["time_used"]["seconds"] = Json::UInt(i++); + reference_json["mask"]["time_used"]["seconds"] = Json::Int(i++); reference_json["mask_compliance"] = json_double_array({(double)i++, (double)i++, (double)i++, (double)i++, (double)i++, (double)i++, (double)i++, (double)i++, (double)i++, (double)i++, (double)i++}); - reference_json["quarantine_start"]["seconds"] = Json::UInt(i++); + reference_json["quarantine_start"]["seconds"] = Json::Int(i++); reference_json["rnd_go_to_school_hour"] = Json::Value((double)i++); reference_json["rnd_go_to_work_hour"] = Json::Value((double)i++); reference_json["rnd_schoolgroup"] = Json::Value((double)i++); @@ -247,23 +241,22 @@ TEST(TestAbmSerialization, Person) reference_json["rng_counter"] = Json::UInt(i++); reference_json["test_results"] = mio::serialize_json(mio::CustomIndexArray{}).value(); - reference_json["time_at_location"]["seconds"] = Json::UInt(i++); + reference_json["time_at_location"]["seconds"] = Json::Int(i++); reference_json["vaccinations"] = Json::Value(Json::arrayValue); reference_json["wears_mask"] = Json::Value(false); - test_json_serialization_by_representation(reference_json); + test_json_serialization_without_equality(reference_json); } TEST(TestAbmSerialization, Location) { - // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. - // See test_json_serialization_by_representation for more detail. + // See test_json_serialization_without_equality for info on this test. unsigned i = 1; // counter s.t. members have different values Json::Value contact_rates = mio::serialize_json(mio::abm::ContactRates::get_default(i++)).value(); - Json::Value reference_json; // aka x + Json::Value reference_json; reference_json["cells"][0]["capacity"]["persons"] = Json::UInt(i++); reference_json["cells"][0]["capacity"]["volume"] = Json::UInt(i++); reference_json["geographical_location"]["latitude"] = Json::Value((double)i++); @@ -275,13 +268,12 @@ TEST(TestAbmSerialization, Location) reference_json["parameters"]["UseLocationCapacityForTransmissions"] = Json::Value(false); reference_json["required_mask"] = Json::UInt(0); - test_json_serialization_by_representation(reference_json); + test_json_serialization_without_equality(reference_json); } TEST(TestAbmSerialization, Model) { - // Test that a json value x is equal to serialize(deserialize(x)) w.r.t json representation. - // See test_json_serialization_by_representation for more detail. + // See test_json_serialization_without_equality for info on this test. auto json_uint_array = [](std::vector values) { return mio::serialize_json(values).value(); @@ -291,7 +283,7 @@ TEST(TestAbmSerialization, Model) Json::Value abm_parameters = mio::serialize_json(mio::abm::Parameters(i++)).value(); - Json::Value reference_json; // aka x + Json::Value reference_json; reference_json["cemetery_id"] = Json::UInt(i++); reference_json["location_types"] = Json::UInt(i++); reference_json["locations"] = Json::Value(Json::arrayValue); @@ -306,7 +298,7 @@ TEST(TestAbmSerialization, Model) reference_json["trip_list"]["trips_weekend"] = Json::Value(Json::arrayValue); reference_json["use_mobility_rules"] = Json::Value(false); - test_json_serialization_by_representation(reference_json); + test_json_serialization_without_equality(reference_json); } #endif diff --git a/cpp/tests/test_math_time_series_functor.cpp b/cpp/tests/test_math_time_series_functor.cpp index f22ed7cfed..b4d5ed0b61 100644 --- a/cpp/tests/test_math_time_series_functor.cpp +++ b/cpp/tests/test_math_time_series_functor.cpp @@ -1,59 +1,77 @@ +/* +* Copyright (C) 2020-2024 MEmilio +* +* Authors: René Schmieding +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ #include "memilio/math/time_series_functor.h" -#include "memilio/utils/random_number_generator.h" +#include "random_number_test.h" #include -class TestMathTimeSeriesFunctor : public ::testing::Test -{ -public: - using TSF = mio::TimeSeriesFunctor; - - const int num_evals = 1000; - const double min = -1e+3, max = 1e+3; // a reasonably large range for fuzzy_val - - double fuzzy_val(double min_, double max_) - { - return mio::UniformDistribution::get_instance()(m_rng, min_, max_); - } - -protected: - void SetUp() override - { - log_rng_seeds(m_rng, mio::LogLevel::warn); - } - -private: - mio::RandomNumberGenerator m_rng{}; -}; +using TestMathTimeSeriesFunctor = RandomNumberTest; TEST_F(TestMathTimeSeriesFunctor, zero) { - // Test that the Zero-functor always returns zero, using a random evaluation point. + // Test that the default initialized functor always returns zero, using a random evaluation point. + + const int num_evals = 100; // initialize functor using the default ctor - TSF tsf; + mio::TimeSeriesFunctor tsf; // check one deterministic value first to avoid flooding the test output with failed tests ASSERT_EQ(tsf(0.0), 0.0); // verify output - for (int i = 0; i < this->num_evals; i++) { - auto random_t_eval = this->fuzzy_val(this->min, this->max); + for (int i = 0; i < num_evals; i++) { + auto random_t_eval = this->random_number(); EXPECT_EQ(tsf(random_t_eval), 0.0); } } TEST_F(TestMathTimeSeriesFunctor, linearInterpolation) { - // Test that the LinearInterpolation-functor correctly reproduces a (piecewise) linear function, using random + // Test that linear interpolation works for a piecewise linear function. + + // continuous function that is constant 1 for t<0, linear in [0, 1] with slope 2, and constant 3 for t>1 + const auto pcw_lin_fct = [&](double t) { + return 1 + 2 * std::clamp(t, 0.0, 1.0); // .../``` + }; + + mio::TimeSeriesFunctor tsf(mio::TimeSeriesFunctorType::LinearInterpolation, {{0., 1.}, {1., 3.}}); + + // go from -1/4 to 5/4 in steps of size 1/4, with values 1.0, 1.0, 1.5, 2.0, 2.5, 3.0, 3.0 + for (double t = -0.25; t < 1.3; t += 0.25) { + EXPECT_NEAR(tsf(t), pcw_lin_fct(t), 1e-14); + } +} + +TEST_F(TestMathTimeSeriesFunctor, linearInterpolationRandomized) +{ + // Test that the LinearInterpolation-functor correctly reproduces a piecewise linear function, using random // samples. Since the initialization uses unsorted data, this also checks that the data gets sorted + const int num_evals = 1000; - const double t_min = -1, t_max = 1, t_mid = this->fuzzy_val(t_min, t_max); - const double slope1 = this->fuzzy_val(this->min, this->max); - const double slope2 = this->fuzzy_val(this->min, this->max); - const double height = this->fuzzy_val(this->min, this->max); + const double t_min = -1, t_max = 1, t_mid = this->random_number(t_min, t_max); + const double slope1 = this->random_number(); + const double slope2 = this->random_number(); + const double height = this->random_number(); + // continuous function with different slopes between t_min, t_mid and t_max, constant otherwise const auto pcw_lin_fct = [&](double t) { - // continuous function with different slopes between t_min, t_mid and t_max, constant otherwise return height + slope1 * std::clamp(t - t_min, 0.0, t_mid - t_min) + slope2 * std::clamp(t - t_mid, 0.0, t_max - t_mid); }; @@ -63,18 +81,18 @@ TEST_F(TestMathTimeSeriesFunctor, linearInterpolation) {t_max, pcw_lin_fct(t_max)}, {t_min, pcw_lin_fct(t_min)}, {t_mid, pcw_lin_fct(t_mid)}}; // randomly add a few more evaluations in between for (int i = 0; i < 10; i++) { - const double t = this->fuzzy_val(-1.0, 1.0); + const double t = this->random_number(-1.0, 1.0); unsorted_data.push_back({t, pcw_lin_fct(t)}); } // initialize functor - TSF tsf(TSF::Type::LinearInterpolation, unsorted_data); + mio::TimeSeriesFunctor tsf(mio::TimeSeriesFunctorType::LinearInterpolation, unsorted_data); // check one deterministic value first to avoid flooding the test output with failed tests ASSERT_NEAR(tsf(0.5 * (t_max - t_min)), pcw_lin_fct(0.5 * (t_max - t_min)), 1e-10); // verify output - for (int i = 0; i < this->num_evals; i++) { + for (int i = 0; i < num_evals; i++) { // sample in the interval [t_min - (t_max - t_min) / 4, t_max + (t_max - tmin) / 4] - double random_t_eval = this->fuzzy_val(1.25 * t_min - 0.25 * t_max, 1.25 * t_max - 0.25 * t_min); + double random_t_eval = this->random_number(1.25 * t_min - 0.25 * t_max, 1.25 * t_max - 0.25 * t_min); EXPECT_NEAR(tsf(random_t_eval), pcw_lin_fct(random_t_eval), 1e-10) << "i = " << i; } } From cb2783df478861d711ae5a526019481bfbc791e8 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Fri, 16 Aug 2024 15:02:09 +0200 Subject: [PATCH 25/42] avoid creating object twice --- cpp/memilio/io/binary_serializer.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/memilio/io/binary_serializer.h b/cpp/memilio/io/binary_serializer.h index 2366ba56d4..77b29aa7f0 100644 --- a/cpp/memilio/io/binary_serializer.h +++ b/cpp/memilio/io/binary_serializer.h @@ -24,6 +24,7 @@ #include "memilio/io/serializer_base.h" #include "memilio/utils/compiler_diagnostics.h" #include "memilio/utils/metaprogramming.h" +#include #include #include #include @@ -77,6 +78,7 @@ class ByteStream */ bool read(unsigned char* p, size_t s) { + std::cout << m_read_head << " + " << s << " / " << m_buf.size() << "\n"; if (s <= (m_buf.size() - m_read_head)) { auto read_begin = m_buf.begin() + m_read_head; auto read_end = read_begin + s; @@ -278,7 +280,7 @@ class BinarySerializerContext : public SerializerBase "Unexpected type in stream:" + type_result.value() + ". Expected " + type); } } - return BinarySerializerObject(m_stream, m_status, m_flags); + return obj; } /** From 7cb2f6937152ddec009963c94c909fd8c17fa556 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Fri, 16 Aug 2024 15:02:44 +0200 Subject: [PATCH 26/42] fix evaluation order of auto_deserialize --- cpp/memilio/io/auto_serialize.h | 49 +++++++++++++++++---------------- cpp/memilio/io/io.h | 12 ++++++++ 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/cpp/memilio/io/auto_serialize.h b/cpp/memilio/io/auto_serialize.h index ad14e2abef..2d43fbb77f 100644 --- a/cpp/memilio/io/auto_serialize.h +++ b/cpp/memilio/io/auto_serialize.h @@ -62,24 +62,24 @@ struct Members { Members(const char* class_name) : name(class_name) - , name_value_pairs() + , named_refs() { } - Members(const char* class_name, std::tuple...> nvps) + Members(const char* class_name, std::tuple...> named_references) : name(class_name) - , name_value_pairs(nvps) + , named_refs(named_references) { } template [[nodiscard]] Members add(const char* member_name, T& member) { - return Members{name, std::tuple_cat(name_value_pairs, std::tuple(NamedRef{member_name, member}))}; + return Members{name, std::tuple_cat(named_refs, std::tuple(NamedRef{member_name, member}))}; } const char* name; - std::tuple...> name_value_pairs; + std::tuple...> named_refs; }; /** @@ -115,40 +115,43 @@ using auto_serialize_expr_t = decltype(std::declval().auto_serialize()); /// Add a name-value pair to an io object. template -void add_nvp(IOObject& obj, const NamedRef nvp) +void add_named_ref(IOObject& obj, const NamedRef named_ref) { - obj.add_element(nvp.name, nvp.value); + obj.add_element(named_ref.name, named_ref.value); } /// Unpack all name-value pairs from the tuple and add them to a new io object with the given name. template -void auto_serialize_impl(IOContext& io, const char* name, const NamedRef... nvps) +void auto_serialize_impl(IOContext& io, const char* name, const NamedRef... named_refs) { auto obj = io.create_object(name); - (add_nvp(obj, nvps), ...); + (add_named_ref(obj, named_refs), ...); } /// Retrieve a name-value pair from an io object. template -IOResult expect_nvp(IOObject& obj, const NamedRef nvp) +IOResult expect_named_ref(IOObject& obj, const NamedRef named_ref) { - return obj.expect_element(nvp.name, Tag{}); + return obj.expect_element(named_ref.name, Tag{}); } /// Read an io object and its members from the io context using the given names and assign the values to a. template IOResult auto_deserialize_impl(IOContext& io, AutoSerializable& a, const char* name, - NamedRef... nvps) + NamedRef... named_refs) { auto obj = io.expect_object(name); + // we cannot use expect_named_ref directly in apply, as function arguments have no guarantueed order of evaluation + std::tuple...> results{expect_named_ref(obj, named_refs)...}; + return apply( io, - [&a, &nvps...](const Members&... values) { - ((nvps.value = values), ...); + [&a, &named_refs...](const Members&... values) { + ((named_refs.value = values), ...); return a; }, - expect_nvp(obj, nvps)...); + results); } } // namespace details @@ -163,7 +166,7 @@ using has_auto_serialize = is_expression_valid(a).auto_serialize(); // unpack members and serialize std::apply( - [&io, &members](auto... nvps) { - details::auto_serialize_impl(io, members.name, nvps...); + [&io, &members](auto... named_refs) { + details::auto_serialize_impl(io, members.name, named_refs...); }, - members.name_value_pairs); + members.named_refs); } /** * @brief Deserialization implementation for the auto-serialization feature. * Disables itself (SFINAE) if there is no auto_serialize member or if a deserialize meember is present. - * Generates the deserialize method depending on the NVPs given by auto_serialize. + * Generates the deserialize method depending on the NamedRefs given by auto_serialize. * @tparam IOContext A type that models the IOContext concept. * @tparam AutoSerializable A type that can be auto-serialized. * @param io An IO context. @@ -206,10 +209,10 @@ IOResult deserialize_internal(IOContext& io, Tag eval(F f, const IOResult&... rs) * @param f the function that is called with the values contained in `rs` as arguments. * @param rs zero or more IOResults from previous operations. * @return the result of f(rs.value()...) if successful, the first error encountered otherwise. + * @{ */ template details::ApplyResultT apply(IOContext& io, F f, const IOResult&... rs) @@ -500,6 +501,17 @@ details::ApplyResultT apply(IOContext& io, F f, const IOResult&... r return result; } +template +details::ApplyResultT apply(IOContext& io, F f, const std::tuple...>& results) +{ + return std::apply( + [&](auto&&... rs) { + return apply(io, f, rs...); + }, + results); +} +/** @} */ + //utility for (de-)serializing tuple-like objects namespace details { From 36289ceaec86f544a06757ccace0a721c3416b22 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Fri, 16 Aug 2024 15:34:58 +0200 Subject: [PATCH 27/42] cover unhandled type asserts in tests --- cpp/tests/test_math_time_series_functor.cpp | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/cpp/tests/test_math_time_series_functor.cpp b/cpp/tests/test_math_time_series_functor.cpp index b4d5ed0b61..eb2ea85adf 100644 --- a/cpp/tests/test_math_time_series_functor.cpp +++ b/cpp/tests/test_math_time_series_functor.cpp @@ -20,7 +20,8 @@ #include "memilio/math/time_series_functor.h" #include "random_number_test.h" -#include +#include "gtest/gtest.h" +#include using TestMathTimeSeriesFunctor = RandomNumberTest; @@ -96,3 +97,21 @@ TEST_F(TestMathTimeSeriesFunctor, linearInterpolationRandomized) EXPECT_NEAR(tsf(random_t_eval), pcw_lin_fct(random_t_eval), 1e-10) << "i = " << i; } } + +TEST_F(TestMathTimeSeriesFunctor, unhandledTypes) +{ + // check that the functor does not accept unhandled types. + + const auto unhandled_type = (mio::TimeSeriesFunctorType)-1; + + // check constructor assert + EXPECT_DEBUG_DEATH(mio::TimeSeriesFunctor(unhandled_type, mio::TimeSeries(0)), + "Unhandled TimeSeriesFunctorType!"); + + // abuse auto_serialize to set an invalid type + mio::TimeSeriesFunctor functor; + std::get<0>(functor.auto_serialize().named_refs).value = unhandled_type; + + // check assert in functor call + EXPECT_DEBUG_DEATH(functor(0.0), "Unhandled TimeSeriesFunctorType!"); +} From 537edbdb29e6b60144bbf864e81cf739dcf04414 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Fri, 16 Aug 2024 15:59:22 +0200 Subject: [PATCH 28/42] [ci skip] update io readme --- cpp/memilio/io/README.md | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/cpp/memilio/io/README.md b/cpp/memilio/io/README.md index 5a21a4ffbe..4e4e7a36d1 100644 --- a/cpp/memilio/io/README.md +++ b/cpp/memilio/io/README.md @@ -7,19 +7,20 @@ This directory contains utilities for reading and writing data from and to files ### Using serialization In the next sections we will explain how to implement serialization (both for types and formats), here we quickly show -how to use it once it already is implemented for a type. Currently, there is support for the Json and a binary format, -which can be used through the `serialize_json`/`deserialize_json` and `serialize_binary`/`deserialize_binary`, -respectively. For example - +how to use it once it already is implemented for a type. In the following examples, we serialize (write) `Foo` to a +file in Json format, then deserialize (read) the Json again. ```cpp Foo foo{5}; -mio::IOResult js_result = mio::serialize_json(foo); +mio::IOResult io_result = mio::write_json("path/to/foo.json", foo); ``` ```cpp -Json::Value js_value; -js_value["i"] = Json::Int(5); -mio::IOResult foo_result = mio::deserialize_json(js_value, mio::Tag{}); +mio::IOResult io_result = mio::read_json("path/to/foo.json", mio::Tag{}); +if (io_result) { + Foo foo = io_result.value(); +} ``` +There is also support for a binary format. If you want to use a format directly, use the +`serialize_json`/`deserialize_json` and `serialize_binary`/`deserialize_binary` functions. ### Main functions and types From 0b67e1c7499b604ab609e47ae4017e5cd0eefa35 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Tue, 27 Aug 2024 12:35:43 +0200 Subject: [PATCH 29/42] rename auto-serialize to default serialize, update documentation per review suggestions --- cpp/memilio/io/README.md | 21 ++- cpp/memilio/io/auto_serialize.h | 193 +++++++++++--------- cpp/memilio/math/time_series_functor.h | 4 +- cpp/memilio/utils/random_number_generator.h | 4 +- cpp/models/abm/infection.h | 8 +- cpp/models/abm/location.h | 16 +- cpp/models/abm/mask.h | 6 +- cpp/models/abm/parameters.h | 12 +- cpp/models/abm/person.h | 6 +- cpp/models/abm/test_type.h | 4 +- cpp/models/abm/testing_strategy.h | 14 +- cpp/models/abm/time.h | 8 +- cpp/models/abm/trip_list.h | 7 +- cpp/models/abm/vaccine.h | 6 +- cpp/tests/test_math_time_series_functor.cpp | 4 +- 15 files changed, 172 insertions(+), 141 deletions(-) diff --git a/cpp/memilio/io/README.md b/cpp/memilio/io/README.md index 4e4e7a36d1..5e1c860460 100644 --- a/cpp/memilio/io/README.md +++ b/cpp/memilio/io/README.md @@ -32,7 +32,7 @@ There is also support for a binary format. If you want to use a format directly, - IOStatus and IOResult: Used for error handling, see section "Error Handling" below. -### Auto-serialization +### Default serialization Before we get into the details of the framework, this feature provides an easy and convenient alternative to the serialize and deserialize functions. To give an example: @@ -40,20 +40,25 @@ serialize and deserialize functions. To give an example: ```cpp struct Foo { int i; - auto auto_serialize() { + auto default_serialize() { return Members("Foo").add("i", i); } }; ``` -The auto-serialization is less flexible than the serialize and deserialize functions and has additional requirements: -- The class must be trivially constructible. +The default serialization is less flexible than the serialize and deserialize functions and has additional +requirements: +- The class must be default constructible. + - A *private* default constructor can be used by marking the struct `DefaultFactory` as a friend. + For the example above, the line `friend DefaultFactory;` would be added to the class definition. - Alternatively, you may provide a specialization of the struct `DefaultFactory`. For more details, - view the struct's documentation. -- Every class member must be added to Members exactly once (though the names and their order are arbitrary). + view the struct's documentation. +- Every class member must be added to `Members` exactly once, and the provided names must be unique. - The members must be passed directly, like in the example. No copies, accessors, etc. -- Every class member itself is both (auto-)(de)serializable and assignable. + - It is recommended, but not required, to add member variables to `Members` in the same order they are declared in + the class, using the variables' names or something very similar. +- Every class member itself must be serializable, deserializable and assignable. -As to the feature set, auto-serialization only supports the `add_element` and `expect_element` operations defined in +As to the feature set, default-serialization only supports the `add_element` and `expect_element` operations defined in the Concepts section below, where each operation's arguments are provided through the `add` function. Note that the value provided to `add` is also used to assign a value during deserialization, hence the class members must be used directly in the function (i.e. as a non-const lvalue reference). diff --git a/cpp/memilio/io/auto_serialize.h b/cpp/memilio/io/auto_serialize.h index 2d43fbb77f..6390b8cff9 100644 --- a/cpp/memilio/io/auto_serialize.h +++ b/cpp/memilio/io/auto_serialize.h @@ -17,8 +17,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MIO_IO_AUTO_SERIALIZE_H_ -#define MIO_IO_AUTO_SERIALIZE_H_ +#ifndef MIO_IO_DEFAULT_SERIALIZE_H_ +#define MIO_IO_DEFAULT_SERIALIZE_H_ #include "memilio/io/io.h" #include "memilio/utils/metaprogramming.h" @@ -33,8 +33,8 @@ namespace mio /** * @brief A pair of name and reference. * - * Used for auto-(de)serialization. - * This object holds a pointer to a name and reference to value. Mind their lifetime! + * Used for default (de)serialization. + * This object holds a char pointer to a name and reference to value. Mind their lifetime! * @tparam ValueType The (non-cv, non-reference) type of the value. */ template @@ -57,61 +57,16 @@ struct NamedRef { } }; -template -struct Members { - - Members(const char* class_name) - : name(class_name) - , named_refs() - { - } - - Members(const char* class_name, std::tuple...> named_references) - : name(class_name) - , named_refs(named_references) - { - } - - template - [[nodiscard]] Members add(const char* member_name, T& member) - { - return Members{name, std::tuple_cat(named_refs, std::tuple(NamedRef{member_name, member}))}; - } - - const char* name; - std::tuple...> named_refs; -}; - -/** - * @brief Creates an instance of T for later initialization. - * - * The default implementation uses the default constructor of T, if available. If there is no default constructor, this - * class can be spezialized to provide the method `static T create()`. If there is a default constructor, but it is - * private, DefaultFactory can be marked as friend instead. - * - * The state of the object retured by `create()` is completely arbitrary, and may be invalid. Make sure to set it to a - * valid state before using it further. - * - * @tparam T The type to create. - */ -template -struct DefaultFactory { - /// @brief Creates a new instance of T. - static T create() - { - return T{}; - } -}; - namespace details { /** - * @brief Helper type to detect whether T has a auto_serialize member function. + * @brief Helper type to detect whether T has a default_serialize member function. + * Use has_default_serialize. * @tparam T Any type. */ template -using auto_serialize_expr_t = decltype(std::declval().auto_serialize()); +using default_serialize_expr_t = decltype(std::declval().default_serialize()); /// Add a name-value pair to an io object. template @@ -122,7 +77,7 @@ void add_named_ref(IOObject& obj, const NamedRef named_ref) /// Unpack all name-value pairs from the tuple and add them to a new io object with the given name. template -void auto_serialize_impl(IOContext& io, const char* name, const NamedRef... named_refs) +void default_serialize_impl(IOContext& io, const char* name, const NamedRef... named_refs) { auto obj = io.create_object(name); (add_named_ref(obj, named_refs), ...); @@ -136,9 +91,9 @@ IOResult expect_named_ref(IOObject& obj, const NamedRef named_re } /// Read an io object and its members from the io context using the given names and assign the values to a. -template -IOResult auto_deserialize_impl(IOContext& io, AutoSerializable& a, const char* name, - NamedRef... named_refs) +template +IOResult default_deserialize_impl(IOContext& io, DefaultSerializable& a, const char* name, + NamedRef... named_refs) { auto obj = io.expect_object(name); @@ -157,64 +112,134 @@ IOResult auto_deserialize_impl(IOContext& io, AutoSerializable } // namespace details /** - * @brief Detect whether T has a auto_serialize member function. + * @brief List of a class's members. + * + * Used for default (de)serialization. + * Holds a char pointer to the class name as well as a tuple of NamedRefs with all added class members. + * @tparam ValueTypes The (non-cv, non-reference) types of member variables. + */ +template +struct Members { + // allow other Members access to the private constructor + template + friend struct Members; + + /** + * @brief Initialize with a class name. Use `add` to specify its member variables. + * @param[in] class_name Name of a class. + */ + Members(const char* class_name) + : name(class_name) + , named_refs() + { + } + + /** + * @brief Add a class member. + * @param[in] member_name + */ + template + [[nodiscard]] Members add(const char* member_name, T& member) + { + return Members{name, std::tuple_cat(named_refs, std::tuple(NamedRef{member_name, member}))}; + } + + const char* name; + std::tuple...> named_refs; + +private: + /** + * @brief Initialize Members directly. Used by the add function. + * @param[in] class_name Name of a class. + * @param[in] named_references Tuple of added class Members. + */ + Members(const char* class_name, std::tuple...> named_references) + : name(class_name) + , named_refs(named_references) + { + } +}; + +/** + * @brief Creates an instance of T for later initialization. + * + * The default implementation uses the default constructor of T, if available. If there is no default constructor, this + * class can be spezialized to provide the method `static T create()`. If there is a default constructor, but it is + * private, DefaultFactory can be marked as friend instead. + * + * The state of the object retured by `create()` is completely arbitrary, and may be invalid. Make sure to set it to a + * valid state before using it further. + * + * @tparam T The type to create. + */ +template +struct DefaultFactory { + /// @brief Creates a new instance of T. + static T create() + { + return T{}; + } +}; + +/** + * @brief Detect whether T has a default_serialize member function. * @tparam T Any type. */ template -using has_auto_serialize = is_expression_valid; +using has_default_serialize = is_expression_valid; /** - * @brief Serialization implementation for the auto-serialization feature. - * Disables itself (SFINAE) if there is no auto_serialize member or if a serialize member is present. - * Generates the serialize method depending on the NamedRefs given by auto_serialize. + * @brief Serialization implementation for the default serialization feature. + * Disables itself (SFINAE) if there is no default_serialize member or if a serialize member is present. + * Generates the serialize method depending on the NamedRefs given by default_serialize. * @tparam IOContext A type that models the IOContext concept. - * @tparam AutoSerializable A type that can be auto-serialized. + * @tparam DefaultSerializable A type that can be default serialized. * @param io An IO context. - * @param a An instance of AutoSerializable to be serialized. + * @param a An instance of DefaultSerializable to be serialized. */ -template < - class IOContext, class AutoSerializable, - std::enable_if_t::value && !has_serialize::value, - AutoSerializable*> = nullptr> -void serialize_internal(IOContext& io, const AutoSerializable& a) +template ::value && + !has_serialize::value, + DefaultSerializable*> = nullptr> +void serialize_internal(IOContext& io, const DefaultSerializable& a) { // Note that the following cons_cast is only safe if we do not modify members. - const auto members = const_cast(a).auto_serialize(); + const auto members = const_cast(a).default_serialize(); // unpack members and serialize std::apply( [&io, &members](auto... named_refs) { - details::auto_serialize_impl(io, members.name, named_refs...); + details::default_serialize_impl(io, members.name, named_refs...); }, members.named_refs); } /** - * @brief Deserialization implementation for the auto-serialization feature. - * Disables itself (SFINAE) if there is no auto_serialize member or if a deserialize meember is present. - * Generates the deserialize method depending on the NamedRefs given by auto_serialize. + * @brief Deserialization implementation for the default serialization feature. + * Disables itself (SFINAE) if there is no default_serialize member or if a deserialize meember is present. + * Generates the deserialize method depending on the NamedRefs given by default_serialize. * @tparam IOContext A type that models the IOContext concept. - * @tparam AutoSerializable A type that can be auto-serialized. + * @tparam DefaultSerializable A type that can be default serialized. * @param io An IO context. - * @param tag Defines the type of the object that is to be deserialized (i.e. AutoSerializble). + * @param tag Defines the type of the object that is to be deserialized (i.e. DefaultSerializble). * @return The restored object if successful, an error otherwise. */ -template ::value && - !has_deserialize::value, - AutoSerializable*> = nullptr> -IOResult deserialize_internal(IOContext& io, Tag tag) +template ::value && + !has_deserialize::value, + DefaultSerializable*> = nullptr> +IOResult deserialize_internal(IOContext& io, Tag tag) { mio::unused(tag); - AutoSerializable a = DefaultFactory::create(); - auto members = a.auto_serialize(); + DefaultSerializable a = DefaultFactory::create(); + auto members = a.default_serialize(); // unpack members and deserialize return std::apply( [&io, &members, &a](auto... named_refs) { - return details::auto_deserialize_impl(io, a, members.name, named_refs...); + return details::default_deserialize_impl(io, a, members.name, named_refs...); }, members.named_refs); } } // namespace mio -#endif // MIO_IO_AUTO_SERIALIZE_H_ +#endif // MIO_IO_DEFAULT_SERIALIZE_H_ diff --git a/cpp/memilio/math/time_series_functor.h b/cpp/memilio/math/time_series_functor.h index 96f9792291..1469bd6da8 100644 --- a/cpp/memilio/math/time_series_functor.h +++ b/cpp/memilio/math/time_series_functor.h @@ -105,8 +105,8 @@ class TimeSeriesFunctor } } - /// This method is used by the auto-serialization feature. - auto auto_serialize() + /// This method is used by the default serialization feature. + auto default_serialize() { return Members("TimeSeriesFunctor").add("type", m_type).add("data", m_data); } diff --git a/cpp/memilio/utils/random_number_generator.h b/cpp/memilio/utils/random_number_generator.h index 8b22f64e0f..909188973a 100644 --- a/cpp/memilio/utils/random_number_generator.h +++ b/cpp/memilio/utils/random_number_generator.h @@ -358,8 +358,8 @@ class RandomNumberGenerator : public RandomNumberGeneratorBase struct DefaultFactory { static abm::Mask create() diff --git a/cpp/models/abm/parameters.h b/cpp/models/abm/parameters.h index 19d03105ff..6b8de8d9f2 100644 --- a/cpp/models/abm/parameters.h +++ b/cpp/models/abm/parameters.h @@ -177,8 +177,8 @@ struct ViralLoadDistributionsParameters { UniformDistribution::ParamType viral_load_incline; UniformDistribution::ParamType viral_load_decline; - /// This method is used by the auto-serialization feature. - auto auto_serialize() + /// This method is used by the default serialization feature. + auto default_serialize() { return Members("ViralLoadDistributionsParameters") .add("viral_load_peak", viral_load_peak) @@ -209,8 +209,8 @@ struct InfectivityDistributionsParameters { UniformDistribution::ParamType infectivity_alpha; UniformDistribution::ParamType infectivity_beta; - /// This method is used by the auto-serialization feature. - auto auto_serialize() + /// This method is used by the default serialization feature. + auto default_serialize() { return Members("InfectivityDistributionsParameters") .add("infectivity_alpha", infectivity_alpha) @@ -332,8 +332,8 @@ struct TestParameters { TimeSpan required_time; TestType type; - /// This method is used by the auto-serialization feature. - auto auto_serialize() + /// This method is used by the default serialization feature. + auto default_serialize() { return Members("TestParameters") .add("sensitivity", sensitivity) diff --git a/cpp/models/abm/person.h b/cpp/models/abm/person.h index 342f13cccb..aff80a0b4c 100755 --- a/cpp/models/abm/person.h +++ b/cpp/models/abm/person.h @@ -388,8 +388,8 @@ class Person */ std::pair get_latest_protection() const; - /// This method is used by the auto-serialization feature. - auto auto_serialize() + /// This method is used by the default serialization feature. + auto default_serialize() { return Members("Person") .add("location", m_location) @@ -456,7 +456,7 @@ class Person } // namespace abm -/// @brief Creates an instance of abm::Person for auto-deserialization. +/// @brief Creates an instance of abm::Person for default serialization. template <> struct DefaultFactory { static abm::Person create() diff --git a/cpp/models/abm/test_type.h b/cpp/models/abm/test_type.h index 823f6f26f9..f3502be43c 100644 --- a/cpp/models/abm/test_type.h +++ b/cpp/models/abm/test_type.h @@ -49,8 +49,8 @@ struct TestResult { TimePoint time_of_testing{std::numeric_limits::min()}; ///< The TimePoint when the Person performs the test. bool result{false}; ///< The test result. - /// This method is used by the auto-serialization feature. - auto auto_serialize() + /// This method is used by the default serialization feature. + auto default_serialize() { return Members("TestResult").add("time_of_testing", time_of_testing).add("result", result); } diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index 0d58222b3a..b3e9cd4fad 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -94,7 +94,7 @@ class TestingCriteria */ bool evaluate(const Person& p, TimePoint t) const; - auto auto_serialize() + auto default_serialize() { return Members("TestingCriteria").add("ages", m_ages).add("infection_states", m_infection_states); } @@ -150,8 +150,8 @@ class TestingScheme */ bool run_scheme(PersonalRandomNumberGenerator& rng, Person& person, TimePoint t) const; - /// This method is used by the auto-serialization feature. - auto auto_serialize() + /// This method is used by the default serialization feature. + auto default_serialize() { return Members("TestingScheme") .add("criteria", m_testing_criteria) @@ -191,8 +191,8 @@ class TestingStrategy LocationId id; std::vector schemes; - /// This method is used by the auto-serialization feature. - auto auto_serialize() + /// This method is used by the default serialization feature. + auto default_serialize() { return Members("LocalStrategy").add("type", type).add("id", id).add("schemes", schemes); } @@ -261,8 +261,8 @@ class TestingStrategy */ bool run_strategy(PersonalRandomNumberGenerator& rng, Person& person, const Location& location, TimePoint t); - /// This method is used by the auto-serialization feature. - auto auto_serialize() + /// This method is used by the default serialization feature. + auto default_serialize() { return Members("TestingStrategy").add("schemes", m_location_to_schemes_map); } diff --git a/cpp/models/abm/time.h b/cpp/models/abm/time.h index 917b80ecca..af9a7f2e51 100644 --- a/cpp/models/abm/time.h +++ b/cpp/models/abm/time.h @@ -145,8 +145,8 @@ class TimeSpan } /**@}*/ - /// This method is used by the auto-serialization feature. - auto auto_serialize() + /// This method is used by the default serialization feature. + auto default_serialize() { return Members("TimeSpan").add("seconds", m_seconds); } @@ -292,8 +292,8 @@ class TimePoint return TimeSpan{m_seconds - p2.seconds()}; } - /// This method is used by the auto-serialization feature. - auto auto_serialize() + /// This method is used by the default serialization feature. + auto default_serialize() { return Members("TimePoint").add("seconds", m_seconds); } diff --git a/cpp/models/abm/trip_list.h b/cpp/models/abm/trip_list.h index d3c9f82221..bcf64b0ffb 100644 --- a/cpp/models/abm/trip_list.h +++ b/cpp/models/abm/trip_list.h @@ -90,7 +90,7 @@ struct Trip { (origin == other.origin); } - auto auto_serialize() + auto default_serialize() { return Members("Trip") .add("person_id", person_id) @@ -168,8 +168,8 @@ class TripList return m_current_index; } - /// This method is used by the auto-serialization feature. - auto auto_serialize() + /// This method is used by the default serialization feature. + auto default_serialize() { return Members("TestingScheme") .add("trips_weekday", m_trips_weekday) @@ -185,6 +185,7 @@ class TripList } // namespace abm +/// @brief Creates an instance of abm::Trip for default serialization. template <> struct DefaultFactory { static abm::Trip create() diff --git a/cpp/models/abm/vaccine.h b/cpp/models/abm/vaccine.h index 5f77692b7b..54bca108d8 100644 --- a/cpp/models/abm/vaccine.h +++ b/cpp/models/abm/vaccine.h @@ -53,8 +53,8 @@ struct Vaccination { { } - /// This method is used by the auto-serialization feature. - auto auto_serialize() + /// This method is used by the default serialization feature. + auto default_serialize() { return Members("Vaccination").add("exposure_type", exposure_type).add("time", time); } @@ -65,7 +65,7 @@ struct Vaccination { } // namespace abm -/// @brief Creates an instance of abm::Vaccination for auto-deserialization. +/// @brief Creates an instance of abm::Vaccination for default serialization. template <> struct DefaultFactory { static abm::Vaccination create() diff --git a/cpp/tests/test_math_time_series_functor.cpp b/cpp/tests/test_math_time_series_functor.cpp index eb2ea85adf..951775c0e2 100644 --- a/cpp/tests/test_math_time_series_functor.cpp +++ b/cpp/tests/test_math_time_series_functor.cpp @@ -108,9 +108,9 @@ TEST_F(TestMathTimeSeriesFunctor, unhandledTypes) EXPECT_DEBUG_DEATH(mio::TimeSeriesFunctor(unhandled_type, mio::TimeSeries(0)), "Unhandled TimeSeriesFunctorType!"); - // abuse auto_serialize to set an invalid type + // abuse default_serialize to set an invalid type mio::TimeSeriesFunctor functor; - std::get<0>(functor.auto_serialize().named_refs).value = unhandled_type; + std::get<0>(functor.default_serialize().named_refs).value = unhandled_type; // check assert in functor call EXPECT_DEBUG_DEATH(functor(0.0), "Unhandled TimeSeriesFunctorType!"); From f2e158c8a46d837afab2ecd94b04403a661a41db Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Tue, 27 Aug 2024 12:51:59 +0200 Subject: [PATCH 30/42] rename auto_serialization.h/.cpp --- cpp/memilio/CMakeLists.txt | 4 ++-- cpp/memilio/io/{auto_serialize.cpp => default_serialize.cpp} | 2 +- cpp/memilio/io/{auto_serialize.h => default_serialize.h} | 0 cpp/memilio/math/time_series_functor.h | 2 +- cpp/memilio/utils/random_number_generator.h | 2 +- cpp/models/abm/infection.h | 2 +- cpp/models/abm/location.h | 2 +- cpp/models/abm/mask.h | 2 +- cpp/models/abm/parameters.h | 2 +- cpp/models/abm/person.h | 2 +- cpp/models/abm/test_type.h | 2 +- cpp/models/abm/testing_strategy.h | 2 +- cpp/models/abm/time.h | 2 +- cpp/models/abm/trip_list.h | 2 +- cpp/models/abm/vaccine.h | 2 +- 15 files changed, 15 insertions(+), 15 deletions(-) rename cpp/memilio/io/{auto_serialize.cpp => default_serialize.cpp} (94%) rename cpp/memilio/io/{auto_serialize.h => default_serialize.h} (100%) diff --git a/cpp/memilio/CMakeLists.txt b/cpp/memilio/CMakeLists.txt index 4e121f5d84..504b542620 100644 --- a/cpp/memilio/CMakeLists.txt +++ b/cpp/memilio/CMakeLists.txt @@ -26,8 +26,8 @@ add_library(memilio compartments/simulation.h compartments/flow_simulation.h compartments/parameter_studies.h - io/auto_serialize.h - io/auto_serialize.cpp + io/default_serialize.h + io/default_serialize.cpp io/io.h io/io.cpp io/hdf5_cpp.h diff --git a/cpp/memilio/io/auto_serialize.cpp b/cpp/memilio/io/default_serialize.cpp similarity index 94% rename from cpp/memilio/io/auto_serialize.cpp rename to cpp/memilio/io/default_serialize.cpp index 054a77c1b3..74bfe2acd5 100644 --- a/cpp/memilio/io/auto_serialize.cpp +++ b/cpp/memilio/io/default_serialize.cpp @@ -17,4 +17,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "memilio/io/auto_serialize.h" +#include "memilio/io/default_serialize.h" diff --git a/cpp/memilio/io/auto_serialize.h b/cpp/memilio/io/default_serialize.h similarity index 100% rename from cpp/memilio/io/auto_serialize.h rename to cpp/memilio/io/default_serialize.h diff --git a/cpp/memilio/math/time_series_functor.h b/cpp/memilio/math/time_series_functor.h index 1469bd6da8..a50982ba10 100644 --- a/cpp/memilio/math/time_series_functor.h +++ b/cpp/memilio/math/time_series_functor.h @@ -20,7 +20,7 @@ #ifndef MIO_MATH_TIME_SERIES_FUNCTOR_H #define MIO_MATH_TIME_SERIES_FUNCTOR_H -#include "memilio/io/auto_serialize.h" +#include "memilio/io/default_serialize.h" #include "memilio/math/interpolation.h" #include "memilio/utils/time_series.h" diff --git a/cpp/memilio/utils/random_number_generator.h b/cpp/memilio/utils/random_number_generator.h index 909188973a..96456dade4 100644 --- a/cpp/memilio/utils/random_number_generator.h +++ b/cpp/memilio/utils/random_number_generator.h @@ -21,7 +21,7 @@ #ifndef MIO_RANDOM_NUMBER_GENERATOR_H #define MIO_RANDOM_NUMBER_GENERATOR_H -#include "memilio/io/auto_serialize.h" +#include "memilio/io/default_serialize.h" #include "memilio/utils/compiler_diagnostics.h" #include "memilio/utils/logging.h" #include "memilio/utils/miompi.h" diff --git a/cpp/models/abm/infection.h b/cpp/models/abm/infection.h index 94896ec562..64d9ead191 100644 --- a/cpp/models/abm/infection.h +++ b/cpp/models/abm/infection.h @@ -21,7 +21,7 @@ #define MIO_ABM_INFECTION_H #include "abm/personal_rng.h" -#include "memilio/io/auto_serialize.h" +#include "memilio/io/default_serialize.h" #include "abm/time.h" #include "abm/infection_state.h" #include "abm/virus_variant.h" diff --git a/cpp/models/abm/location.h b/cpp/models/abm/location.h index c4cfc13caa..1cf03e90aa 100644 --- a/cpp/models/abm/location.h +++ b/cpp/models/abm/location.h @@ -25,7 +25,7 @@ #include "abm/parameters.h" #include "abm/location_type.h" -#include "memilio/io/auto_serialize.h" +#include "memilio/io/default_serialize.h" #include "boost/atomic/atomic.hpp" namespace mio diff --git a/cpp/models/abm/mask.h b/cpp/models/abm/mask.h index 4b601e3e1e..4c3b3b74d3 100644 --- a/cpp/models/abm/mask.h +++ b/cpp/models/abm/mask.h @@ -23,7 +23,7 @@ #include "abm/mask_type.h" #include "abm/time.h" -#include "memilio/io/auto_serialize.h" +#include "memilio/io/default_serialize.h" namespace mio { diff --git a/cpp/models/abm/parameters.h b/cpp/models/abm/parameters.h index 6b8de8d9f2..bbbd29c30b 100644 --- a/cpp/models/abm/parameters.h +++ b/cpp/models/abm/parameters.h @@ -26,7 +26,7 @@ #include "abm/vaccine.h" #include "abm/test_type.h" #include "memilio/config.h" -#include "memilio/io/auto_serialize.h" +#include "memilio/io/default_serialize.h" #include "memilio/io/io.h" #include "memilio/math/time_series_functor.h" #include "memilio/utils/custom_index_array.h" diff --git a/cpp/models/abm/person.h b/cpp/models/abm/person.h index aff80a0b4c..e1267bba4f 100755 --- a/cpp/models/abm/person.h +++ b/cpp/models/abm/person.h @@ -28,7 +28,7 @@ #include "abm/parameters.h" #include "abm/person_id.h" #include "abm/personal_rng.h" -#include "memilio/io/auto_serialize.h" +#include "memilio/io/default_serialize.h" #include "abm/time.h" #include "abm/test_type.h" #include "abm/vaccine.h" diff --git a/cpp/models/abm/test_type.h b/cpp/models/abm/test_type.h index f3502be43c..deaaaecf1c 100644 --- a/cpp/models/abm/test_type.h +++ b/cpp/models/abm/test_type.h @@ -21,7 +21,7 @@ #define MIO_ABM_TEST_TYPE_H #include "abm/time.h" -#include "memilio/io/auto_serialize.h" +#include "memilio/io/default_serialize.h" #include diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index b3e9cd4fad..359bb00ce4 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -27,7 +27,7 @@ #include "abm/person.h" #include "abm/location.h" #include "abm/time.h" -#include "memilio/io/auto_serialize.h" +#include "memilio/io/default_serialize.h" #include #include diff --git a/cpp/models/abm/time.h b/cpp/models/abm/time.h index af9a7f2e51..f2d484512d 100644 --- a/cpp/models/abm/time.h +++ b/cpp/models/abm/time.h @@ -20,7 +20,7 @@ #ifndef MIO_ABM_TIME_H #define MIO_ABM_TIME_H -#include "memilio/io/auto_serialize.h" +#include "memilio/io/default_serialize.h" namespace mio { diff --git a/cpp/models/abm/trip_list.h b/cpp/models/abm/trip_list.h index bcf64b0ffb..b96861ffef 100644 --- a/cpp/models/abm/trip_list.h +++ b/cpp/models/abm/trip_list.h @@ -24,7 +24,7 @@ #include "abm/mobility_data.h" #include "abm/person_id.h" #include "abm/time.h" -#include "memilio/io/auto_serialize.h" +#include "memilio/io/default_serialize.h" #include namespace mio diff --git a/cpp/models/abm/vaccine.h b/cpp/models/abm/vaccine.h index 54bca108d8..b04777029a 100644 --- a/cpp/models/abm/vaccine.h +++ b/cpp/models/abm/vaccine.h @@ -20,7 +20,7 @@ #ifndef MIO_ABM_VACCINE_H #define MIO_ABM_VACCINE_H -#include "memilio/io/auto_serialize.h" +#include "memilio/io/default_serialize.h" #include "abm/time.h" #include From 2f5bd391c4a1b1fe8d4bacebd60552141e8d14d6 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Tue, 27 Aug 2024 12:59:14 +0200 Subject: [PATCH 31/42] use uniform_real_distribution directly to not get accidentally mocked --- cpp/tests/random_number_test.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/tests/random_number_test.h b/cpp/tests/random_number_test.h index f95307deea..8d5fdfbacc 100644 --- a/cpp/tests/random_number_test.h +++ b/cpp/tests/random_number_test.h @@ -21,6 +21,8 @@ #include +#include + class RandomNumberTest : public ::testing::Test { public: @@ -33,7 +35,7 @@ class RandomNumberTest : public ::testing::Test template double random_number(FP min = FP{-1e+3}, FP max = {1e+3}) { - return mio::UniformDistribution::get_instance()(m_rng, min, max); + return std::uniform_real_distribution(min, max)(m_rng); } /// @brief Access the random number generator. Should only be used for debugging. From 0f3ac784939f3b51595f9388a59001f9a0e1b689 Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Tue, 27 Aug 2024 14:36:54 +0200 Subject: [PATCH 32/42] simplify abm serialization test --- cpp/tests/test_abm_serialization.cpp | 91 ++++++++++------------------ 1 file changed, 32 insertions(+), 59 deletions(-) diff --git a/cpp/tests/test_abm_serialization.cpp b/cpp/tests/test_abm_serialization.cpp index 90d4ad4482..9d175fb772 100644 --- a/cpp/tests/test_abm_serialization.cpp +++ b/cpp/tests/test_abm_serialization.cpp @@ -56,7 +56,7 @@ * @param reference_json A json value representing an instance of T. */ template -void test_json_serialization_without_equality(const Json::Value& reference_json) +void test_json_serialization(const Json::Value& reference_json) { // check that the json is deserializable (i.e. a valid representation) auto t_result = mio::deserialize_json(reference_json, mio::Tag()); @@ -69,59 +69,35 @@ void test_json_serialization_without_equality(const Json::Value& reference_json) EXPECT_THAT(json_result.value(), JsonEqual(reference_json)); } -/** - * @brief Test de- and serialization of an object using its equality operator. - * - * First, test that serializing the reference_object is equal to the reference_json. - * Then, test that deserializing the reference_json results in an object equal to the reference_object. - * - * @tparam T The type to test. - * @param reference_object An instance of T. - * @param reference_json A json value representing reference_object. - */ -template -void test_json_serialization_full(const T& reference_object, const Json::Value& reference_json) -{ - // check that the reference type T is serializable - auto json_result = mio::serialize_json(reference_object); - ASSERT_THAT(print_wrap(json_result), IsSuccess()); - EXPECT_THAT(json_result.value(), JsonEqual(reference_json)); - - // check that the reference json is deserializable - auto t_result = mio::deserialize_json(reference_json, mio::Tag()); - ASSERT_THAT(print_wrap(t_result), IsSuccess()); - EXPECT_EQ(t_result.value(), reference_object); -} - TEST(TestAbmSerialization, Trip) { // See test_json_serialization_full for info on this test. - mio::abm::Trip trip(1, mio::abm::TimePoint(2), 3, 4); + unsigned i = 1; // counter s.t. members have different values Json::Value reference_json; - reference_json["person_id"] = Json::UInt(1); - reference_json["time"]["seconds"] = Json::Int(2); - reference_json["destination"] = Json::UInt(3); - reference_json["origin"] = Json::UInt(4); + reference_json["person_id"] = Json::UInt(i++); + reference_json["time"]["seconds"] = Json::Int(i++); + reference_json["destination"] = Json::UInt(i++); + reference_json["origin"] = Json::UInt(i++); - test_json_serialization_full(trip, reference_json); + test_json_serialization(reference_json); } TEST(TestAbmSerialization, Vaccination) { - // See test_json_serialization_without_equality for info on this test. + // See test_json_serialization for info on this test. Json::Value reference_json; reference_json["exposure_type"] = Json::UInt(1); reference_json["time"]["seconds"] = Json::Int(2); - test_json_serialization_without_equality(reference_json); + test_json_serialization(reference_json); } TEST(TestAbmSerialization, Infection) { - // See test_json_serialization_without_equality for info on this test. + // See test_json_serialization for info on this test. unsigned i = 1; // counter s.t. members have different values @@ -140,47 +116,44 @@ TEST(TestAbmSerialization, Infection) reference_json["log_norm_beta"] = Json::Value((double)i++); reference_json["detected"] = Json::Value((bool)0); - test_json_serialization_without_equality(reference_json); + test_json_serialization(reference_json); } TEST(TestAbmSerialization, TestingScheme) { // See test_json_serialization_full for info on this test. - mio::abm::TestingScheme testing_scheme( - mio::abm::TestingCriteria({mio::AgeGroup(1)}, {mio::abm::InfectionState(2)}), mio::abm::TimeSpan(3), - mio::abm::TimePoint(4), mio::abm::TimePoint(5), - mio::abm::TestParameters{{6.0}, {7.0}, mio::abm::TimeSpan{8}, mio::abm::TestType(0)}, 9.0); + unsigned i = 1; // counter s.t. members have different values Json::Value testing_criteria; std::vector ages_bits(mio::abm::MAX_NUM_AGE_GROUPS, false); - ages_bits[1] = true; + ages_bits[i++] = true; testing_criteria["ages"]["bitset"] = mio::serialize_json(ages_bits).value(); std::vector inf_st_bits((size_t)mio::abm::InfectionState::Count, false); - inf_st_bits[2] = true; + inf_st_bits[i++] = true; testing_criteria["infection_states"]["bitset"] = mio::serialize_json(inf_st_bits).value(); Json::Value test_parameters; - test_parameters["sensitivity"] = mio::serialize_json(mio::UncertainValue{6.0}).value(); - test_parameters["specificity"] = mio::serialize_json(mio::UncertainValue{7.0}).value(); - test_parameters["required_time"]["seconds"] = Json::Int(8); + test_parameters["sensitivity"] = mio::serialize_json(mio::UncertainValue{(double)i++}).value(); + test_parameters["specificity"] = mio::serialize_json(mio::UncertainValue{(double)i++}).value(); + test_parameters["required_time"]["seconds"] = Json::Int(i++); test_parameters["test_type"] = Json::UInt(0); Json::Value reference_json; reference_json["criteria"] = testing_criteria; - reference_json["validity_period"]["seconds"] = Json::Int(3); - reference_json["start_date"]["seconds"] = Json::Int(4); - reference_json["end_date"]["seconds"] = Json::Int(5); + reference_json["validity_period"]["seconds"] = Json::Int(i++); + reference_json["start_date"]["seconds"] = Json::Int(i++); + reference_json["end_date"]["seconds"] = Json::Int(i++); reference_json["test_params"] = test_parameters; - reference_json["probability"] = Json::Value((double)9); + reference_json["probability"] = Json::Value((double)i++); reference_json["is_active"] = Json::Value((bool)0); - test_json_serialization_full(testing_scheme, reference_json); + test_json_serialization(reference_json); } TEST(TestAbmSerialization, TestingStrategy) { - // See test_json_serialization_without_equality for info on this test. + // See test_json_serialization for info on this test. unsigned i = 1; // counter s.t. members have different values @@ -192,23 +165,23 @@ TEST(TestAbmSerialization, TestingStrategy) Json::Value reference_json; reference_json["schemes"][0] = local_strategy; - test_json_serialization_without_equality(reference_json); + test_json_serialization(reference_json); } TEST(TestAbmSerialization, TestResult) { - // See test_json_serialization_without_equality for info on this test. + // See test_json_serialization for info on this test. Json::Value reference_json; reference_json["result"] = Json::Value(false); reference_json["time_of_testing"]["seconds"] = Json::Int(1); - test_json_serialization_without_equality(reference_json); + test_json_serialization(reference_json); } TEST(TestAbmSerialization, Person) { - // See test_json_serialization_without_equality for info on this test. + // See test_json_serialization for info on this test. auto json_uint_array = [](std::vector values) { return mio::serialize_json(values).value(); @@ -245,12 +218,12 @@ TEST(TestAbmSerialization, Person) reference_json["vaccinations"] = Json::Value(Json::arrayValue); reference_json["wears_mask"] = Json::Value(false); - test_json_serialization_without_equality(reference_json); + test_json_serialization(reference_json); } TEST(TestAbmSerialization, Location) { - // See test_json_serialization_without_equality for info on this test. + // See test_json_serialization for info on this test. unsigned i = 1; // counter s.t. members have different values @@ -268,12 +241,12 @@ TEST(TestAbmSerialization, Location) reference_json["parameters"]["UseLocationCapacityForTransmissions"] = Json::Value(false); reference_json["required_mask"] = Json::UInt(0); - test_json_serialization_without_equality(reference_json); + test_json_serialization(reference_json); } TEST(TestAbmSerialization, Model) { - // See test_json_serialization_without_equality for info on this test. + // See test_json_serialization for info on this test. auto json_uint_array = [](std::vector values) { return mio::serialize_json(values).value(); @@ -298,7 +271,7 @@ TEST(TestAbmSerialization, Model) reference_json["trip_list"]["trips_weekend"] = Json::Value(Json::arrayValue); reference_json["use_mobility_rules"] = Json::Value(false); - test_json_serialization_without_equality(reference_json); + test_json_serialization(reference_json); } #endif From 383710249f5aeb61524caf6c279b4b3a810d57aa Mon Sep 17 00:00:00 2001 From: reneSchm <49305466+reneSchm@users.noreply.github.com> Date: Tue, 27 Aug 2024 14:37:44 +0200 Subject: [PATCH 33/42] add threadsafe flag to death tests --- cpp/tests/test_io_cli.cpp | 8 +++++++- cpp/tests/test_math_time_series_functor.cpp | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/cpp/tests/test_io_cli.cpp b/cpp/tests/test_io_cli.cpp index dc376a0f88..5301a82cce 100644 --- a/cpp/tests/test_io_cli.cpp +++ b/cpp/tests/test_io_cli.cpp @@ -151,6 +151,8 @@ using Params = mio::ParameterSet; // using BadParams = mio::ParameterSet; TEST(TestCLI, test_option_verifier) { + GTEST_FLAG_SET(death_test_style, "threadsafe"); + EXPECT_DEBUG_DEATH(mio::details::cli::verify_options(mio::ParameterSet()), ".*Options may not have duplicate fields\\. \\(field required\\)"); EXPECT_DEBUG_DEATH(mio::details::cli::verify_options(mio::ParameterSet()), @@ -212,6 +214,8 @@ TEST(TestCLI, test_set_param) TEST(TestCLI, test_write_help) { + GTEST_FLAG_SET(death_test_style, "threadsafe"); + std::stringstream ss; const std::string help = "Usage: TestSuite