commit 2c8fcb8dd6e3b67ffebf51f4d0ca0c612cd428b6
parent a67325eb6c28db74bb43e9197ba5eaa4afb3faf9
Author: Henry Wilson <henry@henryandlizzy.uk>
Date:   Fri, 14 Jul 2023 23:51:18 +0100
coroutine_owner: add owning version of std::coroutine_handle
Diffstat:
7 files changed, 282 insertions(+), 148 deletions(-)
diff --git a/src/coro-generator-consumer.cpp b/src/coro-generator-consumer.cpp
@@ -1,9 +1,9 @@
-#include <coroutine> // coroutine_handle noop_coroutine suspend_always suspend_never
+#include "coroutine_owner.hpp" // coroutine_owner noop_coroutine suspend_always suspend_never
+
 #include <exception> // terminate
 #include <iostream> // cout
 #include <optional> // optional
 #include <cassert> // assert
-#include <utility> // exchange
 
 struct sentry
 {
@@ -23,20 +23,16 @@ struct symmetric_transfer : std::suspend_always
 struct generator
 {
 	struct promise_type;
-	~generator()
-	{
-		h.destroy();
-	}
-	promise_type& promise() { return h.promise(); }
+	promise_type& promise() { return h->promise(); }
 
-	std::coroutine_handle<promise_type> h;
+	coroutine_owner<promise_type> h;
 };
 
 struct generator::promise_type
 {
 	promise_type() { std::cout << __PRETTY_FUNCTION__ << '\n'; }
 	~promise_type() { std::cout << __PRETTY_FUNCTION__ << '\n'; }
-	generator get_return_object() { return {std::coroutine_handle<promise_type>::from_promise(*this)}; }
+	generator get_return_object() { return {make_coroutine_owner<promise_type>(*this)}; }
 	std::suspend_always initial_suspend() { return {}; }
 	symmetric_transfer yield_value(unsigned v)
 	{
diff --git a/src/coro-poll.cpp b/src/coro-poll.cpp
@@ -1,10 +1,12 @@
+#include "coroutine_owner.hpp"
+
 #include <unistd.h>
 #include <poll.h>
 #include <sys/timerfd.h>
 
-#include <coroutine>
 #include <iostream>
 #include <map>
+#include <cstdint>
 
 static void check_err(char const* msg)
 {
@@ -36,12 +38,14 @@ struct coroutine_task
 {
 	struct promise_type
 	{
-		void get_return_object(void) const noexcept {}
+		coroutine_task get_return_object(void) noexcept { return {make_coroutine_owner(*this)}; }
 		std::suspend_never initial_suspend(void) const noexcept { return {}; }
-		std::suspend_never final_suspend() const noexcept { return {}; }
-		void unhandled_exception() const noexcept {}
+		std::suspend_always final_suspend() const noexcept { return {}; }
+		void unhandled_exception() const noexcept { std::terminate(); }
 		void return_void(void) const noexcept {}
 	};
+
+	coroutine_owner<promise_type> c;
 };
 
 struct poll_scheduler
@@ -50,12 +54,6 @@ struct poll_scheduler
 	poll_scheduler(poll_scheduler const&) = delete;
 	poll_scheduler& operator =(poll_scheduler const&) = delete;
 
-	~poll_scheduler()
-	{
-		for (auto& [_, h] : waiting)
-			h.destroy();
-	}
-
 	struct awaitable : std::suspend_always
 	{
 		poll_scheduler& scheduler;
@@ -158,9 +156,9 @@ int main()
 {
 	poll_scheduler s;
 
-	echo_task(0, s);
-	timer_task(4, s);
-	timer_task(8, s);
+	auto task1 = echo_task(0, s);
+	auto task2 = timer_task(4, s);
+	auto task3 = timer_task(8, s);
 
 	s.flush();
 }
diff --git a/src/coro-round-robin.cpp b/src/coro-round-robin.cpp
@@ -1,10 +1,11 @@
-#include <coroutine>
+#include "coroutine_owner.hpp"
+
 #include <iostream>
 #include <deque>
 
 struct round_robin_scheduler : std::suspend_always
 {
-	std::deque<std::coroutine_handle<>> tasks;
+	std::deque<coroutine_owner<>> tasks;
 
 	bool await_ready(void) const noexcept
 	{
@@ -13,12 +14,12 @@ struct round_robin_scheduler : std::suspend_always
 
 	void schedule(std::coroutine_handle<> h)
 	{
-		tasks.push_back(h);
+		tasks.push_back(coroutine_owner<>{h});
 	}
 
 	std::coroutine_handle<> next(void) noexcept
 	{
-		auto h = tasks.front();
+		auto h = tasks.front().release();
 		tasks.pop_front();
 		return h;
 	}
@@ -35,9 +36,6 @@ struct round_robin_scheduler : std::suspend_always
 	{
 		if (not tasks.empty())
 			std::cout << "Finishing early, cleaning up state of suspended tasks..." << std::endl;
-
-		for (auto h : tasks)
-			h.destroy();
 	}
 
 	~round_robin_scheduler()
diff --git a/src/coro-state-machine.cpp b/src/coro-state-machine.cpp
@@ -0,0 +1,109 @@
+#include "coroutine_owner.hpp"
+
+#include <exception>
+#include <cassert>
+#include <iostream>
+
+using std::cerr, std::endl;
+
+struct task
+{
+	struct promise_type
+	{
+		task get_return_object() { return make_coroutine_owner(*this); }
+		std::suspend_never initial_suspend() { return {}; }
+		void return_value(task);
+		void unhandled_exception() { std::terminate(); }
+		std::suspend_always final_suspend() noexcept { return {}; }
+		coroutine_owner<promise_type> continuation;
+	};
+
+	task() = default;
+	task(coroutine_owner<promise_type> c)
+	:	h{std::move(c)}
+	{}
+
+	coroutine_owner<promise_type> h;
+
+	struct await_done
+	{
+		bool await_ready();
+		void await_suspend(std::coroutine_handle<> s);
+		coroutine_owner<promise_type> await_resume();
+		std::coroutine_handle<> awaiting;
+		promise_type& promise;
+	};
+	await_done done() { return {{}, h->promise()}; }
+};
+
+void task::promise_type::return_value(task t) { continuation = std::move(t.h); }
+
+
+bool task::await_done::await_ready()
+{
+	return false;
+}
+void task::await_done::await_suspend(std::coroutine_handle<> s)
+{
+	assert(not awaiting);
+	awaiting = s;
+}
+coroutine_owner<task::promise_type> task::await_done::await_resume()
+{
+	awaiting = {};
+	return std::exchange(promise.continuation, {});
+}
+
+struct awaitable
+{
+	std::coroutine_handle<> awaiting;
+	bool await_ready() { return false; }
+	void await_suspend(std::coroutine_handle<> s) { assert(not std::exchange(awaiting, s)); }
+	void await_resume() {}
+
+	void event() {
+		std::exchange(awaiting, {}).resume();
+	}
+};
+
+task state_machine(awaitable& event);
+
+task loop_machine(awaitable& event)
+{
+	cerr << "init loop\n";
+	for (unsigned i = 0; i < 3; ++i)
+	{
+		co_await event;
+		cerr << "loop!\n";
+	}
+	co_return state_machine(event);
+}
+
+task state_machine(awaitable& event)
+{
+	cerr << "init\n";
+	co_await event;
+	cerr << "button pressed\n";
+	co_await event;
+	cerr << "button pressed 2\n";
+	co_await event;
+	cerr << "button pressed 3\n";
+	co_return loop_machine(event);
+}
+
+task state_engine(task t)
+{
+	cerr << "se init\n";
+	for (;;)
+		t = task{co_await t.done()};
+}
+
+int main()
+{
+	awaitable event;
+	auto h = state_machine(event);
+	auto e = state_engine(std::move(h));
+	for (unsigned i = 0; i < 10; ++i)
+		event.event();
+	cerr << "Done!\n";
+}
diff --git a/src/coro-throwing.cpp b/src/coro-throwing.cpp
@@ -1,98 +1,88 @@
-#include <coroutine>
-#include <iostream>
-#include <optional>
-#include <stdexcept>
+#include "coroutine_owner.hpp"
 
-using namespace std;
+#include <stdexcept>
+#include <utility>
+#include <iostream>
 
-struct coroutine
+struct coro
 {
-	struct promise_type;
-	coroutine_handle<> h;
-	~coroutine()
+	struct promise_type
 	{
-		if (not h.done())
-			h.destroy();
-	}
-};
-
-struct coroutine::promise_type
-{
-	coroutine get_return_object() noexcept { return {coroutine_handle<promise_type>::from_promise(*this)}; }
-	suspend_never initial_suspend() noexcept { return {}; }
-	suspend_always final_suspend() noexcept { return {}; }
-	void return_void() noexcept {}
-	void unhandled_exception() { throw; }
-};
+		coro get_return_object() { return {*this}; }
+		std::suspend_never initial_suspend() { return {}; }
+		std::suspend_always final_suspend() noexcept { return {}; }
+		void unhandled_exception() noexcept { e = std::current_exception(); }
+		void return_void() noexcept {}
+		std::exception_ptr e;
+	};
 
+	coroutine_owner<promise_type> h;
 
-struct awaitable : suspend_always
-{
-	coroutine_handle<> suspended;
-	void await_suspend(coroutine_handle<> h)
+	void result()
 	{
-		suspended = h;
+		if (not *h)
+			throw std::logic_error("cannot get result() from empty coroutine");
+		else if (not h->done())
+			throw std::logic_error("cannot get result() from unfinished coroutine");
+		else if (auto e = std::exchange(h->promise().e, {}))
+			std::rethrow_exception(e);
+		std::cerr << "coroutine finished successfully.\n";
 	}
-} awaitable;
-
-coroutine no_throw()
-{
-	cerr << "no_throw()\n";
-	co_return;
-}
-
-
-coroutine initial_throw()
-{
-	cerr << "\ninitial_throw()\n";
-	throw runtime_error("initial_throw");
-	co_return;
-}
 
-coroutine resume_throw()
-{
-	cerr << "\nresume_throw()\n";
-	for (;;)
+	coro() = default;
+	coro(promise_type& p)
+	:   h{std::coroutine_handle<promise_type>::from_promise(p)}
+	{}
+	~coro()
 	{
-		co_await awaitable;
-		cerr << "resume_throw() RESUMED\n";
-		throw runtime_error("resume_throw");
+		cancel();
 	}
-}
-
-coroutine resume_nothrow()
-{
-	cerr << "\nresume_nothrow()\n";
-	for (;;)
+	void cancel() noexcept
 	{
-		co_await awaitable;
-		cerr << "resume_nothrow() RESUMED\n";
+		if (auto x = std::exchange(h, {}); *x)
+			if (auto& e = x->promise().e)
+				try
+				{
+					std::rethrow_exception(e);
+				} catch (std::exception const& e)
+				{
+					std::cerr << "Dropped exception: '" << e.what() << "'\n";
+				} catch (...) {
+					std::cerr << "Dropped exception\n";
+				}
 	}
-}
-
-int main()
-{
-	try {
-		{
-			no_throw();
-		}
-		[[maybe_unused]] auto coro = initial_throw();
-	} catch (exception& e)
+	coro(coro const&) = delete;
+	coro& operator =(coro const&) = delete;
+	coro(coro&& old)
+	:   h{std::move(old.h)}
+	{}
+	coro& operator =(coro&& old)
 	{
-		cerr << "Exception caught: " << e.what() << endl;
+		cancel();
+		h = std::exchange(old.h, {});
+		return *this;
 	}
+};
 
-	try {
-		{
-			[[maybe_unused]] auto coro = resume_nothrow();
-			cerr << "  suspend OK\n";
-			awaitable.suspended.resume();
-		}
-		[[maybe_unused]] auto coro = resume_throw();
-		cerr << "  suspend OK\n";
-		awaitable.suspended.resume();
-	} catch (exception& e)
-	{
-		cerr << "Exception caught: " << e.what() << endl;
-	}
+coro mycoro(bool success)
+{
+	if (success)
+		co_return;
+	throw std::runtime_error("hi");
+}
+
+int main() try
+{
+	auto a = mycoro(true);
+	a.result();
+	a = mycoro(false);
+	auto c = mycoro(false);
+	auto c2 = std::move(c);
+	a = std::move(c2);
+	a.result();
+}
+catch (std::exception const& e)
+{
+	std::cerr << "Caught exception: '" << e.what() << "'\n";
+	return 1;
 }
diff --git a/src/coro-unconditional-dispatch.cpp b/src/coro-unconditional-dispatch.cpp
@@ -1,68 +1,42 @@
+#include "coroutine_owner.hpp"
+
 #include <coroutine>
 #include <iostream>
 #include <vector>
 
-struct unique_coroutine
+struct task
 {
     struct promise_type
     {
         ~promise_type()
         {}
-        unique_coroutine get_return_object() { return *this; }
+        task get_return_object() { return {make_coroutine_owner(*this)}; }
         std::suspend_always initial_suspend() { return {}; }
         std::suspend_always final_suspend() noexcept { return {}; }
         void unhandled_exception() {}
         void return_void() {}
     };
 
-    unique_coroutine(promise_type& p)
-    :   h(std::coroutine_handle<promise_type>::from_promise(p))
+    task(coroutine_owner<promise_type> c)
+    :   h{std::move(c)}
     {}
 
-    ~unique_coroutine()
-    {
-        if (h)
-            h.destroy();
-    }
-
-    unique_coroutine(unique_coroutine const&) = delete;
-    unique_coroutine& operator =(unique_coroutine const&) = delete;
-
-    unique_coroutine(unique_coroutine&& old)
-    :   h(old.h)
-    {
-        old.h = h.from_address(nullptr);
-    }
-
-    unique_coroutine& operator =(unique_coroutine&& old)
-    {
-        if (&old == this)
-            return *this;
-
-        if (h)
-            h.destroy();
-
-        h = old.h;
-        old.h = h.from_address(nullptr);
-        return *this;
-    }
-
     bool done(void) const
     {
-        return h.done();
+        return h->done();
     }
 
     bool operator()(void)
     {
-        h();
+        h->resume();
         return done();
     }
 
 private:
-    std::coroutine_handle<> h;
+    coroutine_owner<promise_type> h;
 };
 
-bool done(unique_coroutine const& c)
+bool done(task const& c)
 {
     return c.done();
 }
@@ -75,7 +49,7 @@ struct state
 	}
 };
 
-unique_coroutine counter(int from, int to)
+task counter(int from, int to)
 {
 	state s;
 	for (int i = from; i < to; ++i)
@@ -93,7 +67,7 @@ unique_coroutine counter(int from, int to)
 
 int main()
 {
-	std::vector<unique_coroutine> tasks;
+	std::vector<task> tasks;
 	tasks.push_back(counter(1,3));
 	tasks.push_back(counter(4,5));
 	tasks.push_back(counter(0,7));
diff --git a/src/coroutine_owner.hpp b/src/coroutine_owner.hpp
@@ -0,0 +1,69 @@
+#pragma once
+
+#include <coroutine> // coroutine_handle
+#include <utility> // exchange
+
+template <typename T = void>
+struct coroutine_owner
+{
+	coroutine_owner() = default;
+
+	explicit coroutine_owner(std::coroutine_handle<T> h)
+	:	h{h}
+	{}
+
+	coroutine_owner(coroutine_owner const&) = delete;
+	coroutine_owner& operator =(coroutine_owner const&) = delete;
+
+	coroutine_owner(coroutine_owner&& old)
+	:   h{std::exchange(old.h, {})}
+	{};
+
+	coroutine_owner& operator =(coroutine_owner&& old)
+	{
+		if (h)
+			h.destroy();
+		h = std::exchange(old.h, {});
+		return *this;
+	}
+
+	~coroutine_owner()
+	{
+		if (h)
+			h.destroy();
+	}
+
+	std::coroutine_handle<T> release()
+	{
+		return std::exchange(h, {});
+	}
+
+	std::coroutine_handle<T> const* operator ->() const
+	{
+		return &h;
+	}
+
+	std::coroutine_handle<T>* operator ->()
+	{
+		return &h;
+	}
+
+	std::coroutine_handle<T> const& operator *() const
+	{
+		return h;
+	}
+
+	std::coroutine_handle<T>& operator *()
+	{
+		return h;
+	}
+
+private:
+	std::coroutine_handle<T> h;
+};
+
+template <typename T, typename U = T>
+coroutine_owner<U> make_coroutine_owner(T& promise)
+{
+	return coroutine_owner<U>{std::coroutine_handle<T>::from_promise(promise)};
+}