commit 2c8fcb8dd6e3b67ffebf51f4d0ca0c612cd428b6
parent a67325eb6c28db74bb43e9197ba5eaa4afb3faf9
Author: Henry Wilson <henry@henryandlizzy.uk>
Date: Fri, 14 Jul 2023 23:51:18 +0100
coroutine_owner: add owning version of std::coroutine_handle
Diffstat:
7 files changed, 282 insertions(+), 148 deletions(-)
diff --git a/src/coro-generator-consumer.cpp b/src/coro-generator-consumer.cpp
@@ -1,9 +1,9 @@
-#include <coroutine> // coroutine_handle noop_coroutine suspend_always suspend_never
+#include "coroutine_owner.hpp" // coroutine_owner noop_coroutine suspend_always suspend_never
+
#include <exception> // terminate
#include <iostream> // cout
#include <optional> // optional
#include <cassert> // assert
-#include <utility> // exchange
struct sentry
{
@@ -23,20 +23,16 @@ struct symmetric_transfer : std::suspend_always
struct generator
{
struct promise_type;
- ~generator()
- {
- h.destroy();
- }
- promise_type& promise() { return h.promise(); }
+ promise_type& promise() { return h->promise(); }
- std::coroutine_handle<promise_type> h;
+ coroutine_owner<promise_type> h;
};
struct generator::promise_type
{
promise_type() { std::cout << __PRETTY_FUNCTION__ << '\n'; }
~promise_type() { std::cout << __PRETTY_FUNCTION__ << '\n'; }
- generator get_return_object() { return {std::coroutine_handle<promise_type>::from_promise(*this)}; }
+ generator get_return_object() { return {make_coroutine_owner<promise_type>(*this)}; }
std::suspend_always initial_suspend() { return {}; }
symmetric_transfer yield_value(unsigned v)
{
diff --git a/src/coro-poll.cpp b/src/coro-poll.cpp
@@ -1,10 +1,12 @@
+#include "coroutine_owner.hpp"
+
#include <unistd.h>
#include <poll.h>
#include <sys/timerfd.h>
-#include <coroutine>
#include <iostream>
#include <map>
+#include <cstdint>
static void check_err(char const* msg)
{
@@ -36,12 +38,14 @@ struct coroutine_task
{
struct promise_type
{
- void get_return_object(void) const noexcept {}
+ coroutine_task get_return_object(void) noexcept { return {make_coroutine_owner(*this)}; }
std::suspend_never initial_suspend(void) const noexcept { return {}; }
- std::suspend_never final_suspend() const noexcept { return {}; }
- void unhandled_exception() const noexcept {}
+ std::suspend_always final_suspend() const noexcept { return {}; }
+ void unhandled_exception() const noexcept { std::terminate(); }
void return_void(void) const noexcept {}
};
+
+ coroutine_owner<promise_type> c;
};
struct poll_scheduler
@@ -50,12 +54,6 @@ struct poll_scheduler
poll_scheduler(poll_scheduler const&) = delete;
poll_scheduler& operator =(poll_scheduler const&) = delete;
- ~poll_scheduler()
- {
- for (auto& [_, h] : waiting)
- h.destroy();
- }
-
struct awaitable : std::suspend_always
{
poll_scheduler& scheduler;
@@ -158,9 +156,9 @@ int main()
{
poll_scheduler s;
- echo_task(0, s);
- timer_task(4, s);
- timer_task(8, s);
+ auto task1 = echo_task(0, s);
+ auto task2 = timer_task(4, s);
+ auto task3 = timer_task(8, s);
s.flush();
}
diff --git a/src/coro-round-robin.cpp b/src/coro-round-robin.cpp
@@ -1,10 +1,11 @@
-#include <coroutine>
+#include "coroutine_owner.hpp"
+
#include <iostream>
#include <deque>
struct round_robin_scheduler : std::suspend_always
{
- std::deque<std::coroutine_handle<>> tasks;
+ std::deque<coroutine_owner<>> tasks;
bool await_ready(void) const noexcept
{
@@ -13,12 +14,12 @@ struct round_robin_scheduler : std::suspend_always
void schedule(std::coroutine_handle<> h)
{
- tasks.push_back(h);
+ tasks.push_back(coroutine_owner<>{h});
}
std::coroutine_handle<> next(void) noexcept
{
- auto h = tasks.front();
+ auto h = tasks.front().release();
tasks.pop_front();
return h;
}
@@ -35,9 +36,6 @@ struct round_robin_scheduler : std::suspend_always
{
if (not tasks.empty())
std::cout << "Finishing early, cleaning up state of suspended tasks..." << std::endl;
-
- for (auto h : tasks)
- h.destroy();
}
~round_robin_scheduler()
diff --git a/src/coro-state-machine.cpp b/src/coro-state-machine.cpp
@@ -0,0 +1,109 @@
+#include "coroutine_owner.hpp"
+
+#include <exception>
+#include <cassert>
+#include <iostream>
+
+using std::cerr, std::endl;
+
+struct task
+{
+ struct promise_type
+ {
+ task get_return_object() { return make_coroutine_owner(*this); }
+ std::suspend_never initial_suspend() { return {}; }
+ void return_value(task);
+ void unhandled_exception() { std::terminate(); }
+ std::suspend_always final_suspend() noexcept { return {}; }
+ coroutine_owner<promise_type> continuation;
+ };
+
+ task() = default;
+ task(coroutine_owner<promise_type> c)
+ : h{std::move(c)}
+ {}
+
+ coroutine_owner<promise_type> h;
+
+ struct await_done
+ {
+ bool await_ready();
+ void await_suspend(std::coroutine_handle<> s);
+ coroutine_owner<promise_type> await_resume();
+ std::coroutine_handle<> awaiting;
+ promise_type& promise;
+ };
+ await_done done() { return {{}, h->promise()}; }
+};
+
+void task::promise_type::return_value(task t) { continuation = std::move(t.h); }
+
+
+bool task::await_done::await_ready()
+{
+ return false;
+}
+void task::await_done::await_suspend(std::coroutine_handle<> s)
+{
+ assert(not awaiting);
+ awaiting = s;
+}
+coroutine_owner<task::promise_type> task::await_done::await_resume()
+{
+ awaiting = {};
+ return std::exchange(promise.continuation, {});
+}
+
+struct awaitable
+{
+ std::coroutine_handle<> awaiting;
+ bool await_ready() { return false; }
+ void await_suspend(std::coroutine_handle<> s) { assert(not std::exchange(awaiting, s)); }
+ void await_resume() {}
+
+ void event() {
+ std::exchange(awaiting, {}).resume();
+ }
+};
+
+task state_machine(awaitable& event);
+
+task loop_machine(awaitable& event)
+{
+ cerr << "init loop\n";
+ for (unsigned i = 0; i < 3; ++i)
+ {
+ co_await event;
+ cerr << "loop!\n";
+ }
+ co_return state_machine(event);
+}
+
+task state_machine(awaitable& event)
+{
+ cerr << "init\n";
+ co_await event;
+ cerr << "button pressed\n";
+ co_await event;
+ cerr << "button pressed 2\n";
+ co_await event;
+ cerr << "button pressed 3\n";
+ co_return loop_machine(event);
+}
+
+task state_engine(task t)
+{
+ cerr << "se init\n";
+ for (;;)
+ t = task{co_await t.done()};
+}
+
+int main()
+{
+ awaitable event;
+ auto h = state_machine(event);
+ auto e = state_engine(std::move(h));
+ for (unsigned i = 0; i < 10; ++i)
+ event.event();
+ cerr << "Done!\n";
+}
diff --git a/src/coro-throwing.cpp b/src/coro-throwing.cpp
@@ -1,98 +1,88 @@
-#include <coroutine>
-#include <iostream>
-#include <optional>
-#include <stdexcept>
+#include "coroutine_owner.hpp"
-using namespace std;
+#include <stdexcept>
+#include <utility>
+#include <iostream>
-struct coroutine
+struct coro
{
- struct promise_type;
- coroutine_handle<> h;
- ~coroutine()
+ struct promise_type
{
- if (not h.done())
- h.destroy();
- }
-};
-
-struct coroutine::promise_type
-{
- coroutine get_return_object() noexcept { return {coroutine_handle<promise_type>::from_promise(*this)}; }
- suspend_never initial_suspend() noexcept { return {}; }
- suspend_always final_suspend() noexcept { return {}; }
- void return_void() noexcept {}
- void unhandled_exception() { throw; }
-};
+ coro get_return_object() { return {*this}; }
+ std::suspend_never initial_suspend() { return {}; }
+ std::suspend_always final_suspend() noexcept { return {}; }
+ void unhandled_exception() noexcept { e = std::current_exception(); }
+ void return_void() noexcept {}
+ std::exception_ptr e;
+ };
+ coroutine_owner<promise_type> h;
-struct awaitable : suspend_always
-{
- coroutine_handle<> suspended;
- void await_suspend(coroutine_handle<> h)
+ void result()
{
- suspended = h;
+ if (not *h)
+ throw std::logic_error("cannot get result() from empty coroutine");
+ else if (not h->done())
+ throw std::logic_error("cannot get result() from unfinished coroutine");
+ else if (auto e = std::exchange(h->promise().e, {}))
+ std::rethrow_exception(e);
+ std::cerr << "coroutine finished successfully.\n";
}
-} awaitable;
-
-coroutine no_throw()
-{
- cerr << "no_throw()\n";
- co_return;
-}
-
-
-coroutine initial_throw()
-{
- cerr << "\ninitial_throw()\n";
- throw runtime_error("initial_throw");
- co_return;
-}
-coroutine resume_throw()
-{
- cerr << "\nresume_throw()\n";
- for (;;)
+ coro() = default;
+ coro(promise_type& p)
+ : h{std::coroutine_handle<promise_type>::from_promise(p)}
+ {}
+ ~coro()
{
- co_await awaitable;
- cerr << "resume_throw() RESUMED\n";
- throw runtime_error("resume_throw");
+ cancel();
}
-}
-
-coroutine resume_nothrow()
-{
- cerr << "\nresume_nothrow()\n";
- for (;;)
+ void cancel() noexcept
{
- co_await awaitable;
- cerr << "resume_nothrow() RESUMED\n";
+ if (auto x = std::exchange(h, {}); *x)
+ if (auto& e = x->promise().e)
+ try
+ {
+ std::rethrow_exception(e);
+ } catch (std::exception const& e)
+ {
+ std::cerr << "Dropped exception: '" << e.what() << "'\n";
+ } catch (...) {
+ std::cerr << "Dropped exception\n";
+ }
}
-}
-
-int main()
-{
- try {
- {
- no_throw();
- }
- [[maybe_unused]] auto coro = initial_throw();
- } catch (exception& e)
+ coro(coro const&) = delete;
+ coro& operator =(coro const&) = delete;
+ coro(coro&& old)
+ : h{std::move(old.h)}
+ {}
+ coro& operator =(coro&& old)
{
- cerr << "Exception caught: " << e.what() << endl;
+ cancel();
+ h = std::exchange(old.h, {});
+ return *this;
}
+};
- try {
- {
- [[maybe_unused]] auto coro = resume_nothrow();
- cerr << " suspend OK\n";
- awaitable.suspended.resume();
- }
- [[maybe_unused]] auto coro = resume_throw();
- cerr << " suspend OK\n";
- awaitable.suspended.resume();
- } catch (exception& e)
- {
- cerr << "Exception caught: " << e.what() << endl;
- }
+coro mycoro(bool success)
+{
+ if (success)
+ co_return;
+ throw std::runtime_error("hi");
+}
+
+int main() try
+{
+ auto a = mycoro(true);
+ a.result();
+ a = mycoro(false);
+ auto c = mycoro(false);
+ auto c2 = std::move(c);
+ a = std::move(c2);
+ a.result();
+}
+catch (std::exception const& e)
+{
+ std::cerr << "Caught exception: '" << e.what() << "'\n";
+ return 1;
}
diff --git a/src/coro-unconditional-dispatch.cpp b/src/coro-unconditional-dispatch.cpp
@@ -1,68 +1,42 @@
+#include "coroutine_owner.hpp"
+
#include <coroutine>
#include <iostream>
#include <vector>
-struct unique_coroutine
+struct task
{
struct promise_type
{
~promise_type()
{}
- unique_coroutine get_return_object() { return *this; }
+ task get_return_object() { return {make_coroutine_owner(*this)}; }
std::suspend_always initial_suspend() { return {}; }
std::suspend_always final_suspend() noexcept { return {}; }
void unhandled_exception() {}
void return_void() {}
};
- unique_coroutine(promise_type& p)
- : h(std::coroutine_handle<promise_type>::from_promise(p))
+ task(coroutine_owner<promise_type> c)
+ : h{std::move(c)}
{}
- ~unique_coroutine()
- {
- if (h)
- h.destroy();
- }
-
- unique_coroutine(unique_coroutine const&) = delete;
- unique_coroutine& operator =(unique_coroutine const&) = delete;
-
- unique_coroutine(unique_coroutine&& old)
- : h(old.h)
- {
- old.h = h.from_address(nullptr);
- }
-
- unique_coroutine& operator =(unique_coroutine&& old)
- {
- if (&old == this)
- return *this;
-
- if (h)
- h.destroy();
-
- h = old.h;
- old.h = h.from_address(nullptr);
- return *this;
- }
-
bool done(void) const
{
- return h.done();
+ return h->done();
}
bool operator()(void)
{
- h();
+ h->resume();
return done();
}
private:
- std::coroutine_handle<> h;
+ coroutine_owner<promise_type> h;
};
-bool done(unique_coroutine const& c)
+bool done(task const& c)
{
return c.done();
}
@@ -75,7 +49,7 @@ struct state
}
};
-unique_coroutine counter(int from, int to)
+task counter(int from, int to)
{
state s;
for (int i = from; i < to; ++i)
@@ -93,7 +67,7 @@ unique_coroutine counter(int from, int to)
int main()
{
- std::vector<unique_coroutine> tasks;
+ std::vector<task> tasks;
tasks.push_back(counter(1,3));
tasks.push_back(counter(4,5));
tasks.push_back(counter(0,7));
diff --git a/src/coroutine_owner.hpp b/src/coroutine_owner.hpp
@@ -0,0 +1,69 @@
+#pragma once
+
+#include <coroutine> // coroutine_handle
+#include <utility> // exchange
+
+template <typename T = void>
+struct coroutine_owner
+{
+ coroutine_owner() = default;
+
+ explicit coroutine_owner(std::coroutine_handle<T> h)
+ : h{h}
+ {}
+
+ coroutine_owner(coroutine_owner const&) = delete;
+ coroutine_owner& operator =(coroutine_owner const&) = delete;
+
+ coroutine_owner(coroutine_owner&& old)
+ : h{std::exchange(old.h, {})}
+ {};
+
+ coroutine_owner& operator =(coroutine_owner&& old)
+ {
+ if (h)
+ h.destroy();
+ h = std::exchange(old.h, {});
+ return *this;
+ }
+
+ ~coroutine_owner()
+ {
+ if (h)
+ h.destroy();
+ }
+
+ std::coroutine_handle<T> release()
+ {
+ return std::exchange(h, {});
+ }
+
+ std::coroutine_handle<T> const* operator ->() const
+ {
+ return &h;
+ }
+
+ std::coroutine_handle<T>* operator ->()
+ {
+ return &h;
+ }
+
+ std::coroutine_handle<T> const& operator *() const
+ {
+ return h;
+ }
+
+ std::coroutine_handle<T>& operator *()
+ {
+ return h;
+ }
+
+private:
+ std::coroutine_handle<T> h;
+};
+
+template <typename T, typename U = T>
+coroutine_owner<U> make_coroutine_owner(T& promise)
+{
+ return coroutine_owner<U>{std::coroutine_handle<T>::from_promise(promise)};
+}