diff --git a/src/bthread/bthread.h b/src/bthread/bthread.h index 7e42c96c9f..603cf04d0e 100644 --- a/src/bthread/bthread.h +++ b/src/bthread/bthread.h @@ -30,6 +30,7 @@ #if defined(__cplusplus) #include #include "bthread/mutex.h" // use bthread_mutex_t in the RAII way +#include "bthread/condition_variable.h" // use bthread_cond_t in the RAII way #endif // __cplusplus #include "bthread/id.h" diff --git a/src/bthread/condition_variable.h b/src/bthread/condition_variable.h index c684cf6cbd..fb6bb4bcb5 100644 --- a/src/bthread/condition_variable.h +++ b/src/bthread/condition_variable.h @@ -63,6 +63,20 @@ class ConditionVariable { bthread_cond_wait(&_cond, lock.mutex()); } + template + void wait(std::unique_lock& lock, Predicate p) { + while (!p()) { + bthread_cond_wait(&_cond, lock.mutex()->native_handler()); + } + } + + template + void wait(std::unique_lock& lock, Predicate p) { + while (!p()) { + bthread_cond_wait(&_cond, lock.mutex()); + } + } + // Unlike std::condition_variable, we return ETIMEDOUT when time expires // rather than std::timeout int wait_for(std::unique_lock& lock, diff --git a/test/bthread_cond_unittest.cpp b/test/bthread_cond_unittest.cpp index d01ef69c26..f2dcddfe8c 100644 --- a/test/bthread_cond_unittest.cpp +++ b/test/bthread_cond_unittest.cpp @@ -138,7 +138,10 @@ TEST(CondTest, sanity) { struct WrapperArg { bthread::Mutex mutex; bthread::ConditionVariable cond; + bool ready = false; + static std::atomic wake_time; }; +std::atomic WrapperArg::wake_time{0}; void* cv_signaler(void* void_arg) { WrapperArg* a = (WrapperArg*)void_arg; @@ -168,6 +171,23 @@ void* cv_mutex_waiter(void* void_arg) { return NULL; } + +void* cv_bmutex_waiter_with_pred(void* void_arg) { + WrapperArg* a = (WrapperArg*)void_arg; + std::unique_lock lck(*a->mutex.native_handler()); + a->cond.wait(lck, [&] { return a->ready; }); + WrapperArg::wake_time.fetch_add(1); + return NULL; +} + +void* cv_mutex_waiter_with_pred(void* void_arg) { + WrapperArg* a = (WrapperArg*)void_arg; + std::unique_lock lck(a->mutex); + a->cond.wait(lck, [&] { return a->ready; }); + WrapperArg::wake_time.fetch_add(1); + return NULL; +} + #define COND_IN_PTHREAD #ifndef COND_IN_PTHREAD @@ -202,6 +222,37 @@ TEST(CondTest, cpp_wrapper) { } } +TEST(CondTest, cpp_wrapper2) { + stop = false; + bthread::ConditionVariable cond; + pthread_t bmutex_waiter_threads[8]; + pthread_t mutex_waiter_threads[8]; + pthread_t signal_thread; + WrapperArg a; + for (size_t i = 0; i < ARRAY_SIZE(bmutex_waiter_threads); ++i) { + ASSERT_EQ(0, pthread_create(&bmutex_waiter_threads[i], NULL, + cv_bmutex_waiter_with_pred, &a)); + ASSERT_EQ(0, pthread_create(&mutex_waiter_threads[i], NULL, + cv_mutex_waiter_with_pred, &a)); + } + ASSERT_EQ(0, pthread_create(&signal_thread, NULL, cv_signaler, &a)); + bthread_usleep(100L * 1000); + ASSERT_EQ(WrapperArg::wake_time, 0); + { + BAIDU_SCOPED_LOCK(a.mutex); + stop = true; + a.ready = true; + + } + pthread_join(signal_thread, NULL); + a.cond.notify_all(); + for (size_t i = 0; i < ARRAY_SIZE(bmutex_waiter_threads); ++i) { + pthread_join(bmutex_waiter_threads[i], NULL); + pthread_join(mutex_waiter_threads[i], NULL); + } + ASSERT_EQ(WrapperArg::wake_time, 16); +} + #ifndef COND_IN_PTHREAD #undef pthread_join #undef pthread_create