io_uring.cpp (5812B)
1 #include "coroutine_owner.hpp" 2 #include "fd.hpp" 3 4 #include <liburing.h> 5 #include <sys/timerfd.h> 6 7 #include <cstring> 8 #include <print> 9 #include <span> 10 #include <map> 11 #include <memory> 12 #include <cassert> 13 14 struct task 15 { 16 bool done() const noexcept 17 { 18 return coro->done(); 19 } 20 21 struct promise_type 22 { 23 task get_return_object(void) noexcept 24 { 25 return task{make_coroutine_owner(*this)}; 26 } 27 auto initial_suspend(void) const noexcept 28 { 29 return std::suspend_never{}; 30 } 31 auto final_suspend(void) const noexcept 32 { 33 return std::suspend_always{}; 34 } 35 void unhandled_exception(void) const noexcept 36 { 37 std::terminate(); 38 } 39 void return_void(void) const noexcept 40 {} 41 }; 42 43 coroutine_owner<promise_type> coro; 44 }; 45 46 struct uring 47 { 48 struct awaitable; 49 private: 50 struct rendezvous; 51 public: 52 uring(unsigned entries, unsigned flags) 53 { 54 if (io_uring_queue_init(entries, &ring, flags)) 55 throw std::system_error(errno, std::system_category()); 56 } 57 58 ~uring(void) 59 { 60 assert(inflight.empty()); 61 io_uring_queue_exit(&ring); 62 } 63 64 uring(const uring&) = delete; 65 uring& operator =(const uring&) = delete; 66 67 uint64_t wait_and_complete(void); 68 bool process_completion(void); 69 awaitable read(int, std::span<char>, off_t); 70 awaitable write(int, std::span<char const>, off_t); 71 72 private: 73 std::pair<io_uring_sqe*, awaitable> get_sqe(void) 74 { 75 auto sqe = io_uring_get_sqe(&ring); 76 assert(sqe); 77 io_uring_sqe_set_data64(sqe, ++last_id); 78 return {sqe, {*this , std::make_shared<rendezvous>(), last_id}}; 79 } 80 81 void submit(std::shared_ptr<rendezvous> r) 82 { 83 int n = io_uring_submit(&ring); 84 if (n < 0) 85 throw std::system_error(-n, std::system_category()); 86 assert(n == 1); 87 assert(not inflight.contains(last_id)); 88 inflight[last_id] = std::move(r); 89 } 90 91 struct rendezvous 92 { 93 std::optional<int> result = std::nullopt; 94 std::coroutine_handle<> continuation = {}; 95 }; 96 97 public: 98 struct awaitable 99 { 100 awaitable(uring& r, std::shared_ptr<rendezvous> p, uint64_t i) noexcept 101 : ring{r} 102 , rp{std::move(p)} 103 , id{i} 104 {} 105 ~awaitable() noexcept 106 { 107 if (!rp || await_ready()) 108 return; 109 110 rp->continuation = {}; 111 112 io_uring_sync_cancel_reg reg = { 113 .addr = id, 114 .fd = 0, 115 .flags = 0, 116 .timeout = {-1, -1}, 117 .opcode = 0, 118 .pad = {}, 119 .pad2 = {}, 120 }; 121 int result = io_uring_register_sync_cancel(&ring.ring, ®); 122 switch (result) 123 { 124 case 0: // Job was cancelled 125 ring.inflight.erase(id); 126 break; 127 case -ENOENT: // No job was found: result in completion queue. Continuation already removed, so nothing to do. 128 break; 129 case -EINVAL: // bad args 130 assert(false); 131 break; 132 case -EALREADY: // not sure what to do here, can't cancel, can't finish destructor until complete? 133 while (ring.wait_and_complete() != id) 134 {} 135 break; 136 default: 137 std::println("cancel: {}", result); 138 assert(false); 139 } 140 } 141 awaitable(awaitable const&) = delete; 142 awaitable& operator =(awaitable const&) = delete; 143 awaitable(awaitable&& old) 144 : ring{old.ring} 145 , rp{std::move(old.rp)} 146 , id{old.id} 147 { 148 //std::println("Awaitable {} moved", id); 149 } 150 bool await_ready(void) const noexcept 151 { 152 return rp->result.has_value(); 153 } 154 void await_suspend(std::coroutine_handle<> h) const noexcept 155 { 156 rp->continuation = h; 157 } 158 int await_resume() const 159 { 160 //std::println("Job {} resumed", id); 161 return rp->result.value(); 162 } 163 // private: 164 uring& ring; 165 std::shared_ptr<rendezvous> rp; 166 uint64_t id; 167 }; 168 169 private: 170 io_uring ring; 171 std::map<uint64_t, std::shared_ptr<rendezvous>> inflight; 172 uint64_t last_id = 0; 173 }; 174 175 176 uint64_t uring::wait_and_complete(void) 177 { 178 assert(not inflight.empty()); 179 180 io_uring_cqe* cqe; 181 if (io_uring_wait_cqe(&ring, &cqe)) 182 throw std::system_error(errno, std::system_category()); 183 184 uint64_t id = io_uring_cqe_get_data64(cqe); 185 auto it = inflight.find(id); 186 assert(it != inflight.end()); 187 auto s = std::move(it->second); 188 inflight.erase(it); 189 190 assert(not s->result); 191 s->result = cqe->res; 192 io_uring_cq_advance(&ring, 1); 193 194 if (s->continuation) 195 s->continuation(); 196 197 return id; 198 } 199 200 bool uring::process_completion(void) 201 { 202 if (inflight.empty()) 203 return false; 204 205 wait_and_complete(); 206 return true; 207 } 208 209 uring::awaitable uring::read(int fd, std::span<char> buf, off_t offset = 0) 210 { 211 auto [sqe, sub] = get_sqe(); 212 io_uring_prep_read(sqe, fd, buf.data(), buf.size(), offset); 213 submit(sub.rp); 214 return sub; 215 } 216 217 uring::awaitable uring::write(int fd, std::span<char const> buf, off_t offset = 0) 218 { 219 auto [sqe, sub] = get_sqe(); 220 io_uring_prep_write(sqe, fd, buf.data(), buf.size(), offset); 221 submit(sub.rp); 222 return sub; 223 } 224 225 task read_routine(uring& u) 226 { 227 char b[128]; 228 for (;;) 229 { 230 std::span<char> buf(b); 231 int n = co_await u.read(0, buf); 232 if (n == 0) 233 co_return; 234 if (n < 0) 235 throw std::system_error(-n, std::system_category()); 236 237 buf = buf.subspan(0, n); 238 239 write(1, buf.data(), buf.size()); 240 continue; 241 242 while (!buf.empty()) 243 { 244 n = /*co_await u.*/write(1, buf.data(), buf.size()); 245 if (n <= 0) 246 throw std::system_error(-n, std::system_category()); 247 248 buf = buf.subspan(n); 249 } 250 } 251 } 252 253 task timer_routine(uring& u) 254 { 255 auto fd = file_descriptor{timerfd_create(CLOCK_MONOTONIC, 0)}; 256 assert(fd); 257 258 itimerspec ts 259 { 260 {2, 0}, 261 {2, 0}, 262 }; 263 264 timerfd_settime(fd, 0, &ts, nullptr); 265 266 for (uint64_t i = 0;;) 267 { 268 uint64_t timeouts; 269 auto s = std::span{reinterpret_cast<char*>(&timeouts), sizeof timeouts}; 270 assert(co_await u.read(fd, s) == sizeof timeouts); 271 272 i += timeouts; 273 274 auto x = std::format("Timeout {}\n", i); 275 int n = co_await u.write(1, std::span<char>{x.data(), x.size()}); 276 if (n <= 0) 277 throw std::system_error(-n, std::system_category()); 278 } 279 } 280 281 int main() 282 { 283 uring r(128, 0); 284 285 task rr = read_routine(r); 286 task tr = timer_routine(r); 287 288 while (not rr.done()) 289 r.process_completion(); 290 291 return 0; 292 }