diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index e6235e3b4f..5b192a4f5e 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -76,8 +76,8 @@ if(CMAKE_BUILD_TYPE STREQUAL "Debug" OR CMAKE_BUILD_TYPE STREQUAL "DEBUG") message(STATUS "Coverage enabled") include(CodeCoverage) append_coverage_compiler_flags() - # In addition to standard flags, disable elision and inlining to prevent e.g. closing brackets being marked as uncovered. - + # In addition to standard flags, disable elision and inlining to prevent e.g. closing brackets being marked as + # uncovered. set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-elide-constructors -fno-default-inline") setup_target_for_coverage_lcov( NAME coverage diff --git a/cpp/memilio/CMakeLists.txt b/cpp/memilio/CMakeLists.txt index ce05d3948d..e8b58aded5 100644 --- a/cpp/memilio/CMakeLists.txt +++ b/cpp/memilio/CMakeLists.txt @@ -27,6 +27,8 @@ add_library(memilio compartments/simulation.h compartments/flow_simulation.h compartments/parameter_studies.h + io/default_serialize.h + io/default_serialize.cpp io/io.h io/io.cpp io/hdf5_cpp.h @@ -57,6 +59,8 @@ add_library(memilio math/matrix_shape.cpp math/interpolation.h math/interpolation.cpp + math/time_series_functor.h + math/time_series_functor.cpp mobility/metapopulation_mobility_instant.h mobility/metapopulation_mobility_instant.cpp mobility/metapopulation_mobility_stochastic.h diff --git a/cpp/memilio/io/README.md b/cpp/memilio/io/README.md index fb23057493..99055e5de6 100644 --- a/cpp/memilio/io/README.md +++ b/cpp/memilio/io/README.md @@ -4,7 +4,25 @@ This directory contains utilities for reading and writing data from and to files ## The Serialization framework -## Main functions and types +### Using serialization + +In the next sections we will explain how to implement serialization (both for types and formats), here we quickly show +how to use it once it already is implemented for a type. In the following examples, we serialize (write) `Foo` to a +file in Json format, then deserialize (read) the Json again. +```cpp +Foo foo{5}; +mio::IOResult io_result = mio::write_json("path/to/foo.json", foo); +``` +```cpp +mio::IOResult io_result = mio::read_json("path/to/foo.json", mio::Tag{}); +if (io_result) { + Foo foo = io_result.value(); +} +``` +There is also support for a binary format. If you want to use a format directly, use the +`serialize_json`/`deserialize_json` and `serialize_binary`/`deserialize_binary` functions. + +### Main functions and types - functions serialize and deserialize: Main entry points to the framework to write and read values, respectively. The functions expect an IOContext @@ -14,7 +32,38 @@ This directory contains utilities for reading and writing data from and to files - IOStatus and IOResult: Used for error handling, see section "Error Handling" below. -## Concepts +### Default serialization + +Before we get into the details of the framework, this feature provides an easy and convenient alternative to the +serialize and deserialize functions. To give an example: + +```cpp +struct Foo { + int i; + auto default_serialize() { + return Members("Foo").add("i", i); + } +}; +``` +The default serialization is less flexible than the serialize and deserialize functions and has additional +requirements: +- The class must be default constructible. + - If there is a default constructor that is *private*, it can still be used by marking the struct `DefaultFactory` as + a friend. For the example above, the line `friend DefaultFactory;` would be added to the class definition. + - Alternatively, you may provide a specialization of the struct `DefaultFactory`. For more details, + view the struct's documentation. +- Every class member must be added to `Members` exactly once, and the provided names must be unique. + - The members must be passed directly, like in the example. No copies, accessors, etc. + - It is recommended, but not required, to add member variables to `Members` in the same order they are declared in + the class, using the variables' names or something very similar. +- Every class member itself must be serializable, deserializable and assignable. + +As to the feature set, default-serialization only supports the `add_element` and `expect_element` operations defined in +the Concepts section below, where each operation's arguments are provided through the `add` function. Note that the +value provided to `add` is also used to assign a value during deserialization, hence the class members must be used +directly in the function (i.e. as a non-const lvalue reference). + +### Concepts 1. IOContext Stores data that describes serialized objects of any type in some unspecified format and provides structured @@ -66,7 +115,7 @@ for an IOObject `obj`: value or it may be empty. Otherwise returns an error. Note that for some formats a wrong key is indistinguishable from an empty optional, so make sure to provide the correct key. -## Error handling +### Error handling Errors are handled by returning error codes. The type IOStatus contains an error code and an optional string with additional information. The type IOResult contains either a value or an IOStatus that describes an error. Operations that can fail return @@ -78,7 +127,7 @@ inspected, so `expect_...` operations return an IOResult. The `apply` utility fu of multiple `expect_...` operations and use the values if all are succesful. See the documentation of `IOStatus`, `IOResult` and `apply` below for more details. -## Adding a new data type to be serialized +### Adding a new data type to be serialized Serialization of a new type T can be customized by providing _either_ member functions `serialize` and `deserialize` _or_ free functions `serialize_internal` and `deserialize_internal`. diff --git a/cpp/memilio/io/binary_serializer.h b/cpp/memilio/io/binary_serializer.h index 2366ba56d4..79bab26871 100644 --- a/cpp/memilio/io/binary_serializer.h +++ b/cpp/memilio/io/binary_serializer.h @@ -278,7 +278,7 @@ class BinarySerializerContext : public SerializerBase "Unexpected type in stream:" + type_result.value() + ". Expected " + type); } } - return BinarySerializerObject(m_stream, m_status, m_flags); + return obj; } /** diff --git a/cpp/memilio/io/default_serialize.cpp b/cpp/memilio/io/default_serialize.cpp new file mode 100644 index 0000000000..74bfe2acd5 --- /dev/null +++ b/cpp/memilio/io/default_serialize.cpp @@ -0,0 +1,20 @@ +/* +* Copyright (C) 2020-2024 MEmilio +* +* Authors: Rene Schmieding +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#include "memilio/io/default_serialize.h" diff --git a/cpp/memilio/io/default_serialize.h b/cpp/memilio/io/default_serialize.h new file mode 100644 index 0000000000..5100482551 --- /dev/null +++ b/cpp/memilio/io/default_serialize.h @@ -0,0 +1,254 @@ +/* +* Copyright (C) 2020-2024 MEmilio +* +* Authors: Rene Schmieding +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#ifndef MIO_IO_DEFAULT_SERIALIZE_H_ +#define MIO_IO_DEFAULT_SERIALIZE_H_ + +#include "memilio/io/io.h" +#include "memilio/utils/metaprogramming.h" + +#include +#include +#include + +namespace mio +{ + +/** + * @brief A pair of name and reference. + * + * Used for default (de)serialization. + * This object holds a char pointer to a name and reference to value. Mind their lifetime! + * @tparam ValueType The (non-cv, non-reference) type of the value. + */ +template +struct NamedRef { + using Reference = ValueType&; + + const char* name; + Reference value; + + /** + * @brief Create a named reference. + * + * @param n A string literal. + * @param v A non-const lvalue reference to the value. + */ + explicit NamedRef(const char* n, Reference v) + : name(n) + , value(v) + { + } +}; + +namespace details +{ + +/** + * @brief Helper type to detect whether T has a default_serialize member function. + * Use has_default_serialize. + * @tparam T Any type. + */ +template +using default_serialize_expr_t = decltype(std::declval().default_serialize()); + +/// Add a name-value pair to an io object. +template +void add_named_ref(IOObject& obj, const NamedRef named_ref) +{ + obj.add_element(named_ref.name, named_ref.value); +} + +/// Unpack all name-value pairs from the tuple and add them to a new io object with the given name. +template +void default_serialize_impl(IOContext& io, const char* name, const NamedRef... named_refs) +{ + auto obj = io.create_object(name); + (add_named_ref(obj, named_refs), ...); +} + +/// Retrieve a name-value pair from an io object. +template +IOResult expect_named_ref(IOObject& obj, const NamedRef named_ref) +{ + return obj.expect_element(named_ref.name, Tag{}); +} + +/// Read an io object and its members from the io context using the given names and assign the values to a. +template +IOResult default_deserialize_impl(IOContext& io, DefaultSerializable& a, const char* name, + NamedRef... named_refs) +{ + auto obj = io.expect_object(name); + + // we cannot use expect_named_ref directly in apply, as function arguments have no guarantueed order of evaluation + std::tuple...> results{expect_named_ref(obj, named_refs)...}; + + return apply( + io, + [&a, &named_refs...](const Members&... result_values) { + // if all results are successfully deserialized, they are unpacked into result_values + // then all class variables are overwritten (via the named_refs) with these values + ((named_refs.value = result_values), ...); + return a; + }, + results); +} + +} // namespace details + +/** + * @brief List of a class's members. + * + * Used for default (de)serialization. + * Holds a char pointer to the class name as well as a tuple of NamedRefs with all added class members. + * Initially, the template parameter pack should be left empty. It will be filled by calling Members::add. + * @tparam ValueTypes The (non-cv, non-reference) types of member variables. + */ +template +struct Members { + // allow other Members access to the private constructor + template + friend struct Members; + + /** + * @brief Initialize Members with a class name. Use the member function `add` to specify the class's variables. + * @param[in] class_name Name of a class. + */ + Members(const char* class_name) + : name(class_name) + , named_refs() + { + } + + /** + * @brief Add a class member. + * + * Use this function consecutively for all members, e.g. `Members("class").add("a", a).add("b", b).add...`. + * + * @param[in] member_name The name used for serialization. Should be the same as or similar to the class member. + * For example, a good option a private class member `m_time` is simply `"time"`. + * @param[in] member A class member. Always pass this variable directly, do not use getters or accessors. + * @return A Members object with all previous class members and the newly added one. + */ + template + [[nodiscard]] Members add(const char* member_name, T& member) + { + return Members{name, std::tuple_cat(named_refs, std::tuple(NamedRef{member_name, member}))}; + } + + const char* name; ///< Name of the class. + std::tuple...> named_refs; ///< Names and references to members of the class. + +private: + /** + * @brief Initialize Members directly. Used by the add function. + * @param[in] class_name Name of a class. + * @param[in] named_references Tuple of added class Members. + */ + Members(const char* class_name, std::tuple...> named_references) + : name(class_name) + , named_refs(named_references) + { + } +}; + +/** + * @brief Creates an instance of T for later initialization. + * + * The default implementation uses the default constructor of T, if available. If there is no default constructor, this + * class can be spezialized to provide the method `static T create()`. If there is a default constructor, but it is + * private, DefaultFactory can be marked as friend instead. + * + * The state of the object retured by `create()` is completely arbitrary, and may be invalid. Make sure to set it to a + * valid state before using it further. + * + * @tparam T The type to create. + */ +template +struct DefaultFactory { + /// @brief Creates a new instance of T. + static T create() + { + return T{}; + } +}; + +/** + * @brief Detect whether T has a default_serialize member function. + * @tparam T Any type. + */ +template +using has_default_serialize = is_expression_valid; + +/** + * @brief Serialization implementation for the default serialization feature. + * Disables itself (SFINAE) if there is no default_serialize member or if a serialize member is present. + * Generates the serialize method depending on the NamedRefs given by default_serialize. + * @tparam IOContext A type that models the IOContext concept. + * @tparam DefaultSerializable A type that can be default serialized. + * @param io An IO context. + * @param a An instance of DefaultSerializable to be serialized. + */ +template ::value && + !has_serialize::value, + DefaultSerializable*> = nullptr> +void serialize_internal(IOContext& io, const DefaultSerializable& a) +{ + // Note that the following cons_cast is only safe if we do not modify members. + const auto members = const_cast(a).default_serialize(); + // unpack members and serialize + std::apply( + [&io, &members](auto... named_refs) { + details::default_serialize_impl(io, members.name, named_refs...); + }, + members.named_refs); +} + +/** + * @brief Deserialization implementation for the default serialization feature. + * Disables itself (SFINAE) if there is no default_serialize member or if a deserialize meember is present. + * Generates the deserialize method depending on the NamedRefs given by default_serialize. + * @tparam IOContext A type that models the IOContext concept. + * @tparam DefaultSerializable A type that can be default serialized. + * @param io An IO context. + * @param tag Defines the type of the object that is to be deserialized (i.e. DefaultSerializble). + * @return The restored object if successful, an error otherwise. + */ +template ::value && + !has_deserialize::value, + DefaultSerializable*> = nullptr> +IOResult deserialize_internal(IOContext& io, Tag tag) +{ + mio::unused(tag); + DefaultSerializable a = DefaultFactory::create(); + auto members = a.default_serialize(); + // unpack members and deserialize + return std::apply( + [&io, &members, &a](auto... named_refs) { + return details::default_deserialize_impl(io, a, members.name, named_refs...); + }, + members.named_refs); +} + +} // namespace mio + +#endif // MIO_IO_DEFAULT_SERIALIZE_H_ diff --git a/cpp/memilio/io/io.h b/cpp/memilio/io/io.h index 7e8e9dbf82..fe3e651288 100644 --- a/cpp/memilio/io/io.h +++ b/cpp/memilio/io/io.h @@ -27,6 +27,9 @@ #include "boost/outcome/result.hpp" #include "boost/outcome/try.hpp" #include "boost/optional.hpp" + +#include +#include #include #include @@ -467,6 +470,7 @@ ApplyResultT eval(F f, const IOResult&... rs) * @param f the function that is called with the values contained in `rs` as arguments. * @param rs zero or more IOResults from previous operations. * @return the result of f(rs.value()...) if successful, the first error encountered otherwise. + * @{ */ template details::ApplyResultT apply(IOContext& io, F f, const IOResult&... rs) @@ -482,7 +486,7 @@ details::ApplyResultT apply(IOContext& io, F f, const IOResult&... r IOStatus status[] = {(rs ? IOStatus{} : rs.error())...}; auto iter_err = std::find_if(std::begin(status), std::end(status), [](auto& s) { return s.is_error(); - }); + }); //evaluate f if all succesful auto result = @@ -497,6 +501,17 @@ details::ApplyResultT apply(IOContext& io, F f, const IOResult&... r return result; } +template +details::ApplyResultT apply(IOContext& io, F f, const std::tuple...>& results) +{ + return std::apply( + [&](auto&&... rs) { + return apply(io, f, rs...); + }, + results); +} +/** @} */ + //utility for (de-)serializing tuple-like objects namespace details { @@ -632,6 +647,56 @@ IOResult deserialize_internal(IOContext& io, Tag /*tag*/) rows, cols, elements); } +/** + * @brief Serialize an std::bitset. + * @tparam IOContext A type that models the IOContext concept. + * @tparam N The size of the bitset. + * @param io An IO context. + * @param bitset A bitset to be serialized. + */ +template +void serialize_internal(IOContext& io, const std::bitset bitset) +{ + std::array bits; + for (size_t i = 0; i < N; i++) { + bits[i] = bitset[i]; + } + auto obj = io.create_object("BitSet"); + obj.add_list("bitset", bits.begin(), bits.end()); +} + +/** + * @brief Deserialize an std::bitset. + * @tparam IOContext A type that models the IOContext concept. + * @tparam N The size of the bitset. + * @param io An IO context. + * @param tag Defines the type of the object that is to be deserialized. + * @return The restored object if successful, an error otherwise. + */ +template +IOResult> deserialize_internal(IOContext& io, Tag> tag) +{ + mio::unused(tag); + auto obj = io.expect_object("BitSet"); + auto bits = obj.expect_list("bitset", Tag{}); + + return apply( + io, + [](auto&& bits_) -> IOResult> { + if (bits_.size() != N) { + return failure(StatusCode::InvalidValue, + "Incorrent number of booleans to deserialize bitset. Expected " + std::to_string(N) + + ", got " + std::to_string(bits_.size()) + "."); + } + std::bitset bitset; + for (size_t i = 0; i < N; i++) { + bitset[i] = bits_[i]; + } + return bitset; + }, + bits); +} + /** * serialize an enum value as its underlying type. * @tparam IOContext a type that models the IOContext concept. diff --git a/cpp/memilio/math/interpolation.h b/cpp/memilio/math/interpolation.h index ad08eed373..a000607ab4 100644 --- a/cpp/memilio/math/interpolation.h +++ b/cpp/memilio/math/interpolation.h @@ -17,11 +17,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef INTERPOLATION_H_ -#define INTERPOLATION_H_ +#ifndef MIO_MATH_INTERPOLATION_H_ +#define MIO_MATH_INTERPOLATION_H_ +#include "memilio/utils/logging.h" +#include "memilio/utils/time_series.h" + +#include #include #include -#include "memilio/utils/logging.h" namespace mio { @@ -34,21 +37,52 @@ namespace mio * @param[in] x_2 Right node of interpolation. * @param[in] y_1 Value at left node. * @param[in] y_2 Value at right node. - * @param[out] unnamed Interpolation result. + * @return Interpolation result. */ template auto linear_interpolation(const X& x_eval, const X& x_1, const X& x_2, const V& y1, const V& y2) { - auto weight = (x_eval - x_1) / (x_2 - x_1); + const auto weight = (x_eval - x_1) / (x_2 - x_1); return y1 + weight * (y2 - y1); } +/** + * @brief Linear interpolation of a TimeSeries. + * Assumes that the times in the time series are monotonic increasing. If the times are *strictly* monotonic, + * this function is continuous in time. + * If the given interpolation time is outside of the provided time points, this function assumes a constant value of + * the first/last time point. + * @param[in] time The time at which to evaluate. + * @param[in] data Time points to interpolate. At least one is required. + * @return Interpolation result. + */ +template +typename TimeSeries::Vector linear_interpolation(FP time, const TimeSeries& data) +{ + assert(data.get_num_time_points() > 0 && "Interpolation requires at least one time point."); + const auto tp_range = data.get_times(); + // find next time point in data (strictly) after time + const auto next_tp = std::upper_bound(tp_range.begin(), tp_range.end(), time); + // interpolate in between values if possible, otherwise return first/last value + if (next_tp == tp_range.begin()) { // time is before first data point + return data.get_value(0); + } + else if (next_tp == tp_range.end()) { // time is past or equal to last data point + return data.get_last_value(); + } + else { // time is in between data points + const auto i = next_tp - tp_range.begin(); // index of strict upper bound + return linear_interpolation(time, data.get_time(i - 1), data.get_time(i), data.get_value(i - 1), + data.get_value(i)); + } +} + /** * @brief Linear interpolation between two points of a dataset, which is represented by a vector of pairs of node and value. * Return 0 if there is less than two points in the dataset. * @param[in] vector Vector of pairs of node and value. * @param[in] x_eval Location to evaluate interpolation. - * @param[out] unnamed Interpolation result. + * @return Interpolation result. */ template Y linear_interpolation_of_data_set(std::vector> vector, const X& x_eval) @@ -80,4 +114,4 @@ Y linear_interpolation_of_data_set(std::vector> vector, const X& } // namespace mio -#endif +#endif // MIO_MATH_INTERPOLATION_H_ diff --git a/cpp/memilio/math/time_series_functor.cpp b/cpp/memilio/math/time_series_functor.cpp new file mode 100644 index 0000000000..4f0e16772f --- /dev/null +++ b/cpp/memilio/math/time_series_functor.cpp @@ -0,0 +1,20 @@ +/* +* Copyright (C) 2020-2024 MEmilio +* +* Authors: Rene Schmieding +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#include "memilio/math/time_series_functor.h" diff --git a/cpp/memilio/math/time_series_functor.h b/cpp/memilio/math/time_series_functor.h new file mode 100644 index 0000000000..3a08c86a15 --- /dev/null +++ b/cpp/memilio/math/time_series_functor.h @@ -0,0 +1,121 @@ +/* +* Copyright (C) 2020-2024 MEmilio +* +* Authors: Rene Schmieding +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#ifndef MIO_MATH_TIME_SERIES_FUNCTOR_H +#define MIO_MATH_TIME_SERIES_FUNCTOR_H + +#include "memilio/io/default_serialize.h" +#include "memilio/math/interpolation.h" +#include "memilio/utils/time_series.h" + +#include +#include + +namespace mio +{ + +/** + * @brief Type of a TimeSeriesFunctor. + * The available types are: + * - LinearInterpolation: + * - Requires at least one time point with exactly one value each. Time must be strictly monotic increasing. + * - Linearly interpolates between data points. Stays constant outside of provided data with first/last value. + */ +enum class TimeSeriesFunctorType +{ + LinearInterpolation, +}; + +template +class TimeSeriesFunctor +{ +public: + /** + * @brief Creates a functor using the given data. + * Note the data requirements of the given functor type. + * @param type The type of the functor. + * @param table A list of time points, passed to the TimeSeries constructor. + */ + TimeSeriesFunctor(TimeSeriesFunctorType type, const TimeSeries& data) + : m_type(type) + , m_data(data) + { + // data shape checks and preprocessing + switch (m_type) { + case TimeSeriesFunctorType::LinearInterpolation: + // make sure data has the correct shape, i.e. a list of (time, value) pairs + assert(m_data.get_num_time_points() > 0 && "Need at least one time point for LinearInterpolation."); + assert(m_data.get_num_elements() == 1 && "LinearInterpolation requires exactly one value per time point."); + assert(m_data.is_strictly_monotonic()); + break; + default: + assert(false && "Unhandled TimeSeriesFunctorType!"); + break; + } + } + + /** + * @brief Creates a functor using the given table. + * Note the data requirements of the given functor type. + * @param type The type of the functor. + * @param table A list of time points, passed to the TimeSeries constructor. + */ + TimeSeriesFunctor(TimeSeriesFunctorType type, std::vector>&& table) + : TimeSeriesFunctor(type, TimeSeries{table}) + { + } + + /** + * @brief Creates a Zero functor. + */ + TimeSeriesFunctor() + : TimeSeriesFunctor(TimeSeriesFunctorType::LinearInterpolation, {{FP(0.0), FP(0.0)}}) + { + } + + /** + * @brief Function returning a scalar value. + * @param time A scalar value. + * @return A scalar value computed from data, depending on the functor's type. + */ + FP operator()(FP time) const + { + switch (m_type) { + case TimeSeriesFunctorType::LinearInterpolation: + return linear_interpolation(time, m_data)[0]; + default: + assert(false && "Unhandled TimeSeriesFunctorType!"); + return FP(); + } + } + + /// This method is used by the default serialization feature. + auto default_serialize() + { + return Members("TimeSeriesFunctor").add("type", m_type).add("data", m_data); + } + +private: + TimeSeriesFunctorType m_type; ///< Determines what kind of functor this is, e.g. linear interpolation. + TimeSeries m_data; ///< Data used by the functor to compute its values. Its shape depends on type. +}; + +} // namespace mio + +#endif diff --git a/cpp/memilio/utils/random_number_generator.h b/cpp/memilio/utils/random_number_generator.h index 2ab4cadd48..96456dade4 100644 --- a/cpp/memilio/utils/random_number_generator.h +++ b/cpp/memilio/utils/random_number_generator.h @@ -21,6 +21,7 @@ #ifndef MIO_RANDOM_NUMBER_GENERATOR_H #define MIO_RANDOM_NUMBER_GENERATOR_H +#include "memilio/io/default_serialize.h" #include "memilio/utils/compiler_diagnostics.h" #include "memilio/utils/logging.h" #include "memilio/utils/miompi.h" @@ -357,6 +358,12 @@ class RandomNumberGenerator : public RandomNumberGeneratorBase m_key; Counter m_counter; @@ -669,6 +676,34 @@ using UniformIntDistribution = DistributionAdapter using UniformDistribution = DistributionAdapter>; +template ::ParamType>, + void*> = nullptr> +void serialize_internal(IOContext& io, const UniformDistributionParams& p) +{ + auto obj = io.create_object("UniformDistributionParams"); + obj.add_element("a", p.params.a()); + obj.add_element("b", p.params.b()); +} + +template ::ParamType>, + void*> = nullptr> +IOResult deserialize_internal(IOContext& io, Tag) +{ + auto obj = io.expect_object("UniformDistributionParams"); + auto a = obj.expect_element("a", Tag{}); + auto b = obj.expect_element("b", Tag{}); + return apply( + io, + [](auto&& a_, auto&& b_) { + return UniformDistributionParams{a_, b_}; + }, + a, b); +} + /** * adapted poisson_distribution. * @see DistributionAdapter diff --git a/cpp/memilio/utils/time_series.h b/cpp/memilio/utils/time_series.h index 0016e38d71..285da889e1 100644 --- a/cpp/memilio/utils/time_series.h +++ b/cpp/memilio/utils/time_series.h @@ -21,10 +21,10 @@ #define EPI_TIME_SERIES_H #include "memilio/io/io.h" -#include "memilio/math/eigen.h" #include "memilio/utils/stl_util.h" #include "memilio/math/floating_point.h" +#include #include #include #include @@ -102,6 +102,38 @@ class TimeSeries col.tail(expr.rows()) = expr; } + /** + * @brief Initialize a TimeSeries with a table. + * @param table Consists of a list of time points, each of the form (time, value_0, value_1, ..., value_n) for + * some fixed n >= 0. + */ + TimeSeries(std::vector> table) + : m_data() // resized in body + , m_num_time_points(table.size()) + { + // check table sizes + assert(table.size() > 0 && "At least one entry is required to determine the number of elements."); + assert(std::all_of(table.begin(), table.end(), + [&table](auto&& a) { + return a.size() == table.front().size(); + }) && + "All table entries must have the same size."); + // resize data. note that the table entries contain both time and values + m_data.resize(table.front().size(), 0); // set columns first so reserve allocates correctly + reserve(table.size()); // reserve needs to happen before setting the number of rows + m_data.resize(Eigen::NoChange, table.size()); // finalize resize by setting the rows + // sort table by time + std::sort(table.begin(), table.end(), [](auto&& a, auto&& b) { + return a[0] < b[0]; + }); + // assign table to data + for (Eigen::Index tp = 0; tp < m_data.cols(); tp++) { + for (Eigen::Index i = 0; i < m_data.rows(); i++) { + m_data(i, tp) = table[tp][i]; + } + } + } + /** copy ctor */ TimeSeries(const TimeSeries& other) : m_data(other.get_num_elements() + 1, details::next_pow2(other.m_num_time_points)) @@ -148,6 +180,16 @@ class TimeSeries TimeSeries(TimeSeries&& other) = default; TimeSeries& operator=(TimeSeries&& other) = default; + /// Check if the time is strictly monotonic increasing. + bool is_strictly_monotonic() const + { + const auto times = get_times(); + auto time_itr = times.begin(); + return std::all_of(++times.begin(), times.end(), [&](const auto& t) { + return *(time_itr++) < t; + }); + } + /** * number of time points in the series */ diff --git a/cpp/models/abm/infection.h b/cpp/models/abm/infection.h index ca9dae9bdd..64d9ead191 100644 --- a/cpp/models/abm/infection.h +++ b/cpp/models/abm/infection.h @@ -21,6 +21,7 @@ #define MIO_ABM_INFECTION_H #include "abm/personal_rng.h" +#include "memilio/io/default_serialize.h" #include "abm/time.h" #include "abm/infection_state.h" #include "abm/virus_variant.h" @@ -44,11 +45,21 @@ struct ViralLoad { ScalarType peak; ///< Peak amplitude of the ViralLoad. ScalarType incline; ///< Incline of the ViralLoad during incline phase in log_10 scale per day (always positive). ScalarType decline; ///< Decline of the ViralLoad during decline phase in log_10 scale per day (always negative). + + /// This method is used by the default serialization feature. + auto default_serialize() + { + return Members("ViralLoad") + .add("start_date", start_date) + .add("end_date", end_date) + .add("peak", peak) + .add("incline", incline) + .add("decline", decline); + } }; class Infection { - public: /** * @brief Create an Infection for a single Person. @@ -114,7 +125,22 @@ class Infection */ TimePoint get_start_date() const; + /// This method is used by the default serialization feature. + auto default_serialize() + { + return Members("Infection") + .add("infection_course", m_infection_course) + .add("virus_variant", m_virus_variant) + .add("viral_load", m_viral_load) + .add("log_norm_alpha", m_log_norm_alpha) + .add("log_norm_beta", m_log_norm_beta) + .add("detected", m_detected); + } + private: + friend DefaultFactory; + Infection() = default; + /** * @brief Determine ViralLoad course and Infection course based on init_state. * Calls draw_infection_course_backward for all #InfectionState%s prior and draw_infection_course_forward for all diff --git a/cpp/models/abm/location.h b/cpp/models/abm/location.h index 921dbd148f..4475504a93 100644 --- a/cpp/models/abm/location.h +++ b/cpp/models/abm/location.h @@ -25,6 +25,7 @@ #include "abm/parameters.h" #include "abm/location_type.h" +#include "memilio/io/default_serialize.h" #include "boost/atomic/atomic.hpp" namespace mio @@ -48,6 +49,12 @@ struct GeographicalLocation { { return !(latitude == other.latitude && longitude == other.longitude); } + + /// This method is used by the default serialization feature. + auto default_serialize() + { + return Members("GraphicalLocation").add("latitude", latitude).add("longitude", longitude); + } }; struct CellIndex : public mio::Index { @@ -73,6 +80,12 @@ struct CellCapacity { } uint32_t volume; ///< Volume of the Cell. uint32_t persons; ///< Maximal number of Person%s at the Cell. + + /// This method is used by the default serialization feature. + auto default_serialize() + { + return Members("CellCapacity").add("volume", volume).add("persons", persons); + } }; /** @@ -87,6 +100,12 @@ struct Cell { * @return The relative cell size for the Cell. */ ScalarType compute_space_per_person_relative() const; + + /// This method is used by the default serialization feature. + auto default_serialize() + { + return Members("Cell").add("capacity", m_capacity); + } }; // namespace mio /** @@ -222,36 +241,6 @@ class Location return m_required_mask != MaskType::None; } - /** - * serialize this. - * @see mio::serialize - */ - template - void serialize(IOContext& io) const - { - auto obj = io.create_object("Location"); - obj.add_element("index", m_id); - obj.add_element("type", m_type); - } - - /** - * deserialize an object of this class. - * @see mio::deserialize - */ - template - static IOResult deserialize(IOContext& io) - { - auto obj = io.expect_object("Location"); - auto index = obj.expect_element("index", Tag{}); - auto type = obj.expect_element("type", Tag{}); - return apply( - io, - [](auto&& index_, auto&& type_) { - return Location{type_, index_}; - }, - index, type); - } - /** * @brief Get the geographical location of the Location. * @return The geographical location of the Location. @@ -270,7 +259,22 @@ class Location m_geographical_location = location; } + /// This method is used by the default serialization feature. + auto default_serialize() + { + return Members("Location") + .add("type", m_type) + .add("id", m_id) + .add("parameters", m_parameters) + .add("cells", m_cells) + .add("required_mask", m_required_mask) + .add("geographical_location", m_geographical_location); + } + private: + friend DefaultFactory; + Location() = default; + LocationType m_type; ///< Type of the Location. LocationId m_id; ///< Unique identifier for the Location in the Model owning it. LocalInfectionParameters m_parameters; ///< Infection parameters for the Location. diff --git a/cpp/models/abm/mask.h b/cpp/models/abm/mask.h index fb05bba051..65d5d0b0e5 100644 --- a/cpp/models/abm/mask.h +++ b/cpp/models/abm/mask.h @@ -23,6 +23,7 @@ #include "abm/mask_type.h" #include "abm/time.h" +#include "memilio/io/default_serialize.h" namespace mio { @@ -63,11 +64,27 @@ class Mask */ void change_mask(MaskType new_mask_type, TimePoint t); + /// This method is used by the default serialization feature. + auto default_serialize() + { + return Members("Mask").add("mask_type", m_type).add("time_first_used", m_time_first_usage); + } + private: MaskType m_type; ///< Type of the Mask. TimePoint m_time_first_usage; ///< TimePoint of the Mask's initial usage. }; } // namespace abm + +/// @brief Creates an instance of abm::Mask for default serialization. +template <> +struct DefaultFactory { + static abm::Mask create() + { + return abm::Mask(abm::MaskType::Count, abm::TimePoint()); + } +}; + } // namespace mio #endif diff --git a/cpp/models/abm/model.h b/cpp/models/abm/model.h index ad3ff322b6..ea9f49c90c 100644 --- a/cpp/models/abm/model.h +++ b/cpp/models/abm/model.h @@ -112,22 +112,16 @@ class Model void serialize(IOContext& io) const { auto obj = io.create_object("Model"); - obj.add_element("num_agegroups", parameters.get_num_groups()); - std::vector trips; - TripList trip_list = get_trip_list(); - for (size_t i = 0; i < trip_list.num_trips(false); i++) { - trips.push_back(trip_list.get_next_trip(false)); - trip_list.increase_index(); - } - trip_list.reset_index(); - for (size_t i = 0; i < trip_list.num_trips(true); i++) { - trips.push_back(trip_list.get_next_trip(true)); - trip_list.increase_index(); - } - obj.add_list("trips", trips.begin(), trips.end()); - obj.add_list("locations", get_locations().begin(), get_locations().end()); + obj.add_element("parameters", parameters); + // skip caches, they are rebuild by the deserialized model obj.add_list("persons", get_persons().begin(), get_persons().end()); + obj.add_list("locations", get_locations().begin(), get_locations().end()); + obj.add_element("location_types", m_has_locations.to_ulong()); + obj.add_element("testing_strategy", m_testing_strategy); + obj.add_element("trip_list", m_trip_list); obj.add_element("use_mobility_rules", m_use_mobility_rules); + obj.add_element("cemetery_id", m_cemetery_id); + obj.add_element("rng", m_rng); } /** @@ -137,18 +131,30 @@ class Model template static IOResult deserialize(IOContext& io) { - auto obj = io.expect_object("Model"); - auto size = obj.expect_element("num_agegroups", Tag{}); - auto locations = obj.expect_list("locations", Tag{}); - auto trip_list = obj.expect_list("trips", Tag{}); - auto persons = obj.expect_list("persons", Tag{}); + auto obj = io.expect_object("Model"); + auto params = obj.expect_element("parameters", Tag{}); + auto persons = obj.expect_list("persons", Tag{}); + auto locations = obj.expect_list("locations", Tag{}); + auto location_types = obj.expect_element("location_types", Tag{}); + auto trip_list = obj.expect_element("trip_list", Tag{}); auto use_mobility_rules = obj.expect_element("use_mobility_rules", Tag{}); + auto cemetery_id = obj.expect_element("cemetery_id", Tag{}); + auto rng = obj.expect_element("rng", Tag{}); return apply( io, - [](auto&& size_, auto&& locations_, auto&& trip_list_, auto&& persons_, auto&& use_mobility_rule_) { - return Model{size_, locations_, trip_list_, persons_, use_mobility_rule_}; + [](auto&& params_, auto&& persons_, auto&& locations_, auto&& location_types_, auto&& trip_list_, + auto&& use_mobility_rules_, auto&& cemetery_id_, auto&& rng_) { + Model model{params_}; + model.m_persons.assign(persons_.cbegin(), persons_.cend()); + model.m_locations.assign(locations_.cbegin(), locations_.cend()); + model.m_has_locations = location_types_; + model.m_trip_list = trip_list_; + model.m_use_mobility_rules = use_mobility_rules_; + model.m_cemetery_id = cemetery_id_; + model.m_rng = rng_; + return model; }, - size, locations, trip_list, persons, use_mobility_rules); + params, persons, locations, location_types, trip_list, use_mobility_rules, cemetery_id, rng); } /** @@ -377,8 +383,9 @@ class Model inline void change_location(PersonId person, LocationId destination, TransportMode mode = TransportMode::Unknown, const std::vector& cells = {0}) { - LocationId origin = get_location(person).get_id(); - const bool has_changed_location = mio::abm::change_location(get_person(person), get_location(destination), mode, cells); + LocationId origin = get_location(person).get_id(); + const bool has_changed_location = + mio::abm::change_location(get_person(person), get_location(destination), mode, cells); // if the person has changed location, invalidate exposure caches but keep population caches valid if (has_changed_location) { m_are_exposure_caches_valid = false; diff --git a/cpp/models/abm/parameters.h b/cpp/models/abm/parameters.h index 524de6fd07..a819b55e17 100644 --- a/cpp/models/abm/parameters.h +++ b/cpp/models/abm/parameters.h @@ -25,13 +25,20 @@ #include "abm/virus_variant.h" #include "abm/vaccine.h" #include "abm/test_type.h" +#include "memilio/config.h" +#include "memilio/io/default_serialize.h" +#include "memilio/io/io.h" +#include "memilio/math/time_series_functor.h" #include "memilio/utils/custom_index_array.h" #include "memilio/utils/uncertain_value.h" #include "memilio/utils/parameter_set.h" #include "memilio/epidemiology/age_group.h" #include "memilio/epidemiology/damping.h" #include "memilio/epidemiology/contact_matrix.h" + +#include #include +#include namespace mio { @@ -169,6 +176,15 @@ struct ViralLoadDistributionsParameters { UniformDistribution::ParamType viral_load_peak; UniformDistribution::ParamType viral_load_incline; UniformDistribution::ParamType viral_load_decline; + + /// This method is used by the default serialization feature. + auto default_serialize() + { + return Members("ViralLoadDistributionsParameters") + .add("viral_load_peak", viral_load_peak) + .add("viral_load_incline", viral_load_incline) + .add("viral_load_decline", viral_load_decline); + } }; struct ViralLoadDistributions { @@ -192,6 +208,14 @@ struct ViralLoadDistributions { struct InfectivityDistributionsParameters { UniformDistribution::ParamType infectivity_alpha; UniformDistribution::ParamType infectivity_beta; + + /// This method is used by the default serialization feature. + auto default_serialize() + { + return Members("InfectivityDistributionsParameters") + .add("infectivity_alpha", infectivity_alpha) + .add("infectivity_beta", infectivity_beta); + } }; struct InfectivityDistributions { @@ -257,19 +281,15 @@ struct AerosolTransmissionRates { } }; -using InputFunctionForProtectionLevel = std::function; - /** * @brief Personal protection factor against #Infection% after #Infection and #Vaccination, which depends on #ExposureType, * #AgeGroup and #VirusVariant. Its value is between 0 and 1. */ struct InfectionProtectionFactor { - using Type = CustomIndexArray; + using Type = CustomIndexArray, ExposureType, AgeGroup, VirusVariant>; static auto get_default(AgeGroup size) { - return Type({ExposureType::Count, size, VirusVariant::Count}, [](ScalarType /*days*/) -> ScalarType { - return 0; - }); + return Type({ExposureType::Count, size, VirusVariant::Count}, TimeSeriesFunctor()); } static std::string name() { @@ -282,12 +302,10 @@ struct InfectionProtectionFactor { * #AgeGroup and #VirusVariant. Its value is between 0 and 1. */ struct SeverityProtectionFactor { - using Type = CustomIndexArray; + using Type = CustomIndexArray, ExposureType, AgeGroup, VirusVariant>; static auto get_default(AgeGroup size) { - return Type({ExposureType::Count, size, VirusVariant::Count}, [](ScalarType /*days*/) -> ScalarType { - return 0; - }); + return Type({ExposureType::Count, size, VirusVariant::Count}, TimeSeriesFunctor()); } static std::string name() { @@ -299,12 +317,10 @@ struct SeverityProtectionFactor { * @brief Personal protective factor against high viral load. Its value is between 0 and 1. */ struct HighViralLoadProtectionFactor { - using Type = InputFunctionForProtectionLevel; + using Type = TimeSeriesFunctor; static auto get_default() { - return Type([](ScalarType /*days*/) -> ScalarType { - return 0; - }); + return Type(); } static std::string name() { @@ -321,34 +337,14 @@ struct TestParameters { TimeSpan required_time; TestType type; - /** - * serialize this. - * @see mio::serialize - */ - template - void serialize(IOContext& io) const - { - auto obj = io.create_object("TestParameters"); - obj.add_element("Sensitivity", sensitivity); - obj.add_element("Specificity", specificity); - } - - /** - * deserialize an object of this class. - * @see mio::deserialize - */ - template - static IOResult deserialize(IOContext& io) + /// This method is used by the default serialization feature. + auto default_serialize() { - auto obj = io.expect_object("TestParameters"); - auto sens = obj.expect_element("Sensitivity", mio::Tag>{}); - auto spec = obj.expect_element("Specificity", mio::Tag>{}); - return apply( - io, - [](auto&& sens_, auto&& spec_) { - return TestParameters{sens_, spec_}; - }, - sens, spec); + return Members("TestParameters") + .add("sensitivity", sensitivity) + .add("specificity", specificity) + .add("required_time", required_time) + .add("test_type", type); } }; @@ -622,6 +618,14 @@ class Parameters : public ParametersBase { } +private: + Parameters(ParametersBase&& base) + : ParametersBase(std::move(base)) + , m_num_groups(this->get().size().get()) + { + } + +public: /** * @brief Get the number of the age groups. */ @@ -772,6 +776,17 @@ class Parameters : public ParametersBase return false; } + /** + * deserialize an object of this class. + * @see epi::deserialize + */ + template + static IOResult deserialize(IOContext& io) + { + BOOST_OUTCOME_TRY(auto&& base, ParametersBase::deserialize(io)); + return success(Parameters(std::move(base))); + } + private: size_t m_num_groups; }; diff --git a/cpp/models/abm/person.h b/cpp/models/abm/person.h index b952361fe0..fefdf72d0e 100755 --- a/cpp/models/abm/person.h +++ b/cpp/models/abm/person.h @@ -28,6 +28,7 @@ #include "abm/parameters.h" #include "abm/person_id.h" #include "abm/personal_rng.h" +#include "memilio/io/default_serialize.h" #include "abm/time.h" #include "abm/test_type.h" #include "abm/vaccine.h" @@ -378,36 +379,29 @@ class Person */ std::pair get_latest_protection() const; - /** - * serialize this. - * @see mio::serialize - */ - template - void serialize(IOContext& io) const - { - auto obj = io.create_object("Person"); - obj.add_element("Location", m_location); - obj.add_element("age", m_age); - obj.add_element("id", m_person_id); - } - - /** - * deserialize an object of this class. - * @see mio::deserialize - */ - template - static IOResult deserialize(IOContext& io) + /// This method is used by the default serialization feature. + auto default_serialize() { - auto obj = io.expect_object("Person"); - auto loc = obj.expect_element("Location", mio::Tag{}); - auto age = obj.expect_element("age", Tag{}); - auto id = obj.expect_element("id", Tag{}); - return apply( - io, - [](auto&& loc_, auto&& age_, auto&& id_) { - return Person{mio::RandomNumberGenerator(), loc_, AgeGroup(age_), id_}; - }, - loc, age, id); + return Members("Person") + .add("location", m_location) + .add("location_type", m_location_type) + .add("assigned_locations", m_assigned_locations) + .add("vaccinations", m_vaccinations) + .add("infections", m_infections) + .add("home_isolation_start", m_home_isolation_start) + .add("age_group", m_age) + .add("time_at_location", m_time_at_location) + .add("rnd_workgroup", m_random_workgroup) + .add("rnd_schoolgroup", m_random_schoolgroup) + .add("rnd_go_to_work_hour", m_random_goto_work_hour) + .add("rnd_go_to_school_hour", m_random_goto_school_hour) + .add("mask", m_mask) + .add("compliance", m_compliance) + .add("id", m_person_id) + .add("cells", m_cells) + .add("last_transport_mode", m_last_transport_mode) + .add("rng_counter", m_rng_counter) + .add("test_results", m_test_results); } /** @@ -451,6 +445,17 @@ class Person }; } // namespace abm + +/// @brief Creates an instance of abm::Person for default serialization. +template <> +struct DefaultFactory { + static abm::Person create() + { + return abm::Person(thread_local_rng(), abm::LocationType::Count, abm::LocationId(), AgeGroup(0), + abm::PersonId()); + } +}; + } // namespace mio #endif diff --git a/cpp/models/abm/test_type.h b/cpp/models/abm/test_type.h index ff9bb7cf09..70e1e6f65e 100644 --- a/cpp/models/abm/test_type.h +++ b/cpp/models/abm/test_type.h @@ -17,12 +17,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #ifndef MIO_ABM_TEST_TYPE_H #define MIO_ABM_TEST_TYPE_H +#include "abm/time.h" +#include "memilio/io/default_serialize.h" + #include -#include + namespace mio { namespace abm @@ -46,6 +48,12 @@ enum class TestType : std::uint32_t struct TestResult { TimePoint time_of_testing{std::numeric_limits::min()}; ///< The TimePoint when the Person performs the test. bool result{false}; ///< The test result. + + /// This method is used by the default serialization feature. + auto default_serialize() + { + return Members("TestResult").add("time_of_testing", time_of_testing).add("result", result); + } }; } // namespace abm diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index 00578e7379..359bb00ce4 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -27,6 +27,8 @@ #include "abm/person.h" #include "abm/location.h" #include "abm/time.h" +#include "memilio/io/default_serialize.h" + #include #include @@ -92,6 +94,11 @@ class TestingCriteria */ bool evaluate(const Person& p, TimePoint t) const; + auto default_serialize() + { + return Members("TestingCriteria").add("ages", m_ages).add("infection_states", m_infection_states); + } + private: std::bitset m_ages; ///< Set of #AgeGroup%s that are either allowed or required to be tested. std::bitset<(size_t)InfectionState::Count> @@ -143,7 +150,23 @@ class TestingScheme */ bool run_scheme(PersonalRandomNumberGenerator& rng, Person& person, TimePoint t) const; + /// This method is used by the default serialization feature. + auto default_serialize() + { + return Members("TestingScheme") + .add("criteria", m_testing_criteria) + .add("validity_period", m_validity_period) + .add("start_date", m_start_date) + .add("end_date", m_end_date) + .add("test_params", m_test_parameters) + .add("probability", m_probability) + .add("is_active", m_is_active); + } + private: + friend DefaultFactory; + TestingScheme() = default; + TestingCriteria m_testing_criteria; ///< TestingCriteria of the scheme. TimeSpan m_validity_period; ///< The valid TimeSpan of the test. TimePoint m_start_date; ///< Starting date of the scheme. @@ -167,6 +190,12 @@ class TestingStrategy LocationType type; LocationId id; std::vector schemes; + + /// This method is used by the default serialization feature. + auto default_serialize() + { + return Members("LocalStrategy").add("type", type).add("id", id).add("schemes", schemes); + } }; /** @@ -232,6 +261,12 @@ class TestingStrategy */ bool run_strategy(PersonalRandomNumberGenerator& rng, Person& person, const Location& location, TimePoint t); + /// This method is used by the default serialization feature. + auto default_serialize() + { + return Members("TestingStrategy").add("schemes", m_location_to_schemes_map); + } + private: std::vector m_location_to_schemes_map; ///< Set of schemes that are checked for testing. }; diff --git a/cpp/models/abm/time.h b/cpp/models/abm/time.h index 45ad5495d4..f2d484512d 100644 --- a/cpp/models/abm/time.h +++ b/cpp/models/abm/time.h @@ -20,6 +20,8 @@ #ifndef MIO_ABM_TIME_H #define MIO_ABM_TIME_H +#include "memilio/io/default_serialize.h" + namespace mio { namespace abm @@ -143,6 +145,12 @@ class TimeSpan } /**@}*/ + /// This method is used by the default serialization feature. + auto default_serialize() + { + return Members("TimeSpan").add("seconds", m_seconds); + } + private: int m_seconds; ///< The duration of time in seconds. }; @@ -284,6 +292,12 @@ class TimePoint return TimeSpan{m_seconds - p2.seconds()}; } + /// This method is used by the default serialization feature. + auto default_serialize() + { + return Members("TimePoint").add("seconds", m_seconds); + } + private: int m_seconds; ///< The number of seconds after the epoch. }; diff --git a/cpp/models/abm/trip_list.h b/cpp/models/abm/trip_list.h index f20db0df29..b96861ffef 100644 --- a/cpp/models/abm/trip_list.h +++ b/cpp/models/abm/trip_list.h @@ -24,6 +24,7 @@ #include "abm/mobility_data.h" #include "abm/person_id.h" #include "abm/time.h" +#include "memilio/io/default_serialize.h" #include namespace mio @@ -89,38 +90,13 @@ struct Trip { (origin == other.origin); } - /** - * serialize this. - * @see mio::serialize - */ - template - void serialize(IOContext& io) const - { - auto obj = io.create_object("Trip"); - obj.add_element("person_id", person_id); - obj.add_element("time", time.seconds()); - obj.add_element("destination", destination); - obj.add_element("origin", origin); - } - - /** - * deserialize an object of this class. - * @see mio::deserialize - */ - template - static IOResult deserialize(IOContext& io) + auto default_serialize() { - auto obj = io.expect_object("Trip"); - auto person_id = obj.expect_element("person_id", Tag{}); - auto time = obj.expect_element("time", Tag{}); - auto destination_id = obj.expect_element("destination", Tag{}); - auto origin_id = obj.expect_element("origin", Tag{}); - return apply( - io, - [](auto&& person_id_, auto&& time_, auto&& destination_id_, auto&& origin_id_) { - return Trip(person_id_, TimePoint(time_), destination_id_, origin_id_); - }, - person_id, time, destination_id, origin_id); + return Members("Trip") + .add("person_id", person_id) + .add("time", time) + .add("destination", destination) + .add("origin", origin); } }; @@ -192,6 +168,15 @@ class TripList return m_current_index; } + /// This method is used by the default serialization feature. + auto default_serialize() + { + return Members("TestingScheme") + .add("trips_weekday", m_trips_weekday) + .add("trips_weekend", m_trips_weekend) + .add("index", m_current_index); + } + private: std::vector m_trips_weekday; ///< The list of Trip%s a Person makes on a weekday. std::vector m_trips_weekend; ///< The list of Trip%s a Person makes on a weekend day. @@ -199,6 +184,16 @@ class TripList }; } // namespace abm + +/// @brief Creates an instance of abm::Trip for default serialization. +template <> +struct DefaultFactory { + static abm::Trip create() + { + return abm::Trip{abm::PersonId{}, abm::TimePoint{}, abm::LocationId{}}; + } +}; + } // namespace mio #endif diff --git a/cpp/models/abm/vaccine.h b/cpp/models/abm/vaccine.h index 0277415bc3..d613409067 100644 --- a/cpp/models/abm/vaccine.h +++ b/cpp/models/abm/vaccine.h @@ -20,6 +20,7 @@ #ifndef MIO_ABM_VACCINE_H #define MIO_ABM_VACCINE_H +#include "memilio/io/default_serialize.h" #include "abm/time.h" #include @@ -44,7 +45,7 @@ enum class ExposureType : std::uint32_t /** * @brief A tuple of #TimePoint and #ExposureType (i.e. type of the Vaccine). * The #TimePoint describes the time of administration of the Vaccine. -*/ + */ struct Vaccination { Vaccination(ExposureType exposure, TimePoint t) : exposure_type(exposure) @@ -52,11 +53,27 @@ struct Vaccination { { } + /// This method is used by the default serialization feature. + auto default_serialize() + { + return Members("Vaccination").add("exposure_type", exposure_type).add("time", time); + } + ExposureType exposure_type; TimePoint time; }; } // namespace abm + +/// @brief Creates an instance of abm::Vaccination for default serialization. +template <> +struct DefaultFactory { + static abm::Vaccination create() + { + return abm::Vaccination(abm::ExposureType::Count, abm::TimePoint()); + } +}; + } // namespace mio #endif diff --git a/cpp/simulations/abm.cpp b/cpp/simulations/abm.cpp index 0cb6f79673..cc3d1cef5c 100644 --- a/cpp/simulations/abm.cpp +++ b/cpp/simulations/abm.cpp @@ -21,8 +21,8 @@ #include "abm/common_abm_loggers.h" #include "abm/household.h" #include "abm/lockdown_rules.h" +#include "memilio/config.h" #include "memilio/io/result_io.h" -#include "memilio/math/interpolation.h" #include "memilio/utils/random_number_generator.h" #include "memilio/utils/uncertain_value.h" @@ -469,10 +469,9 @@ void set_parameters(mio::abm::Parameters params) params.set({{mio::abm::VirusVariant::Count, mio::AgeGroup(num_age_groups)}, 4.}); // Set protection level from high viral load. Information based on: https://doi.org/10.1093/cid/ciaa886 - params.get() = [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.863}, {1, 0.969}, {7, 0.029}, {10, 0.002}, {14, 0.0014}, {21, 0}}, days); - }; + params.get() = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.863}, {1, 0.969}, {7, 0.029}, {10, 0.002}, {14, 0.0014}, {21, 0}}}; //0-4 params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.276; diff --git a/cpp/simulations/abm_braunschweig.cpp b/cpp/simulations/abm_braunschweig.cpp index 8ced5e0f66..5aa486bf2f 100644 --- a/cpp/simulations/abm_braunschweig.cpp +++ b/cpp/simulations/abm_braunschweig.cpp @@ -20,13 +20,13 @@ #include "abm/common_abm_loggers.h" #include "abm/location_id.h" #include "abm/lockdown_rules.h" +#include "abm/parameters.h" #include "abm/person.h" #include "abm/simulation.h" #include "abm/model.h" #include "memilio/epidemiology/age_group.h" #include "memilio/io/io.h" #include "memilio/io/result_io.h" -#include "memilio/math/interpolation.h" #include "memilio/utils/uncertain_value.h" #include "boost/algorithm/string/split.hpp" #include "boost/algorithm/string/classification.hpp" @@ -400,10 +400,9 @@ void set_parameters(mio::abm::Parameters params) params.set({{mio::abm::VirusVariant::Count, mio::AgeGroup(num_age_groups)}, 4.}); // Set protection level from high viral load. Information based on: https://doi.org/10.1093/cid/ciaa886 - params.get() = [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.863}, {1, 0.969}, {7, 0.029}, {10, 0.002}, {14, 0.0014}, {21, 0}}, days); - }; + params.get() = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.863}, {1, 0.969}, {7, 0.029}, {10, 0.002}, {14, 0.0014}, {21, 0}}}; //0-4 params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.276; @@ -493,58 +492,51 @@ void set_parameters(mio::abm::Parameters params) // Protection of reinfection is the same for all age-groups, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_0_to_4, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.852}, - {180, 0.852}, - {210, 0.845}, - {240, 0.828}, - {270, 0.797}, - {300, 0.759}, - {330, 0.711}, - {360, 0.661}, - {390, 0.616}, - {420, 0.580}, - {450, 0.559}, - {450, 0.550}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.852}, + {180, 0.852}, + {210, 0.845}, + {240, 0.828}, + {270, 0.797}, + {300, 0.759}, + {330, 0.711}, + {360, 0.661}, + {390, 0.616}, + {420, 0.580}, + {450, 0.559}, + {450, 0.550}}}; + // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_0_to_4, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.91}, {60, 0.92}, {90, 0.88}, {120, 0.84}, {150, 0.81}, {180, 0.88}, {450, 0.5}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.5}, {30, 0.91}, {60, 0.92}, {90, 0.88}, {120, 0.84}, {150, 0.81}, {180, 0.88}, {450, 0.5}}}; // Set up age-related severe protection levels, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_0_to_4, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.967}, - {30, 0.975}, - {60, 0.977}, - {90, 0.974}, - {120, 0.963}, - {150, 0.947}, - {180, 0.93}, - {210, 0.929}, - {240, 0.923}, - {270, 0.908}, - {300, 0.893}, - {330, 0.887}, - {360, 0.887}, - {450, 0.5}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.967}, + {30, 0.975}, + {60, 0.977}, + {90, 0.974}, + {120, 0.963}, + {150, 0.947}, + {180, 0.93}, + {210, 0.929}, + {240, 0.923}, + {270, 0.908}, + {300, 0.893}, + {330, 0.887}, + {360, 0.887}, + {450, 0.5}}}; // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_0_to_4, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {450, 0.5}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {450, 0.5}}}; //5-14 params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.161; @@ -560,57 +552,50 @@ void set_parameters(mio::abm::Parameters params) // Protection of reinfection is the same for all age-groups, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_5_to_14, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.852}, - {180, 0.852}, - {210, 0.845}, - {240, 0.828}, - {270, 0.797}, - {300, 0.759}, - {330, 0.711}, - {360, 0.661}, - {390, 0.616}, - {420, 0.580}, - {450, 0.559}, - {450, 0.550}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.852}, + {180, 0.852}, + {210, 0.845}, + {240, 0.828}, + {270, 0.797}, + {300, 0.759}, + {330, 0.711}, + {360, 0.661}, + {390, 0.616}, + {420, 0.580}, + {450, 0.559}, + {450, 0.550}}}; // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_5_to_14, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.91}, {60, 0.92}, {90, 0.88}, {120, 0.84}, {150, 0.81}, {180, 0.88}, {450, 0.5}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.5}, {30, 0.91}, {60, 0.92}, {90, 0.88}, {120, 0.84}, {150, 0.81}, {180, 0.88}, {450, 0.5}}}; + // Set up age-related severe protection levels, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_5_to_14, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.967}, - {30, 0.975}, - {60, 0.977}, - {90, 0.974}, - {120, 0.963}, - {150, 0.947}, - {180, 0.93}, - {210, 0.929}, - {240, 0.923}, - {270, 0.908}, - {300, 0.893}, - {330, 0.887}, - {360, 0.887}, - {450, 0.5}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.967}, + {30, 0.975}, + {60, 0.977}, + {90, 0.974}, + {120, 0.963}, + {150, 0.947}, + {180, 0.93}, + {210, 0.929}, + {240, 0.923}, + {270, 0.908}, + {300, 0.893}, + {330, 0.887}, + {360, 0.887}, + {450, 0.5}}}; // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_5_to_14, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {450, 0.5}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {450, 0.5}}}; //15-34 params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = @@ -626,57 +611,49 @@ void set_parameters(mio::abm::Parameters params) params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.0; // Set up personal infection and vaccine protection levels, based on: https://doi.org/10.1038/s41577-021-00550-x, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_15_to_34, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.852}, - {180, 0.852}, - {210, 0.845}, - {240, 0.828}, - {270, 0.797}, - {300, 0.759}, - {330, 0.711}, - {360, 0.661}, - {390, 0.616}, - {420, 0.580}, - {450, 0.559}, - {450, 0.550}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.852}, + {180, 0.852}, + {210, 0.845}, + {240, 0.828}, + {270, 0.797}, + {300, 0.759}, + {330, 0.711}, + {360, 0.661}, + {390, 0.616}, + {420, 0.580}, + {450, 0.559}, + {450, 0.550}}}; // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_15_to_34, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.89}, {60, 0.84}, {90, 0.78}, {120, 0.68}, {150, 0.57}, {180, 0.39}, {450, 0.1}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.5}, {30, 0.89}, {60, 0.84}, {90, 0.78}, {120, 0.68}, {150, 0.57}, {180, 0.39}, {450, 0.1}}}; // Set up age-related severe protection levels, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_15_to_34, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.967}, - {30, 0.975}, - {60, 0.977}, - {90, 0.974}, - {120, 0.963}, - {150, 0.947}, - {180, 0.93}, - {210, 0.929}, - {240, 0.923}, - {270, 0.908}, - {300, 0.893}, - {330, 0.887}, - {360, 0.887}, - {450, 0.5}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.967}, + {30, 0.975}, + {60, 0.977}, + {90, 0.974}, + {120, 0.963}, + {150, 0.947}, + {180, 0.93}, + {210, 0.929}, + {240, 0.923}, + {270, 0.908}, + {300, 0.893}, + {330, 0.887}, + {360, 0.887}, + {450, 0.5}}}; // Information is from: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_15_to_34, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {180, 0.90}, {450, 0.5}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {180, 0.90}, {450, 0.5}}}; //35-59 params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = @@ -693,58 +670,49 @@ void set_parameters(mio::abm::Parameters params) // Protection of reinfection is the same for all age-groups, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_35_to_59, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.852}, - {180, 0.852}, - {210, 0.845}, - {240, 0.828}, - {270, 0.797}, - {300, 0.759}, - {330, 0.711}, - {360, 0.661}, - {390, 0.616}, - {420, 0.580}, - {450, 0.559}, - {450, 0.550}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.852}, + {180, 0.852}, + {210, 0.845}, + {240, 0.828}, + {270, 0.797}, + {300, 0.759}, + {330, 0.711}, + {360, 0.661}, + {390, 0.616}, + {420, 0.580}, + {450, 0.559}, + {450, 0.550}}}; // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_35_to_59, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.89}, {60, 0.84}, {90, 0.78}, {120, 0.68}, {150, 0.57}, {180, 0.39}, {450, 0.1}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.5}, {30, 0.89}, {60, 0.84}, {90, 0.78}, {120, 0.68}, {150, 0.57}, {180, 0.39}, {450, 0.1}}}; // Set up age-related severe protection levels, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_35_to_59, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.967}, - {30, 0.975}, - {60, 0.977}, - {90, 0.974}, - {120, 0.963}, - {150, 0.947}, - {180, 0.93}, - {210, 0.929}, - {240, 0.923}, - {270, 0.908}, - {300, 0.893}, - {330, 0.887}, - {360, 0.887}, - {450, 0.5}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.967}, + {30, 0.975}, + {60, 0.977}, + {90, 0.974}, + {120, 0.963}, + {150, 0.947}, + {180, 0.93}, + {210, 0.929}, + {240, 0.923}, + {270, 0.908}, + {300, 0.893}, + {330, 0.887}, + {360, 0.887}, + {450, 0.5}}}; // Information is from: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_35_to_59, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {180, 0.90}, {450, 0.5}}, days); - }; - + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {180, 0.90}, {450, 0.5}}}; //60-79 params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.179; @@ -760,57 +728,49 @@ void set_parameters(mio::abm::Parameters params) // Protection of reinfection is the same for all age-groups, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_60_to_79, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.852}, - {180, 0.852}, - {210, 0.845}, - {240, 0.828}, - {270, 0.797}, - {300, 0.759}, - {330, 0.711}, - {360, 0.661}, - {390, 0.616}, - {420, 0.580}, - {450, 0.559}, - {450, 0.550}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.852}, + {180, 0.852}, + {210, 0.845}, + {240, 0.828}, + {270, 0.797}, + {300, 0.759}, + {330, 0.711}, + {360, 0.661}, + {390, 0.616}, + {420, 0.580}, + {450, 0.559}, + {450, 0.550}}}; // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_60_to_79, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.87}, {60, 0.85}, {90, 0.78}, {120, 0.67}, {150, 0.61}, {180, 0.50}, {450, 0.1}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.5}, {30, 0.87}, {60, 0.85}, {90, 0.78}, {120, 0.67}, {150, 0.61}, {180, 0.50}, {450, 0.1}}}; // Set up personal severe protection levels. // Protection of severe infection of age group 65 + is different from other age group, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_60_to_79, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.967}, - {30, 0.975}, - {60, 0.977}, - {90, 0.974}, - {120, 0.963}, - {150, 0.947}, - {180, 0.93}, - {210, 0.929}, - {240, 0.923}, - {270, 0.908}, - {300, 0.893}, - {330, 0.887}, - {360, 0.887}, - {360, 0.5}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.967}, + {30, 0.975}, + {60, 0.977}, + {90, 0.974}, + {120, 0.963}, + {150, 0.947}, + {180, 0.93}, + {210, 0.929}, + {240, 0.923}, + {270, 0.908}, + {300, 0.893}, + {330, 0.887}, + {360, 0.887}, + {360, 0.5}}}; params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_60_to_79, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.91}, {60, 0.86}, {90, 0.91}, {120, 0.94}, {150, 0.95}, {180, 0.90}, {450, 0.5}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.5}, {30, 0.91}, {60, 0.86}, {90, 0.91}, {120, 0.94}, {150, 0.95}, {180, 0.90}, {450, 0.5}}}; //80+ params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.179; @@ -826,58 +786,50 @@ void set_parameters(mio::abm::Parameters params) // Protection of reinfection is the same for all age-groups, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_80_plus, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.852}, - {180, 0.852}, - {210, 0.845}, - {240, 0.828}, - {270, 0.797}, - {300, 0.759}, - {330, 0.711}, - {360, 0.661}, - {390, 0.616}, - {420, 0.580}, - {450, 0.559}, - {450, 0.550}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.852}, + {180, 0.852}, + {210, 0.845}, + {240, 0.828}, + {270, 0.797}, + {300, 0.759}, + {330, 0.711}, + {360, 0.661}, + {390, 0.616}, + {420, 0.580}, + {450, 0.559}, + {450, 0.550}}}; // Information is from: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_80_plus, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.80}, {60, 0.79}, {90, 0.75}, {120, 0.56}, {150, 0.49}, {180, 0.43}, {450, 0.1}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.5}, {30, 0.80}, {60, 0.79}, {90, 0.75}, {120, 0.56}, {150, 0.49}, {180, 0.43}, {450, 0.1}}}; // Set up personal severe protection levels. // Protection of severe infection of age group 65 + is different from other age group, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5 params.get()[{mio::abm::ExposureType::NaturalInfection, age_group_0_to_4, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.967}, - {30, 0.975}, - {60, 0.977}, - {90, 0.974}, - {120, 0.963}, - {150, 0.947}, - {180, 0.93}, - {210, 0.929}, - {240, 0.923}, - {270, 0.908}, - {300, 0.893}, - {330, 0.887}, - {360, 0.887}, - {360, 0.5}}, - days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.967}, + {30, 0.975}, + {60, 0.977}, + {90, 0.974}, + {120, 0.963}, + {150, 0.947}, + {180, 0.93}, + {210, 0.929}, + {240, 0.923}, + {270, 0.908}, + {300, 0.893}, + {330, 0.887}, + {360, 0.887}, + {360, 0.5}}}; // Information is based on: https://doi.org/10.1016/S0140-6736(21)02183-8 params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_80_plus, - mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set( - {{0, 0.5}, {30, 0.84}, {60, 0.88}, {90, 0.89}, {120, 0.86}, {150, 0.85}, {180, 0.83}, {450, 0.5}}, days); - }; + mio::abm::VirusVariant::Wildtype}] = { + mio::TimeSeriesFunctorType::LinearInterpolation, + {{0, 0.5}, {30, 0.84}, {60, 0.88}, {90, 0.89}, {120, 0.86}, {150, 0.85}, {180, 0.83}, {450, 0.5}}}; } /** @@ -952,7 +904,7 @@ void write_log_to_file_trip_data(const T& history) int start_index = mobility_data_index - 1; using Type = std::tuple; + mio::abm::TransportMode, mio::abm::ActivityType, mio::abm::InfectionState>; while (!std::binary_search(std::begin(mobility_data[start_index]), std::end(mobility_data[start_index]), mobility_data[mobility_data_index][trip_index], [](const Type& v1, const Type& v2) { diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index fda62c5ef7..4a7be57cb2 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -37,9 +37,11 @@ set(TESTSOURCES test_abm_mobility_rules.cpp test_abm_person.cpp test_abm_simulation.cpp + test_abm_serialization.cpp test_abm_testing_strategy.cpp test_abm_model.cpp test_math_floating_point.cpp + test_math_time_series_functor.cpp test_analyze_result.cpp test_contact_matrix.cpp test_type_safe.cpp @@ -59,20 +61,22 @@ set(TESTSOURCES test_metaprogramming.cpp test_history.cpp test_utils.cpp - abm_helpers.h - abm_helpers.cpp test_ide_seir.cpp test_ide_secir.cpp test_state_age_function.cpp + test_lct_secir.cpp + test_lct_initializer_flows.cpp + test_ad.cpp + abm_helpers.h + abm_helpers.cpp + actions.h distributions_helpers.h distributions_helpers.cpp - actions.h + matchers.cpp matchers.h - temp_file_register.h + random_number_test.h sanitizers.cpp - test_lct_secir.cpp - test_lct_initializer_flows.cpp - test_ad.cpp + temp_file_register.h ) if(MEMILIO_HAS_JSONCPP) diff --git a/cpp/tests/matchers.cpp b/cpp/tests/matchers.cpp new file mode 100644 index 0000000000..9c222af0f4 --- /dev/null +++ b/cpp/tests/matchers.cpp @@ -0,0 +1,64 @@ +/* +* Copyright (C) 2020-2024 MEmilio +* +* Authors: René Schmieding +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#include "matchers.h" + +#ifdef MEMILIO_HAS_JSONCPP + +void Json::PrintTo(const Value& json, std::ostream* os) +{ + if (json.isObject()) { + // move opening bracket into its own line + *os << "\n"; + } + const static auto js_writer = [] { + StreamWriterBuilder swb; + swb["indentation"] = " "; + return std::unique_ptr(swb.newStreamWriter()); + }(); + js_writer->write(json, os); +} + +std::string json_type_to_string(Json::ValueType t) +{ + using namespace Json; + switch (t) { + case nullValue: + return "Null"; + case intValue: + return "Int"; + case uintValue: + return "UInt"; + case realValue: + return "Real"; + case stringValue: + return "String"; + case booleanValue: + return "Boolean"; + case arrayValue: + return "Array"; + case objectValue: + return "Object"; + default: + assert(false && "Unreachable"); + return ""; + } +} + +#endif // MEMILIO_HAS_JSONCPP diff --git a/cpp/tests/matchers.h b/cpp/tests/matchers.h index 3d89323cfe..4953625f75 100644 --- a/cpp/tests/matchers.h +++ b/cpp/tests/matchers.h @@ -25,6 +25,63 @@ #include "memilio/io/io.h" #include "gmock/gmock.h" +#ifdef MEMILIO_HAS_JSONCPP + +#include "json/json.h" + +namespace Json +{ +void PrintTo(const Value& json, std::ostream* os); +} // namespace Json + +std::string json_type_to_string(Json::ValueType t); + +MATCHER_P(JsonEqual, expected_json, testing::PrintToString(expected_json)) +{ + auto match_rec = [&](auto&& match, const Json::Value& a, const Json::Value& b, std::string name) { + // first check if the types match + if (a.type() != b.type()) { + *result_listener << "type mismatch for " << name << ", expected " << json_type_to_string(a.type()) + << ", actual " << json_type_to_string(b.type()); + return false; + } + // handle object types by recursively matching members + if (a.isObject()) { + for (auto& key : a.getMemberNames()) { + if (!b.isMember(key)) { + *result_listener << "missing key \"" << key << "\" in " << name; + return false; + } + if (!match(match, a[key], b[key], name + "[\"" + key + "\"]")) { + return false; + } + } + } + // handle arrays by recursively matching each item + else if (a.isArray()) { + if (a.size() != b.size()) { + *result_listener << "wrong number of items in " << name << ", expected " << a.size() << ", actual " + << b.size(); + return false; + } + for (Json::ArrayIndex i = 0; i < a.size(); ++i) { + if (!match(match, a[i], b[i], name + "[\"" + std::to_string(i) + "\"]")) { + return false; + } + } + } + // handle value types using Json::Value::operator== + else if (a != b) { + *result_listener << "value mismatch in " << name << ", expected " << testing::PrintToString(a) + << ", actual " << testing::PrintToString(b); + return false; + } + return true; + }; + return match_rec(match_rec, expected_json, arg, "Json::Value"); +} +#endif //MEMILIO_HAS_JSONCPP + /** * @brief overload gtest printer function for eigen matrices. * @note see https://stackoverflow.com/questions/25146997/teach-google-test-how-to-print-eigen-matrix diff --git a/cpp/tests/random_number_test.h b/cpp/tests/random_number_test.h new file mode 100644 index 0000000000..8d5fdfbacc --- /dev/null +++ b/cpp/tests/random_number_test.h @@ -0,0 +1,55 @@ +/* +* Copyright (C) 2020-2024 MEmilio +* +* Authors: René Schmieding +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#include "memilio/utils/random_number_generator.h" + +#include + +#include + +class RandomNumberTest : public ::testing::Test +{ +public: + /** + * @brief Draws a uniformly distributed random number. + * @tparam FP A floating point type, defaults to double. + * @param[in] min, max Lower and upper bound to the uniform distribution. + * @return A random value between min and max. + */ + template + double random_number(FP min = FP{-1e+3}, FP max = {1e+3}) + { + return std::uniform_real_distribution(min, max)(m_rng); + } + + /// @brief Access the random number generator. Should only be used for debugging. + mio::RandomNumberGenerator get_rng() + { + return m_rng; + } + +protected: + void SetUp() override + { + log_rng_seeds(m_rng, mio::LogLevel::warn); + } + +private: + mio::RandomNumberGenerator m_rng{}; ///< Seeded rng used by this test fixture. +}; diff --git a/cpp/tests/test_abm_infection.cpp b/cpp/tests/test_abm_infection.cpp index 3ee2369000..c192a2e639 100644 --- a/cpp/tests/test_abm_infection.cpp +++ b/cpp/tests/test_abm_infection.cpp @@ -21,7 +21,6 @@ #include "abm/location_type.h" #include "abm/person.h" #include "abm_helpers.h" -#include "memilio/math/interpolation.h" #include "memilio/utils/random_number_generator.h" #include "abm_helpers.h" @@ -79,12 +78,10 @@ TEST(TestInfection, init) EXPECT_NEAR(infection.get_infectivity(mio::abm::TimePoint(0) + mio::abm::days(3)), 0.2689414213699951, 1e-14); params.get()[{mio::abm::ExposureType::GenericVaccine, age_group_test, - virus_variant_test}] = [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.91}, {30, 0.81}}, days); - }; - params.get() = [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{0, 0.91}, {30, 0.81}}, days); - }; + virus_variant_test}] = + mio::TimeSeriesFunctor{mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.91}, {30, 0.81}}}; + params.get() = + mio::TimeSeriesFunctor{mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.91}, {30, 0.81}}}; auto infection_w_previous_exp = mio::abm::Infection(rng, mio::abm::VirusVariant::Wildtype, age_group_test, params, mio::abm::TimePoint(0), mio::abm::InfectionState::InfectedSymptoms, @@ -174,7 +171,8 @@ TEST(TestInfection, drawInfectionCourseBackward) TEST(TestInfection, getPersonalProtectiveFactor) { - auto rng = mio::RandomNumberGenerator(); + const ScalarType eps = 1e-4; + auto rng = mio::RandomNumberGenerator(); auto location = mio::abm::Location(mio::abm::LocationType::School, 0, num_age_groups); auto person = mio::abm::Person(rng, location.get_type(), location.get_id(), age_group_15_to_34); @@ -187,79 +185,71 @@ TEST(TestInfection, getPersonalProtectiveFactor) mio::abm::ExposureType::GenericVaccine, mio::AgeGroup(0), mio::abm::VirusVariant::Wildtype}](0); auto defaut_severity_protection = params.get()[{ mio::abm::ExposureType::GenericVaccine, mio::AgeGroup(0), mio::abm::VirusVariant::Wildtype}](0); - ASSERT_NEAR(defaut_infection_protection, 0, 0.0001); - ASSERT_NEAR(defaut_severity_protection, 0, 0.0001); + EXPECT_NEAR(defaut_infection_protection, 0, eps); + EXPECT_NEAR(defaut_severity_protection, 0, eps); // Test linear interpolation with one node - mio::set_log_level(mio::LogLevel::critical); //this throws an error either way params.get()[{mio::abm::ExposureType::GenericVaccine, person.get_age(), mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{2, 0.91}}, days); - }; + mio::TimeSeriesFunctor{mio::TimeSeriesFunctorType::LinearInterpolation, {{2, 0.91}}}; auto t = mio::abm::TimePoint(6 * 24 * 60 * 60); - ASSERT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0, 0.001); - mio::set_log_level(mio::LogLevel::warn); //this throws an error either way + + EXPECT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0.91, eps); params.get()[{mio::abm::ExposureType::GenericVaccine, person.get_age(), mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{2, 0.91}, {30, 0.81}}, days); - }; + mio::TimeSeriesFunctor{mio::TimeSeriesFunctorType::LinearInterpolation, {{2, 0.91}, {30, 0.81}}}; params.get()[{mio::abm::ExposureType::GenericVaccine, person.get_age(), mio::abm::VirusVariant::Wildtype}] = - [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{2, 0.91}, {30, 0.81}}, days); - }; - params.get() = [](ScalarType days) -> ScalarType { - return mio::linear_interpolation_of_data_set({{2, 0.91}, {30, 0.81}}, days); - }; + mio::TimeSeriesFunctor{mio::TimeSeriesFunctorType::LinearInterpolation, {{2, 0.91}, {30, 0.81}}}; + params.get() = + mio::TimeSeriesFunctor{mio::TimeSeriesFunctorType::LinearInterpolation, {{2, 0.91}, {30, 0.81}}}; // Test Parameter InfectionProtectionFactor and get_protection_factor() t = mio::abm::TimePoint(0) + mio::abm::days(2); auto infection_protection_factor = params.get()[{ latest_protection.first, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}]( t.days() - latest_protection.second.days()); - ASSERT_NEAR(infection_protection_factor, 0.91, 0.0001); - ASSERT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0.91, 0.0001); + EXPECT_NEAR(infection_protection_factor, 0.91, eps); + EXPECT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0.91, eps); t = mio::abm::TimePoint(0) + mio::abm::days(15); infection_protection_factor = params.get()[{ latest_protection.first, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}]( t.days() - latest_protection.second.days()); - ASSERT_NEAR(infection_protection_factor, 0.8635, 0.0001); - ASSERT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0.8635, 0.0001); + EXPECT_NEAR(infection_protection_factor, 0.8635, eps); + EXPECT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0.8635, eps); t = mio::abm::TimePoint(0) + mio::abm::days(40); infection_protection_factor = params.get()[{ latest_protection.first, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}]( t.days() - latest_protection.second.days()); - ASSERT_NEAR(infection_protection_factor, 0, 0.0001); - ASSERT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0, 0.0001); + EXPECT_NEAR(infection_protection_factor, 0.81, eps); + EXPECT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0.81, eps); // Test Parameter SeverityProtectionFactor t = mio::abm::TimePoint(0) + mio::abm::days(2); auto severity_protection_factor = params.get()[{ latest_protection.first, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}]( t.days() - latest_protection.second.days()); - ASSERT_NEAR(severity_protection_factor, 0.91, 0.0001); + EXPECT_NEAR(severity_protection_factor, 0.91, eps); t = mio::abm::TimePoint(0) + mio::abm::days(15); severity_protection_factor = params.get()[{ latest_protection.first, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}]( t.days() - latest_protection.second.days()); - ASSERT_NEAR(severity_protection_factor, 0.8635, 0.0001); + EXPECT_NEAR(severity_protection_factor, 0.8635, eps); t = mio::abm::TimePoint(0) + mio::abm::days(40); severity_protection_factor = params.get()[{ latest_protection.first, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}]( t.days() - latest_protection.second.days()); - ASSERT_NEAR(severity_protection_factor, 0, 0.0001); + EXPECT_NEAR(severity_protection_factor, 0.81, eps); // Test Parameter HighViralLoadProtectionFactor t = mio::abm::TimePoint(0) + mio::abm::days(2); - ASSERT_NEAR(params.get()(t.days()), 0.91, 0.0001); + EXPECT_NEAR(params.get()(t.days()), 0.91, eps); t = mio::abm::TimePoint(0) + mio::abm::days(15); - ASSERT_NEAR(params.get()(t.days()), 0.8635, 0.0001); + EXPECT_NEAR(params.get()(t.days()), 0.8635, eps); t = mio::abm::TimePoint(0) + mio::abm::days(40); - ASSERT_NEAR(params.get()(t.days()), 0, 0.0001); + EXPECT_NEAR(params.get()(t.days()), 0.81, eps); } diff --git a/cpp/tests/test_abm_model.cpp b/cpp/tests/test_abm_model.cpp index 0cbdd6d810..738c31cec6 100644 --- a/cpp/tests/test_abm_model.cpp +++ b/cpp/tests/test_abm_model.cpp @@ -18,6 +18,7 @@ * limitations under the License. */ #include "abm/person.h" +#include "abm/model.h" #include "abm_helpers.h" #include "memilio/utils/random_number_generator.h" @@ -38,8 +39,8 @@ TEST(TestModel, addLocation) auto work_id = model.add_location(mio::abm::LocationType::Work); auto home_id = model.add_location(mio::abm::LocationType::Home); - ASSERT_EQ(school_id1.get(), 1u); - ASSERT_EQ(school_id2.get(), 2u); + EXPECT_EQ(school_id1.get(), 1u); + EXPECT_EQ(school_id2.get(), 2u); auto& school1 = model.get_location(school_id1); auto& school2 = model.get_location(school_id2); @@ -52,12 +53,12 @@ TEST(TestModel, addLocation) count_schools++; } } - ASSERT_EQ(count_schools, 2); + EXPECT_EQ(count_schools, 2); - ASSERT_EQ(model.get_locations()[1], school1); - ASSERT_EQ(model.get_locations()[2], school2); - ASSERT_EQ(model.get_locations()[3], work); - ASSERT_EQ(model.get_locations()[4], home); + EXPECT_EQ(model.get_locations()[1], school1); + EXPECT_EQ(model.get_locations()[2], school2); + EXPECT_EQ(model.get_locations()[3], work); + EXPECT_EQ(model.get_locations()[4], home); } TEST(TestModel, addPerson) @@ -68,9 +69,9 @@ TEST(TestModel, addPerson) model.add_person(location, age_group_15_to_34); model.add_person(location, age_group_35_to_59); - ASSERT_EQ(model.get_persons().size(), 2); - ASSERT_EQ(model.get_person(0).get_age(), age_group_15_to_34); - ASSERT_EQ(model.get_person(1).get_age(), age_group_35_to_59); + EXPECT_EQ(model.get_persons().size(), 2); + EXPECT_EQ(model.get_person(0).get_age(), age_group_15_to_34); + EXPECT_EQ(model.get_person(1).get_age(), age_group_35_to_59); } TEST(TestModel, getSubpopulationCombined) @@ -88,13 +89,13 @@ TEST(TestModel, getSubpopulationCombined) add_test_person(model, school3, age_group_15_to_34, mio::abm::InfectionState::InfectedNoSymptoms); add_test_person(model, home1, age_group_15_to_34, mio::abm::InfectionState::InfectedNoSymptoms); - ASSERT_EQ(model.get_subpopulation_combined_per_location_type(t, mio::abm::InfectionState::Susceptible, + EXPECT_EQ(model.get_subpopulation_combined_per_location_type(t, mio::abm::InfectionState::Susceptible, mio::abm::LocationType::School), 3); - ASSERT_EQ(model.get_subpopulation_combined_per_location_type(t, mio::abm::InfectionState::InfectedNoSymptoms, + EXPECT_EQ(model.get_subpopulation_combined_per_location_type(t, mio::abm::InfectionState::InfectedNoSymptoms, mio::abm::LocationType::School), 2); - ASSERT_EQ(model.get_subpopulation_combined(t, mio::abm::InfectionState::InfectedNoSymptoms), 3); + EXPECT_EQ(model.get_subpopulation_combined(t, mio::abm::InfectionState::InfectedNoSymptoms), 3); } TEST(TestModel, findLocation) @@ -456,7 +457,7 @@ TEST(TestModelTestingCriteria, testAddingAndUpdatingAndRunningTestingSchemes) mio::abm::TestingScheme(testing_criteria, validity_period, start_date, end_date, test_params_pcr, probability); model.get_testing_strategy().add_testing_scheme(mio::abm::LocationType::Work, testing_scheme); - ASSERT_EQ(model.get_testing_strategy().run_strategy(rng_person, person, work, current_time), + EXPECT_EQ(model.get_testing_strategy().run_strategy(rng_person, person, work, current_time), true); // no active testing scheme -> person can enter current_time = mio::abm::TimePoint(30); model.get_testing_strategy().update_activity_status(current_time); @@ -465,12 +466,12 @@ TEST(TestModelTestingCriteria, testAddingAndUpdatingAndRunningTestingSchemes) .Times(testing::AtLeast(2)) .WillOnce(testing::Return(0.7)) .WillOnce(testing::Return(0.4)); - ASSERT_EQ(model.get_testing_strategy().run_strategy(rng_person, person, work, current_time), false); + EXPECT_EQ(model.get_testing_strategy().run_strategy(rng_person, person, work, current_time), false); model.get_testing_strategy().add_testing_scheme(mio::abm::LocationType::Work, testing_scheme); //doesn't get added because of == operator model.get_testing_strategy().remove_testing_scheme(mio::abm::LocationType::Work, testing_scheme); - ASSERT_EQ(model.get_testing_strategy().run_strategy(rng_person, person, work, current_time), + EXPECT_EQ(model.get_testing_strategy().run_strategy(rng_person, person, work, current_time), true); // no more testing_schemes } @@ -499,67 +500,68 @@ TEST(TestModel, checkParameterConstraints) params.get()[mio::abm::MaskType::FFP2] = 0.6; params.get()[mio::abm::MaskType::Surgical] = 0.7; params.get() = mio::abm::TimePoint(0); - ASSERT_EQ(params.check_constraints(), false); + EXPECT_EQ(params.check_constraints(), false); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -1.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 1.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -2.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 2.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -3.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 3.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -4.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 4.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -5.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 5.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -6.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 6.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -7.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 7.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -8.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 8.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -9.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 9.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -10.; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 10.; params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 1.1; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.3; params.get()[age_group_35_to_59] = mio::abm::hours(30); - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[age_group_35_to_59] = mio::abm::hours(4); params.get()[age_group_35_to_59] = mio::abm::hours(30); - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[age_group_35_to_59] = mio::abm::hours(8); params.get()[age_group_0_to_4] = mio::abm::hours(30); - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[age_group_0_to_4] = mio::abm::hours(3); params.get()[age_group_0_to_4] = mio::abm::hours(30); - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[age_group_0_to_4] = mio::abm::hours(6); params.get()[mio::abm::MaskType::Community] = 1.2; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[mio::abm::MaskType::Community] = 0.5; params.get()[mio::abm::MaskType::FFP2] = 1.2; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[mio::abm::MaskType::FFP2] = 0.6; params.get()[mio::abm::MaskType::Surgical] = 1.2; - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); params.get()[mio::abm::MaskType::Surgical] = 0.7; params.get() = mio::abm::TimePoint(-2); - ASSERT_EQ(params.check_constraints(), true); + EXPECT_EQ(params.check_constraints(), true); + mio::set_log_level(mio::LogLevel::warn); } TEST(TestModel, mobilityRulesWithAppliedNPIs) diff --git a/cpp/tests/test_abm_serialization.cpp b/cpp/tests/test_abm_serialization.cpp new file mode 100644 index 0000000000..6a69f9aadb --- /dev/null +++ b/cpp/tests/test_abm_serialization.cpp @@ -0,0 +1,273 @@ +/* +* Copyright (C) 2020-2024 MEmilio +* +* Authors: René Schmieding +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#include "matchers.h" +#include "abm/config.h" +#include "abm/infection_state.h" +#include "abm/parameters.h" +#include "abm/test_type.h" +#include "abm/testing_strategy.h" +#include "abm/vaccine.h" +#include "memilio/epidemiology/age_group.h" +#include "memilio/io/json_serializer.h" +#include "memilio/utils/custom_index_array.h" +#include "memilio/utils/uncertain_value.h" +#include "models/abm/location.h" +#include "models/abm/person.h" +#include "models/abm/trip_list.h" +#include "models/abm/model.h" + +#ifdef MEMILIO_HAS_JSONCPP + +#include "json/value.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +/** + * @brief Test de- and serialization of an object without equality operator. + * + * Test that a json value x representing type T is equal to serialize(deserialize(x)). + * + * Assuming the (de)serialization functions' general behavior is independent of specific values of member variables, + * i.e. the function does not contain conditionals (`if (t > 0)`), optionals (`add_optional`/`expect_optional`), etc., + * and assuming that deserialize is injective (meaning that two unequal instances of T do not have the same json + * representation, which can happen e.g. if not all member variables are serialized), + * this sufficiently tests that serialize and deserialize are inverse functions to each other. + * + * @tparam T The type to test. + * @param reference_json A json value representing an instance of T. + */ +template +void test_json_serialization(const Json::Value& reference_json) +{ + // check that the json is deserializable (i.e. a valid representation) + auto t_result = mio::deserialize_json(reference_json, mio::Tag()); + ASSERT_THAT(print_wrap(t_result), IsSuccess()); + + // check that the resulting type T is serializable + auto json_result = mio::serialize_json(t_result.value()); + ASSERT_THAT(print_wrap(json_result), IsSuccess()); + + EXPECT_THAT(json_result.value(), JsonEqual(reference_json)); +} + +TEST(TestAbmSerialization, Trip) +{ + // See test_json_serialization for info on this test. + + unsigned i = 1; // counter s.t. members have different values + + Json::Value reference_json; + reference_json["person_id"] = Json::UInt(i++); + reference_json["time"]["seconds"] = Json::Int(i++); + reference_json["destination"] = Json::UInt(i++); + reference_json["origin"] = Json::UInt(i++); + + test_json_serialization(reference_json); +} + +TEST(TestAbmSerialization, Vaccination) +{ + // See test_json_serialization for info on this test. + + Json::Value reference_json; + reference_json["exposure_type"] = Json::UInt(1); + reference_json["time"]["seconds"] = Json::Int(2); + + test_json_serialization(reference_json); +} + +TEST(TestAbmSerialization, Infection) +{ + // See test_json_serialization for info on this test. + + unsigned i = 1; // counter s.t. members have different values + + Json::Value viral_load; + viral_load["decline"] = Json::Value((double)i++); + viral_load["end_date"]["seconds"] = Json::Int(i++); + viral_load["incline"] = Json::Value((double)i++); + viral_load["peak"] = Json::Value((double)i++); + viral_load["start_date"]["seconds"] = Json::Int(i++); + + Json::Value reference_json; + reference_json["infection_course"] = Json::Value(Json::arrayValue); + reference_json["virus_variant"] = Json::UInt(0); + reference_json["viral_load"] = viral_load; + reference_json["log_norm_alpha"] = Json::Value((double)i++); + reference_json["log_norm_beta"] = Json::Value((double)i++); + reference_json["detected"] = Json::Value((bool)0); + + test_json_serialization(reference_json); +} + +TEST(TestAbmSerialization, TestingScheme) +{ + // See test_json_serialization for info on this test. + + unsigned i = 1; // counter s.t. members have different values + + Json::Value testing_criteria; + std::vector ages_bits(mio::abm::MAX_NUM_AGE_GROUPS, false); + ages_bits[i++] = true; + testing_criteria["ages"]["bitset"] = mio::serialize_json(ages_bits).value(); + std::vector inf_st_bits((size_t)mio::abm::InfectionState::Count, false); + inf_st_bits[i++] = true; + testing_criteria["infection_states"]["bitset"] = mio::serialize_json(inf_st_bits).value(); + + Json::Value test_parameters; + test_parameters["sensitivity"] = mio::serialize_json(mio::UncertainValue{(double)i++}).value(); + test_parameters["specificity"] = mio::serialize_json(mio::UncertainValue{(double)i++}).value(); + test_parameters["required_time"]["seconds"] = Json::Int(i++); + test_parameters["test_type"] = Json::UInt(0); + + Json::Value reference_json; + reference_json["criteria"] = testing_criteria; + reference_json["validity_period"]["seconds"] = Json::Int(i++); + reference_json["start_date"]["seconds"] = Json::Int(i++); + reference_json["end_date"]["seconds"] = Json::Int(i++); + reference_json["test_params"] = test_parameters; + reference_json["probability"] = Json::Value((double)i++); + reference_json["is_active"] = Json::Value((bool)0); + + test_json_serialization(reference_json); +} + +TEST(TestAbmSerialization, TestingStrategy) +{ + // See test_json_serialization for info on this test. + + unsigned i = 1; // counter s.t. members have different values + + Json::Value local_strategy; + local_strategy["id"] = Json::UInt(i++); + local_strategy["schemes"] = Json::Value(Json::arrayValue); + local_strategy["type"] = Json::UInt(i++); + + Json::Value reference_json; + reference_json["schemes"][0] = local_strategy; + + test_json_serialization(reference_json); +} + +TEST(TestAbmSerialization, TestResult) +{ + // See test_json_serialization for info on this test. + + Json::Value reference_json; + reference_json["result"] = Json::Value(false); + reference_json["time_of_testing"]["seconds"] = Json::Int(1); + + test_json_serialization(reference_json); +} + +TEST(TestAbmSerialization, Person) +{ + // See test_json_serialization for info on this test. + + auto json_uint_array = [](std::vector values) { + return mio::serialize_json(values).value(); + }; + auto json_double_array = [](std::vector values) { + return mio::serialize_json(values).value(); + }; + + unsigned i = 1; // counter s.t. members have different values + + Json::Value reference_json; + reference_json["age_group"] = Json::UInt(i++); + reference_json["assigned_locations"] = json_uint_array({i++, i++, i++, i++, i++, i++, i++, i++, i++, i++, i++}); + reference_json["cells"] = json_uint_array({i++}); + reference_json["compliance"] = json_double_array({(double)i++, (double)i++, (double)i++}); + reference_json["id"] = Json::UInt(i++); + reference_json["infections"] = Json::Value(Json::arrayValue); + reference_json["last_transport_mode"] = Json::UInt(i++); + reference_json["location"] = Json::UInt(i++); + reference_json["location_type"] = Json::UInt(0); + reference_json["mask"]["mask_type"] = Json::UInt(0); + reference_json["mask"]["time_first_used"]["seconds"] = Json::Int(i++); + reference_json["home_isolation_start"]["seconds"] = Json::Int(i++); + reference_json["rnd_go_to_school_hour"] = Json::Value((double)i++); + reference_json["rnd_go_to_work_hour"] = Json::Value((double)i++); + reference_json["rnd_schoolgroup"] = Json::Value((double)i++); + reference_json["rnd_workgroup"] = Json::Value((double)i++); + reference_json["rng_counter"] = Json::UInt(i++); + reference_json["test_results"] = + mio::serialize_json(mio::CustomIndexArray{}).value(); + reference_json["time_at_location"]["seconds"] = Json::Int(i++); + reference_json["vaccinations"] = Json::Value(Json::arrayValue); + + test_json_serialization(reference_json); +} + +TEST(TestAbmSerialization, Location) +{ + // See test_json_serialization for info on this test. + + unsigned i = 1; // counter s.t. members have different values + + Json::Value contact_rates = mio::serialize_json(mio::abm::ContactRates::get_default(i++)).value(); + + Json::Value reference_json; + reference_json["cells"][0]["capacity"]["persons"] = Json::UInt(i++); + reference_json["cells"][0]["capacity"]["volume"] = Json::UInt(i++); + reference_json["geographical_location"]["latitude"] = Json::Value((double)i++); + reference_json["geographical_location"]["longitude"] = Json::Value((double)i++); + reference_json["id"] = Json::UInt(i++); + reference_json["parameters"]["ContactRates"] = contact_rates; + reference_json["parameters"]["MaximumContacts"] = Json::Value((double)i++); + reference_json["parameters"]["UseLocationCapacityForTransmissions"] = Json::Value(false); + reference_json["required_mask"] = Json::UInt(0); + reference_json["type"] = Json::UInt(0); + + test_json_serialization(reference_json); +} + +TEST(TestAbmSerialization, Model) +{ + // See test_json_serialization for info on this test. + + auto json_uint_array = [](std::vector values) { + return mio::serialize_json(values).value(); + }; + + unsigned i = 1; // counter s.t. members have different values + + Json::Value abm_parameters = mio::serialize_json(mio::abm::Parameters(i++)).value(); + + Json::Value reference_json; + reference_json["cemetery_id"] = Json::UInt(i++); + reference_json["location_types"] = Json::UInt(i++); + reference_json["locations"] = Json::Value(Json::arrayValue); + reference_json["parameters"] = abm_parameters; + reference_json["persons"] = Json::Value(Json::arrayValue); + reference_json["rng"]["counter"] = Json::UInt(i++); + reference_json["rng"]["key"] = Json::UInt(i++); + reference_json["rng"]["seeds"] = json_uint_array({i++, i++, i++, i++, i++, i++}); + reference_json["testing_strategy"]["schemes"] = Json::Value(Json::arrayValue); + reference_json["trip_list"]["index"] = Json::UInt(i++); + reference_json["trip_list"]["trips_weekday"] = Json::Value(Json::arrayValue); + reference_json["trip_list"]["trips_weekend"] = Json::Value(Json::arrayValue); + reference_json["use_mobility_rules"] = Json::Value(false); + + test_json_serialization(reference_json); +} + +#endif diff --git a/cpp/tests/test_flows.cpp b/cpp/tests/test_flows.cpp index ca74b1db61..fad2a71f6c 100644 --- a/cpp/tests/test_flows.cpp +++ b/cpp/tests/test_flows.cpp @@ -96,8 +96,6 @@ TEST(TestFlows, FlowChart) TEST(TestFlows, FlowSimulation) { - mio::set_log_level(mio::LogLevel::off); - double t0 = 0; double tmax = 1; double dt = 0.001; @@ -119,9 +117,12 @@ TEST(TestFlows, FlowSimulation) model.parameters.set>(0.04); model.parameters.get>().get_cont_freq_mat()[0].get_baseline().setConstant(10); + mio::set_log_level(mio::LogLevel::off); // Suppress log output of check_constraints and the Simulation. model.check_constraints(); auto IC = std::make_shared>(); auto seir = mio::simulate_flows>(t0, tmax, dt, model, IC); + mio::set_log_level(mio::LogLevel::warn); + // verify results (computed using flows) auto results = seir[0].get_last_value(); EXPECT_NEAR(results[0], 9660.5835936179408, 1e-14); @@ -137,8 +138,6 @@ TEST(TestFlows, FlowSimulation) TEST(TestFlows, CompareSimulations) { - mio::set_log_level(mio::LogLevel::off); - double t0 = 0; double tmax = 1; double dt = 0.001; @@ -162,11 +161,14 @@ TEST(TestFlows, CompareSimulations) model.parameters.get>().get_cont_freq_mat(); contact_matrix[0].get_baseline().setConstant(10); + mio::set_log_level(mio::LogLevel::off); // Suppress log output of check_constraints and the Simulation. model.check_constraints(); auto seir_sim_flows = simulate_flows(t0, tmax, dt, model); auto seir_sim = simulate(t0, tmax, dt, model); - auto results_flows = seir_sim_flows[0].get_last_value(); - auto results = seir_sim.get_last_value(); + mio::set_log_level(mio::LogLevel::warn); + + auto results_flows = seir_sim_flows[0].get_last_value(); + auto results = seir_sim.get_last_value(); EXPECT_NEAR(results[0], results_flows[0], 1e-10); EXPECT_NEAR(results[1], results_flows[1], 1e-10); diff --git a/cpp/tests/test_io_cli.cpp b/cpp/tests/test_io_cli.cpp index dc376a0f88..5301a82cce 100644 --- a/cpp/tests/test_io_cli.cpp +++ b/cpp/tests/test_io_cli.cpp @@ -151,6 +151,8 @@ using Params = mio::ParameterSet; // using BadParams = mio::ParameterSet; TEST(TestCLI, test_option_verifier) { + GTEST_FLAG_SET(death_test_style, "threadsafe"); + EXPECT_DEBUG_DEATH(mio::details::cli::verify_options(mio::ParameterSet()), ".*Options may not have duplicate fields\\. \\(field required\\)"); EXPECT_DEBUG_DEATH(mio::details::cli::verify_options(mio::ParameterSet()), @@ -212,6 +214,8 @@ TEST(TestCLI, test_set_param) TEST(TestCLI, test_write_help) { + GTEST_FLAG_SET(death_test_style, "threadsafe"); + std::stringstream ss; const std::string help = "Usage: TestSuite