commit 22802b32fb1c96c3c785b5c43eb3fed162bdff2d
parent 2ca7d44946664efa74301c04a675754acb673fce
Author: Henry Wilson <henry@henryandlizzy.uk>
Date: Fri, 8 Nov 2024 17:39:53 +0000
io_uring: keep track of inflight jobs and cancel with destructors
Diffstat:
4 files changed, 276 insertions(+), 94 deletions(-)
diff --git a/src/coro-poll.cpp b/src/coro-poll.cpp
@@ -1,4 +1,5 @@
#include "coroutine_owner.hpp"
+#include "fd.hpp"
#include <unistd.h>
#include <poll.h>
@@ -15,26 +16,6 @@ static void check_err(char const* msg)
throw std::system_error(errno, std::generic_category(), msg);
}
-struct file_descriptor
-{
- explicit file_descriptor(int f) noexcept
- : fd{f}
- {}
- ~file_descriptor()
- {
- std::cout << "close fd[" << fd << "]\n";
- ::close(fd);
- check_err("close");
- }
-
- file_descriptor(file_descriptor const&) = delete;
- file_descriptor& operator =(file_descriptor const&) = delete;
-
- operator int() noexcept { return fd; };
-
- int fd;
-};
-
struct coroutine_task
{
struct promise_type
diff --git a/src/coroutine_owner.hpp b/src/coroutine_owner.hpp
@@ -16,7 +16,7 @@ struct [[nodiscard]] coroutine_owner
coroutine_owner& operator =(coroutine_owner const&) = delete;
coroutine_owner(coroutine_owner&& old)
- : h{std::exchange(old.h, {})}
+ : h{old.release()}
{};
coroutine_owner& operator =(coroutine_owner&& old)
diff --git a/src/fd.hpp b/src/fd.hpp
@@ -0,0 +1,57 @@
+#pragma once
+
+#include <unistd.h>
+#include <utility>
+
+struct file_descriptor
+{
+ explicit file_descriptor(int f) noexcept
+ : fd{f}
+ {}
+
+ ~file_descriptor() noexcept
+ {
+ if (*this)
+ ::close(fd);
+ }
+
+ file_descriptor(file_descriptor const&) = delete;
+ file_descriptor& operator =(file_descriptor const&) = delete;
+
+ file_descriptor(file_descriptor&& old)
+ : fd{old.release()}
+ {}
+
+ file_descriptor& operator =(file_descriptor&& old)
+ {
+ reset(old.release());
+ return *this;
+ }
+
+ explicit operator bool() const noexcept
+ {
+ return fd >= 0;
+ }
+
+ int get() const noexcept {
+ return fd;
+ };
+
+ operator int() const noexcept {
+ return fd;
+ };
+
+ int release() noexcept
+ {
+ return std::exchange(fd, -1);
+ }
+
+ void reset(int f = -1) noexcept
+ {
+ file_descriptor{std::exchange(fd, f)};
+ }
+
+private:
+ int fd = -1;
+};
+
diff --git a/src/io_uring.cpp b/src/io_uring.cpp
@@ -1,13 +1,55 @@
+#include "coroutine_owner.hpp"
+#include "fd.hpp"
+
#include <liburing.h>
-#include <iostream>
-#include <coroutine>
+#include <sys/timerfd.h>
+
+#include <cstring>
+#include <print>
#include <span>
+#include <map>
+#include <memory>
+#include <cassert>
-struct awaitable;
+struct task
+{
+ bool done() const noexcept
+ {
+ return coro->done();
+ }
+
+ struct promise_type
+ {
+ task get_return_object(void) noexcept
+ {
+ return task{make_coroutine_owner(*this)};
+ }
+ auto initial_suspend(void) const noexcept
+ {
+ return std::suspend_never{};
+ }
+ auto final_suspend(void) const noexcept
+ {
+ return std::suspend_always{};
+ }
+ void unhandled_exception(void) const noexcept
+ {
+ std::terminate();
+ }
+ void return_void(void) const noexcept
+ {}
+ };
+
+ coroutine_owner<promise_type> coro;
+};
struct uring
{
- uring(unsigned entries, unsigned flags)
+ struct awaitable;
+private:
+ struct rendezvous;
+public:
+ uring(unsigned entries, unsigned flags)
{
if (io_uring_queue_init(entries, &ring, flags))
throw std::system_error(errno, std::system_category());
@@ -15,127 +57,229 @@ struct uring
~uring(void)
{
+ assert(inflight.empty());
io_uring_queue_exit(&ring);
}
uring(const uring&) = delete;
uring& operator =(const uring&) = delete;
- io_uring_sqe* get_sqe(void)
+ bool process_completion(void);
+ awaitable read(int, std::span<char>, off_t);
+ awaitable write(int, std::span<char const>, off_t);
+
+private:
+ std::pair<io_uring_sqe*, awaitable> get_sqe(void)
{
- return io_uring_get_sqe(&ring);
+ auto sqe = io_uring_get_sqe(&ring);
+ assert(sqe);
+ io_uring_sqe_set_data64(sqe, ++last_id);
+ return {sqe, {*this , std::make_shared<rendezvous>(), last_id}};
}
- int submit(void)
+ void submit(std::shared_ptr<rendezvous> r)
{
- return io_uring_submit(&ring);
+ int n = io_uring_submit(&ring);
+ if (n < 0)
+ throw std::system_error(-n, std::system_category());
+ assert(n == 1);
+ inflight[last_id] = std::move(r);
}
- bool process_completion(void);
- awaitable read(int, std::span<char>, unsigned);
- awaitable write(int, std::span<char const>, unsigned);
- unsigned n = 0;
-private:
- io_uring ring;
-};
-
-struct task
-{
- struct promise_type
+ struct rendezvous
{
- int res;
+ std::optional<int> result = std::nullopt;
+ std::coroutine_handle<> continuation = {};
+ };
- void get_return_object(void) const noexcept
+public:
+ struct awaitable
+ {
+ awaitable(uring& r, std::shared_ptr<rendezvous> p, uint64_t i) noexcept
+ : ring{r}
+ , rp{std::move(p)}
+ , id{i}
{}
- auto initial_suspend(void) const noexcept
+ ~awaitable() noexcept
{
- return std::suspend_never{};
+ if (!rp || await_ready())
+ return;
+
+ rp->continuation = {};
+
+ io_uring_sync_cancel_reg reg = {
+ .addr = id,
+ .fd = 0,
+ .flags = 0,
+ .timeout = {-1, -1},
+ .opcode = 0,
+ .pad = {},
+ .pad2 = {},
+ };
+ int result = io_uring_register_sync_cancel(&ring.ring, ®);
+ switch (result)
+ {
+ case 0: // Job was cancelled
+ ring.inflight.erase(id);
+ break;
+ case -ENOENT: // No job was found: result in completion queue. Continuation already removed, so nothing to do.
+ break;
+ case -EINVAL: // bad args
+ assert(false);
+ break;
+ case -EALREADY: // not sure what to do here, can't cancel, can't finish destructor until complete?
+ std::println("cancel: EALREADY");
+ break;
+ default:
+ std::println("cancel: {}", result);
+ assert(false);
+ }
}
- auto final_suspend(void) const noexcept
+ awaitable(awaitable const&) = delete;
+ awaitable& operator =(awaitable const&) = delete;
+ awaitable(awaitable&& old)
+ : ring{old.ring}
+ , rp{std::move(old.rp)}
+ , id{old.id}
{
- return std::suspend_always{};
+ //std::println("Awaitable {} moved", id);
}
- void unhandled_exception(void) const noexcept
+ bool await_ready(void) const noexcept
{
- std::terminate();
+ return rp->result.has_value();
}
- void return_void(void) const noexcept
- {}
+ void await_suspend(std::coroutine_handle<> h) const noexcept
+ {
+ rp->continuation = h;
+ }
+ int await_resume() const
+ {
+ //std::println("Job {} resumed", id);
+ return rp->result.value();
+ }
+// private:
+ uring& ring;
+ std::shared_ptr<rendezvous> rp;
+ uint64_t id;
};
-};
-
-struct awaitable
-{
- uring& u;
- io_uring_sqe& sqe;
- task::promise_type* p;
- bool await_ready(void) const noexcept
- {
- return false;
- }
- void await_suspend(std::coroutine_handle<task::promise_type> h)
- {
- p = &h.promise();
- io_uring_sqe_set_data(&sqe, h.address());
- u.submit();
- u.n++;
- }
- int await_resume()
- {
- return p->res;
- }
+private:
+ io_uring ring;
+ std::map<uint64_t, std::shared_ptr<rendezvous>> inflight;
+ uint64_t last_id = 0;
};
+
bool uring::process_completion(void)
{
- if (!n)
+ if (inflight.empty())
return false;
- --n;
io_uring_cqe* cqe;
if (io_uring_wait_cqe(&ring, &cqe))
throw std::system_error(errno, std::system_category());
- auto addr = io_uring_cqe_get_data(cqe);
- auto h = std::coroutine_handle<task::promise_type>::from_address(addr);
- h.promise().res = cqe->res;
- h.resume();
+ //std::println("Job {} completed", io_uring_cqe_get_data64(cqe));
+ auto it = inflight.find(io_uring_cqe_get_data64(cqe));
+ assert(it != inflight.end());
+ auto& s = *it->second;
+
+ assert(not s.result);
+ s.result = cqe->res;
+ io_uring_cq_advance(&ring, 1);
+
+ if (s.continuation)
+ s.continuation();
+
+ inflight.erase(it);
return true;
}
-awaitable uring::read(int fd, std::span<char> buf, unsigned flags)
+uring::awaitable uring::read(int fd, std::span<char> buf, off_t offset = 0)
{
- auto sqe = get_sqe();
- io_uring_prep_read(sqe, fd, buf.data(), buf.size(), flags);
- return {*this, *sqe, {}};
+ auto [sqe, sub] = get_sqe();
+ io_uring_prep_read(sqe, fd, buf.data(), buf.size(), offset);
+ submit(sub.rp);
+ return sub;
}
-awaitable uring::write(int fd, std::span<char const> buf, unsigned flags)
+
+uring::awaitable uring::write(int fd, std::span<char const> buf, off_t offset = 0)
{
- auto sqe = get_sqe();
- io_uring_prep_write(sqe, fd, buf.data(), buf.size(), flags);
- return {*this, *sqe, {}};
+ auto [sqe, sub] = get_sqe();
+ io_uring_prep_write(sqe, fd, buf.data(), buf.size(), offset);
+ submit(sub.rp);
+ return sub;
}
task read_routine(uring& u)
{
char b[128];
- std::span<char> buf(b);
- int n = co_await u.read(0, buf, 0);
- if (n > 0)
- co_await u.write(1, {b, (std::size_t)n}, 0);
- co_return;
+ for (;;)
+ {
+ std::span<char> buf(b);
+ int n = co_await u.read(0, buf);
+ if (n == 0)
+ co_return;
+ if (n < 0)
+ throw std::system_error(-n, std::system_category());
+
+ buf = buf.subspan(0, n);
+
+ write(1, buf.data(), buf.size());
+ continue;
+
+ while (!buf.empty())
+ {
+ n = /*co_await u.*/write(1, buf.data(), buf.size());
+ if (n <= 0)
+ throw std::system_error(-n, std::system_category());
+
+ buf = buf.subspan(n);
+ }
+ }
+}
+
+task timer_routine(uring& u)
+{
+ auto fd = file_descriptor{timerfd_create(CLOCK_MONOTONIC, 0)};
+ assert(fd);
+
+ itimerspec ts
+ {
+ {2, 0},
+ {2, 0},
+ };
+
+ timerfd_settime(fd, 0, &ts, nullptr);
+
+ for (uint64_t i = 0;;)
+ {
+ uint64_t timeouts;
+ auto s = std::span{reinterpret_cast<char*>(&timeouts), sizeof timeouts};
+ assert(co_await u.read(fd, s) == sizeof timeouts);
+
+ i += timeouts;
+
+ auto x = std::format("Timeout {}\n", i);
+ //std::println("{}", x);
+ int n = co_await u.write(1, std::span<char>{x.data(), x.size()});
+ if (n <= 0)
+ throw std::system_error(-n, std::system_category());
+ }
}
int main()
{
uring r(128, 0);
- read_routine(r);
+ task rr = read_routine(r);
+ task tr = timer_routine(r);
+ //task tr2 = timer_routine(r);
+ //task tr3 = timer_routine(r);
- while (r.process_completion())
- {}
+ while (not rr.done())
+ r.process_completion();
return 0;
}