protohackers

My solutions to the protohackers.com challenges.
git clone git://henryandlizzy.uk/protohackers
Log | Files | Refs

protohack-1.cpp (7349B)


      1 #include "inet.hpp"
      2 #include "fd.hpp"
      3 
      4 #include <poll.h>
      5 
      6 #include <iostream>
      7 #include <cmath>
      8 #include <map>
      9 #include <vector>
     10 #include <optional>
     11 #include <openssl/bn.h>
     12 
     13 using namespace std;
     14 
     15 struct bignum
     16 {
     17 	bignum()
     18 	:	p(BN_new(), BN_free)
     19 	{
     20 		BN_init(*this);
     21 	}
     22 
     23 	bignum(BIGNUM* bn)
     24 	:	p(bn, BN_free)
     25 	{}
     26 
     27 	bignum(unsigned long w)
     28 	:	bignum()
     29 	{
     30 		BN_set_word(*this, w);
     31 	}
     32 
     33 	bignum(bignum&&) = default;
     34 	bignum& operator=(bignum&&) = default;
     35 
     36 	bignum(bignum const& old)
     37 	:	bignum{BN_dup(old)}
     38 	{}
     39 	bignum& operator=(bignum const& old)
     40 	{
     41 		if (p)
     42 			BN_copy(*this, old);
     43 		else
     44 			p.reset(BN_dup(old));
     45 
     46 		return *this;
     47 	}
     48 
     49 	char const* to_str() const
     50 	{
     51 		return BN_bn2dec(*this);
     52 	}
     53 
     54 	bool is_prime() const
     55 	{
     56 		return BN_is_prime_fasttest_ex(*this, INT_MAX, nullptr, 1, nullptr);
     57 	}
     58 
     59 	bignum operator +(bignum const& rhs)
     60 	{
     61 		bignum r;
     62 		BN_add(r, *this, rhs);
     63 		return r;
     64 	}
     65 	bignum operator -(bignum const& rhs)
     66 	{
     67 		bignum r;
     68 		BN_sub(r, *this, rhs);
     69 		return r;
     70 	}
     71 
     72 private:
     73 	unique_ptr<BIGNUM, void(*)(BIGNUM*)> p;
     74 
     75 	operator BIGNUM const*() const
     76 	{
     77 		return p.get();
     78 	}
     79 	operator BIGNUM*()
     80 	{
     81 		return p.get();
     82 	}
     83 };
     84 
     85 ostream& operator <<(ostream& out, bignum const& rhs)
     86 {
     87 	return out << rhs.to_str();
     88 }
     89 
     90 struct json_object
     91 {
     92 	std::map<string_view, string_view> nv_pairs;
     93 	std::map<string_view, bignum> num_pairs;
     94 };
     95 
     96 template <typename T>
     97 struct parse_t
     98 {
     99 	T val;
    100 	string_view remaining;
    101 };
    102 
    103 std::optional<string_view> parse_char(string_view in, char c)
    104 {
    105 	if (in.empty())
    106 		return {};
    107 	if (in.front() != c)
    108 		return {};
    109 	return in.substr(1);
    110 }
    111 
    112 std::optional<string_view> parse_str(string_view in, string_view str)
    113 {
    114 	if (not in.starts_with(str))
    115 		return {};
    116 	return in.substr(str.size());
    117 }
    118 
    119 string_view parse_whitespace(string_view in)
    120 {
    121 	auto end = in.find_first_not_of(" \t\r\n");
    122 	return end != string::npos ? in.substr(end) : ""sv;
    123 }
    124 
    125 std::optional<parse_t<bool>> parse_bool(string_view in)
    126 {
    127 	if (auto next = parse_str(in, "true"))
    128 		return parse_t<bool>{true, *next};
    129 	if (auto next = parse_str(in, "false"))
    130 		return parse_t<bool>{false, *next};
    131 	return {};
    132 }
    133 
    134 size_t find_end_quote(string_view str)
    135 {
    136 	size_t start = 0;
    137 	for (;;)
    138 	{
    139 		auto pos = str.find_first_of("\"\\", start);
    140 		if ((pos == string::npos) or (str[pos] == '"'))
    141 			return pos;
    142 		start = pos + 2;
    143 		if (start > str.size())
    144 			return string::npos;
    145 	}
    146 }
    147 
    148 std::optional<parse_t<string_view>> parse_string(string_view in)
    149 {
    150 	if (auto next = parse_char(in, '"'))
    151 		in = *next;
    152 	else
    153 		return {};
    154 
    155 	if (in.empty())
    156 		return {};
    157 
    158 	auto end = find_end_quote(in);
    159 	auto str = in.substr(0, end);
    160 	in = end != string::npos ? in.substr(end) : ""sv;
    161 
    162 	if (auto next = parse_char(in, '"'))
    163 		in = *next;
    164 	else
    165 		return {};
    166 
    167 	return parse_t<string_view>{str, in};
    168 }
    169 
    170 std::optional<parse_t<bignum>> parse_bignum(string_view in)
    171 {
    172 	BIGNUM* p = nullptr;
    173 	auto end = in.find_first_not_of("-0123456789");
    174 	string str{in.substr(0, end)};
    175 	int r = BN_dec2bn(&p, str.c_str());
    176 	if (r == 0)
    177 		return {};
    178 	if (r < 0)
    179 		__builtin_trap();
    180 
    181 	in = end != string::npos ? in.substr(end) : ""sv;
    182 
    183 	if (auto next = parse_char(in, '.'))
    184 	{
    185 		in = *next;
    186 	}
    187 	else
    188 		return parse_t<bignum>{bignum{p}, in};
    189 
    190 	end = in.find_first_not_of("0123456789");
    191 	in = end != string::npos ? in.substr(end) : ""sv;
    192 	return parse_t<bignum>{bignum{0UL}, in};
    193 }
    194 
    195 std::optional<parse_t<json_object>> parse_object(string_view in)
    196 {
    197 	json_object o;
    198 
    199 	if (auto next = parse_char(in, '{'))
    200 		in = *next;
    201 	else
    202 		return {};
    203 
    204 	in = parse_whitespace(in);
    205 
    206 	if (auto next = parse_char(in, '}'))
    207 		return parse_t<json_object>{std::move(o), *next};
    208 
    209 	for (;;)
    210 	{
    211 		auto name = parse_string(in);
    212 		if (not name)
    213 			return {};
    214 
    215 		in = parse_whitespace(name->remaining);
    216 
    217 		if (auto next = parse_char(in, ':'))
    218 			in = parse_whitespace(*next);
    219 		else
    220 			return {};
    221 
    222 		if (auto value = parse_string(in))
    223 		{
    224 			auto& val = o.nv_pairs[name->val];
    225 			val = value->val;
    226 			in = value->remaining;
    227 		}
    228 		else if (auto value = parse_bignum(in))
    229 		{
    230 			auto& val = o.num_pairs[name->val];
    231 			val = value->val;
    232 			in = value->remaining;
    233 		}
    234 		else if (auto value = parse_bool(in))
    235 		{
    236 			//auto& val = o.bool_pairs[name->val];
    237 			//val = value->val;
    238 			in = value->remaining;
    239 		}
    240 		else if (auto value = parse_object(in))
    241 		{
    242 			//auto& val = o.obj_pairs[name->val];
    243 			//val = value->val;
    244 			in = value->remaining;
    245 		}
    246 		else
    247 			return {};
    248 
    249 		if (auto next = parse_char(in, ','))
    250 			in = parse_whitespace(*next);
    251 		else
    252 			break;
    253 	}
    254 	if (auto next = parse_char(in, '}'))
    255 		return parse_t<json_object>{std::move(o), *next};
    256 	else
    257 		return {};
    258 }
    259 
    260 ostream& operator <<(ostream& out, json_object const& rhs)
    261 {
    262 	out << '{';
    263 	for (auto const& nv : rhs.nv_pairs)
    264 		out << '"' << nv.first << "\": \"" << nv.second << "\",";
    265 	for (auto const& nv : rhs.num_pairs)
    266 		out << '"' << nv.first << "\": " << nv.second << ',';
    267 	return out << '}';
    268 }
    269 
    270 ostream& operator <<(ostream& out, std::optional<parse_t<json_object>> const& rhs)
    271 {
    272 	if (not rhs)
    273 		return out << "<failed parse>";
    274 	return out << rhs->val << " <" << rhs->remaining << ">";
    275 }
    276 
    277 struct client
    278 {
    279 	descriptor d;
    280 	std::string buf;
    281 
    282 	bool do_read()
    283 	{
    284 		string data = read(d);
    285 		if (data.empty())
    286 			return false;
    287 
    288 		buf.append(std::move(data));
    289 		return true;
    290 	}
    291 	optional<string> getline()
    292 	{
    293 		if (buf.empty())
    294 			return {};
    295 		auto n = buf.find('\n');
    296 		if (n == string::npos)
    297 			return {};
    298 		string line = buf.substr(0, n);
    299 		buf.erase(0, n+1);
    300 		return line;
    301 	}
    302 };
    303 
    304 std::optional<bignum> parse_json_request(string_view line)
    305 {
    306 	auto obj = parse_object(line);
    307 	if (not obj)
    308 	{
    309 		cout << "request is invalid JSON\n";
    310 		return {};
    311 	}
    312 	//std::cout << obj << std::flush;
    313 
    314 	auto srch = obj->val.nv_pairs.find("method");
    315 	if ((srch == obj->val.nv_pairs.end()) or (srch->second != "isPrime"))
    316 	{
    317 		cout << "'method' != isPrime or is missing\n";
    318 		return {};
    319 	}
    320 	auto srch2 = obj->val.num_pairs.find("number");
    321 	if (srch2 == obj->val.num_pairs.end())
    322 	{
    323 		cout << "'number' is not number type or is missing\n";
    324 		return {};
    325 	}
    326 	return srch2->second;
    327 }
    328 
    329 bool client_ready(client& c)
    330 {
    331 	if (not c.do_read())
    332 	{
    333 		cout << "* " << c.d << " disconnected\n";
    334 		return false;
    335 	}
    336 
    337 	while (auto line = c.getline())
    338 	{
    339 		if (auto number = parse_json_request(*line))
    340 		{
    341 			bool prime = number->is_prime();
    342 			std::cout << *number << ": " << prime << std::endl;
    343 
    344 			if (prime)
    345 				write(c.d, "{\"method\":\"isPrime\",\"prime\":true}\n"sv);
    346 			else
    347 				write(c.d, "{\"method\":\"isPrime\",\"prime\":false}\n"sv);
    348 		}
    349 		else
    350 		{
    351 			write(c.d, "xxx\n"sv);
    352 			return false;
    353 		}
    354 	}
    355 	return true;
    356 }
    357 
    358 int main()
    359 {
    360 	std::map<int, client> clients;
    361 	auto incoming = inet::listen(SOCK_STREAM);
    362 
    363 	std::cout << std::boolalpha << "Awaiting clients\n";
    364 
    365 	for (;;)
    366 	{
    367 		std::vector<pollfd> awaitables;
    368 		awaitables.push_back({incoming, POLLIN, 0});
    369 		for (auto const& c : clients)
    370 			awaitables.push_back({c.second.d, POLLIN, 0});
    371 
    372 		::poll(awaitables.data(), awaitables.size(), INFTIM);
    373 
    374 		for (auto& ev : awaitables)
    375 		{
    376 			if (~ev.revents & POLLIN)
    377 				continue;
    378 
    379 			if (ev.fd == incoming)
    380 			{
    381 				auto conn = inet::accept(incoming);
    382 
    383 				auto& x = clients[conn];
    384 				x.d = move(conn);
    385 				continue;
    386 			}
    387 
    388 			auto it = clients.find(ev.fd);
    389 			if (it == clients.end())
    390 			{
    391 				cerr << "Unknown fd " << ev.fd << endl;
    392 				return 1;
    393 			}
    394 
    395 			if (not client_ready(it->second))
    396 				clients.erase(it);
    397 		}
    398 	}
    399 }