diff --git a/src/analysis/lattices/inverted.h b/src/analysis/lattices/inverted.h index b232d89aa7e..22e3267423b 100644 --- a/src/analysis/lattices/inverted.h +++ b/src/analysis/lattices/inverted.h @@ -37,10 +37,14 @@ template struct Inverted { LatticeComparison compare(const Element& a, const Element& b) const noexcept { return lattice.compare(b, a); } - bool join(Element& joinee, Element joiner) const noexcept { + + template + bool join(Element& joinee, const Elem& joiner) const noexcept { return lattice.meet(joinee, joiner); } - bool meet(Element& meetee, Element meeter) const noexcept { + + template + bool meet(Element& meetee, const Elem& meeter) const noexcept { return lattice.join(meetee, meeter); } }; diff --git a/src/analysis/lattices/shared.h b/src/analysis/lattices/shared.h index 489ed0003c2..e75b895ea21 100644 --- a/src/analysis/lattices/shared.h +++ b/src/analysis/lattices/shared.h @@ -106,7 +106,8 @@ template struct Shared { return false; } - bool join(Element& joinee, const typename L::Element& joiner) const noexcept { + template + bool join(Element& joinee, const Elem& joiner) const noexcept { if (lattice.join(val, joiner)) { // We have moved to the next value in our ascending chain. Assign it a new // sequence number and update joinee with that sequence number. diff --git a/src/analysis/lattices/vector.h b/src/analysis/lattices/vector.h index d13380868fd..175ec5e3c36 100644 --- a/src/analysis/lattices/vector.h +++ b/src/analysis/lattices/vector.h @@ -30,6 +30,13 @@ namespace wasm::analysis { template struct Vector { using Element = std::vector; + // Represent a vector in which all but one of the elements are bottom without + // materializing the full vector. + struct SingletonElement : std::pair { + SingletonElement(size_t i, typename L::Element&& elem) + : std::pair{i, std::move(elem)} {} + }; + L lattice; const size_t size; @@ -39,13 +46,7 @@ template struct Vector { return Element(size, lattice.getBottom()); } - Element getTop() const noexcept -#if __cplusplus >= 202002L - requires FullLattice -#endif - { - return Element(size, lattice.getTop()); - } + Element getTop() const noexcept { return Element(size, lattice.getTop()); } // `a` <= `b` if their elements are pairwise <=, etc. Unless we determine // that there is no relation, we must check all the elements. @@ -84,48 +85,70 @@ template struct Vector { assert(joiner.size() == size); bool result = false; for (size_t i = 0; i < size; ++i) { - if constexpr (std::is_same_v) { - // The vector specialization does not expose references to the - // individual bools because they might be in a bitmap, so we need a - // workaround. - bool e = joinee[i]; - if (lattice.join(e, joiner[i])) { - joinee[i] = e; - result = true; - } - } else { - result |= lattice.join(joinee[i], joiner[i]); - } + result |= joinAtIndex(joinee, i, joiner[i]); } - return result; } + bool join(Element& joinee, const SingletonElement& joiner) const noexcept { + const auto& [index, elem] = joiner; + assert(index < joinee.size()); + return joinAtIndex(joinee, index, elem); + } + // Pairwise meet on the elements. - bool meet(Element& meetee, const Element& meeter) const noexcept -#if __cplusplus >= 202002L - requires FullLattice -#endif - { + bool meet(Element& meetee, const Element& meeter) const noexcept { assert(meetee.size() == size); assert(meeter.size() == size); bool result = false; for (size_t i = 0; i < size; ++i) { - if constexpr (std::is_same_v) { - // The vector specialization does not expose references to the - // individual bools because they might be in a bitmap, so we need a - // workaround. - bool e = meetee[i]; - if (lattice.meet(e, meeter[i])) { - meetee[i] = e; - result = true; - } - } else { - result |= lattice.meet(meetee[i], meeter[i]); - } + result |= meetAtIndex(meetee, i, meeter[i]); } return result; } + + bool meet(Element& meetee, const SingletonElement& meeter) const noexcept { + const auto& [index, elem] = meeter; + assert(index < meetee.size()); + return meetAtIndex(meetee, index, elem); + } + +private: + bool joinAtIndex(Element& joinee, + size_t i, + const typename L::Element& elem) const noexcept { + if constexpr (std::is_same_v) { + // The vector specialization does not expose references to the + // individual bools because they might be in a bitmap, so we need a + // workaround. + bool e = joinee[i]; + if (lattice.join(e, elem)) { + joinee[i] = e; + return true; + } + return false; + } else { + return lattice.join(joinee[i], elem); + } + } + + bool meetAtIndex(Element& meetee, + size_t i, + const typename L::Element& elem) const noexcept { + if constexpr (std::is_same_v) { + // The vector specialization does not expose references to the + // individual bools because they might be in a bitmap, so we need a + // workaround. + bool e = meetee[i]; + if (lattice.meet(e, elem)) { + meetee[i] = e; + return true; + } + return false; + } else { + return lattice.meet(meetee[i], elem); + } + } }; #if __cplusplus >= 202002L diff --git a/test/gtest/lattices.cpp b/test/gtest/lattices.cpp index ed4f48e38ea..905f03420e2 100644 --- a/test/gtest/lattices.cpp +++ b/test/gtest/lattices.cpp @@ -478,6 +478,30 @@ TEST(VectorLattice, Meet) { vector, {false, false}, {false, true}, {true, false}, {true, true}); } +TEST(VectorLattice, JoinSingleton) { + using Vec = analysis::Vector; + Vec vector{analysis::Bool{}, 2}; + auto elem = vector.getBottom(); + + EXPECT_FALSE(vector.join(elem, Vec::SingletonElement(0, false))); + EXPECT_EQ(elem, (std::vector{false, false})); + + EXPECT_TRUE(vector.join(elem, Vec::SingletonElement(1, true))); + EXPECT_EQ(elem, (std::vector{false, true})); +} + +TEST(VectorLattice, MeetSingleton) { + using Vec = analysis::Vector; + Vec vector{analysis::Bool{}, 2}; + auto elem = vector.getTop(); + + EXPECT_FALSE(vector.meet(elem, Vec::SingletonElement(1, true))); + EXPECT_EQ(elem, (std::vector{true, true})); + + EXPECT_TRUE(vector.meet(elem, Vec::SingletonElement(0, false))); + EXPECT_EQ(elem, (std::vector{false, true})); +} + TEST(TupleLattice, GetBottom) { analysis::Tuple tuple{analysis::Bool{}, analysis::UInt32{}}; @@ -656,6 +680,25 @@ TEST(SharedLattice, Join) { } } +TEST(SharedLattice, JoinVecSingleton) { + using Vec = analysis::Vector; + analysis::Shared shared{analysis::Vector{analysis::Bool{}, 2}}; + + auto elem = shared.getBottom(); + EXPECT_TRUE(shared.join(elem, Vec::SingletonElement(1, true))); + EXPECT_EQ(*elem, (std::vector{false, true})); +} + +TEST(SharedLattice, JoinInvertedVecSingleton) { + using Vec = analysis::Vector; + analysis::Shared> shared{ + analysis::Inverted{analysis::Vector{analysis::Bool{}, 2}}}; + + auto elem = shared.getBottom(); + EXPECT_TRUE(shared.join(elem, Vec::SingletonElement(1, false))); + EXPECT_EQ(*elem, (std::vector{true, false})); +} + TEST(StackLattice, GetBottom) { analysis::Stack stack{analysis::Flat{}}; EXPECT_EQ(stack.getBottom().size(), 0u);