console_server.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. /*
  2. * Copyright (c) 2012-2023 Daniele Bartolini et al.
  3. * SPDX-License-Identifier: MIT
  4. */
  5. #include "core/containers/array.inl"
  6. #include "core/containers/hash_map.inl"
  7. #include "core/containers/vector.inl"
  8. #include "core/filesystem/file_buffer.inl"
  9. #include "core/filesystem/reader_writer.inl"
  10. #include "core/json/json_object.inl"
  11. #include "core/json/sjson.h"
  12. #include "core/memory/temp_allocator.inl"
  13. #include "core/network/ip_address.h"
  14. #include "core/strings/dynamic_string.inl"
  15. #include "core/strings/string_id.inl"
  16. #include "core/strings/string_stream.inl"
  17. #include "core/thread/scoped_mutex.inl"
  18. #include "device/console_server.h"
  19. #include "device/log.h"
  20. LOG_SYSTEM(CONSOLE_SERVER, "console_server")
  21. namespace crown
  22. {
  23. namespace console_server_internal
  24. {
  25. static void message_command(ConsoleServer &cs, u32 client_id, const char *json, void *user_data)
  26. {
  27. TempAllocator4096 ta;
  28. JsonObject obj(ta);
  29. JsonArray args(ta);
  30. sjson::parse(obj, json);
  31. sjson::parse_array(args, obj["args"]);
  32. DynamicString command_name(ta);
  33. sjson::parse_string(command_name, args[0]);
  34. ConsoleServer::CommandData cmd;
  35. cmd.command_function = NULL;
  36. cmd.user_data = NULL;
  37. cmd = hash_map::get(cs._commands, command_name.to_string_id(), cmd);
  38. if (cmd.command_function == NULL) {
  39. ((ConsoleServer *)user_data)->error(client_id, "Command not found");
  40. return;
  41. }
  42. cmd.command_function(cs, client_id, args, cmd.user_data);
  43. }
  44. static void command_help(ConsoleServer &cs, u32 client_id, const JsonArray &args, void * /*user_data*/)
  45. {
  46. if (array::size(args) != 1) {
  47. cs.error(client_id, "Usage: help");
  48. return;
  49. }
  50. u32 longest = 0;
  51. auto cur = hash_map::begin(cs._commands);
  52. auto end = hash_map::end(cs._commands);
  53. for (; cur != end; ++cur) {
  54. HASH_MAP_SKIP_HOLE(cs._commands, cur);
  55. if (longest < strlen32(cur->second.name))
  56. longest = strlen32(cur->second.name);
  57. }
  58. cur = hash_map::begin(cs._commands);
  59. end = hash_map::end(cs._commands);
  60. for (; cur != end; ++cur) {
  61. HASH_MAP_SKIP_HOLE(cs._commands, cur);
  62. logi(CONSOLE_SERVER, "%s%*s%s"
  63. , cur->second.name
  64. , longest - strlen32(cur->second.name) + 2
  65. , " "
  66. , cur->second.brief
  67. );
  68. }
  69. }
  70. static u32 add_client(ConsoleServer &cs, const TCPSocket &socket)
  71. {
  72. ScopedMutex scoped_mutex(cs._clients_mutex);
  73. ConsoleServer::Client client;
  74. client.socket = socket;
  75. client.id = cs._next_client_id++;
  76. vector::push_back(cs._clients, client);
  77. return client.id;
  78. }
  79. static void remove_client_by_socket(ConsoleServer &cs, const TCPSocket &socket)
  80. {
  81. ScopedMutex scoped_mutex(cs._clients_mutex);
  82. const u32 last = vector::size(cs._clients) - 1;
  83. for (u32 cc = 0; cc < vector::size(cs._clients); ++cc) {
  84. if (cs._clients[cc].socket == socket) {
  85. cs._clients[cc] = cs._clients[last];
  86. vector::pop_back(cs._clients);
  87. return;
  88. }
  89. }
  90. }
  91. static u32 get_client_id(ConsoleServer &cs, const TCPSocket &socket)
  92. {
  93. ScopedMutex scoped_mutex(cs._clients_mutex);
  94. const u32 num_clients = vector::size(cs._clients);
  95. for (u32 cc = 0; cc < num_clients; ++cc) {
  96. if (cs._clients[cc].socket == socket)
  97. return cs._clients[cc].id;
  98. }
  99. return UINT32_MAX;
  100. }
  101. static bool get_socket_by_id(TCPSocket *socket, ConsoleServer &cs, u32 id)
  102. {
  103. ScopedMutex scoped_mutex(cs._clients_mutex);
  104. const u32 num_clients = vector::size(cs._clients);
  105. for (u32 cc = 0; cc < num_clients; ++cc) {
  106. if (cs._clients[cc].id == id) {
  107. *socket = cs._clients[cc].socket;
  108. return true;
  109. }
  110. }
  111. return false;
  112. }
  113. } // namespace console_server_internal
  114. ConsoleServer::ConsoleServer(Allocator &a)
  115. : _port(UINT16_MAX)
  116. , _next_client_id(0)
  117. , _clients(a)
  118. , _messages(a)
  119. , _commands(a)
  120. , _thread_exit(false)
  121. , _input_0(a)
  122. , _input_1(a)
  123. , _input_write(&_input_0)
  124. , _input_read(&_input_1)
  125. , _output_0(a)
  126. , _output_1(a)
  127. , _output_write(&_output_0)
  128. , _output_read(&_output_1)
  129. {
  130. this->register_message_type("command", console_server_internal::message_command, this);
  131. this->register_command_name("help", "List all commands", console_server_internal::command_help, this);
  132. }
  133. void ConsoleServer::listen(u16 port, bool wait)
  134. {
  135. const BindResult br = _server.bind(port);
  136. if (br.error != BindResult::SUCCESS)
  137. return;
  138. _port = port;
  139. _server.listen(5);
  140. _active_socket_set.set(&_server);
  141. _input_thread.start([](void *thiz) { return ((ConsoleServer *)thiz)->run_input_thread(); }, this);
  142. _output_thread.start([](void *thiz) { return ((ConsoleServer *)thiz)->run_output_thread(); }, this);
  143. // Connect a dummy client to the _server to
  144. // unlock the input_thread later at exit.
  145. _dummy_client.connect(IP_ADDRESS_LOOPBACK, _port);
  146. _client_connected.wait();
  147. // Wait for real clients to connect.
  148. if (wait)
  149. _client_connected.wait();
  150. }
  151. void ConsoleServer::close()
  152. {
  153. _thread_exit = true;
  154. // Unlock input thread if it is stuck inside the select().
  155. u32 blank_header = 0;
  156. if (_dummy_client.is_open())
  157. _dummy_client.write(&blank_header, sizeof(blank_header));
  158. }
  159. void ConsoleServer::shutdown()
  160. {
  161. close();
  162. _dummy_client.close();
  163. _handlers_semaphore.post();
  164. if (_input_thread.is_running())
  165. _input_thread.stop();
  166. _output_condition.signal();
  167. if (_output_thread.is_running())
  168. _output_thread.stop();
  169. ScopedMutex scoped_mutex(_clients_mutex);
  170. for (u32 i = 0; i < vector::size(_clients); ++i)
  171. _clients[i].socket.close();
  172. _server.close();
  173. }
  174. void ConsoleServer::send(u32 client_id, const char *json)
  175. {
  176. TCPSocket socket;
  177. if (!console_server_internal::get_socket_by_id(&socket, *this, client_id))
  178. return;
  179. const u32 msg_len = strlen32(json);
  180. _output_mutex.lock();
  181. FileBuffer fb(*_output_write);
  182. fb.seek_to_end();
  183. BinaryWriter bw(fb);
  184. bw.write(client_id);
  185. bw.write(msg_len);
  186. bw.write(json, msg_len);
  187. _output_condition.signal();
  188. _output_mutex.unlock();
  189. }
  190. void ConsoleServer::error(u32 client_id, const char *msg)
  191. {
  192. TempAllocator4096 ta;
  193. StringStream ss(ta);
  194. ss << "{\"type\":\"error\",\"message\":\"" << msg << "\"}";
  195. send(client_id, string_stream::c_str(ss));
  196. }
  197. void ConsoleServer::broadcast(const char *json)
  198. {
  199. for (u32 i = 0; i < vector::size(_clients); ++i)
  200. send(_clients[i].id, json);
  201. }
  202. void ConsoleServer::execute_message_handlers(bool sync)
  203. {
  204. bool locked = true;
  205. if (sync)
  206. _input_semaphore.wait();
  207. else
  208. locked = _input_semaphore.try_wait();
  209. if (!locked)
  210. return;
  211. Buffer *temp = _input_read;
  212. _input_read = _input_write;
  213. _input_write = temp;
  214. _handlers_semaphore.post();
  215. // Do not execute message handlers at exit, because when _thread_exit is
  216. // set by shutdown(), handlers may reference stale objects.
  217. if (_thread_exit)
  218. return;
  219. FileBuffer fb(*_input_read);
  220. BinaryReader br(fb);
  221. while (!fb.end_of_file()) {
  222. // Read client, message size and message.
  223. u32 client_id;
  224. u32 msg_len;
  225. br.read(client_id);
  226. br.read(msg_len);
  227. const char *msg = array::begin(*_input_read) + fb.position();
  228. br.skip(msg_len);
  229. if (msg_len > 0) {
  230. // Process the message if any.
  231. JsonObject obj(default_allocator());
  232. sjson::parse(obj, msg);
  233. if (!json_object::has(obj, "type")) {
  234. error(client_id, "Missing command type");
  235. continue;
  236. }
  237. // Find handler for the message type.
  238. CommandData cmd;
  239. cmd.message_function = NULL;
  240. cmd.user_data = NULL;
  241. cmd = hash_map::get(_messages
  242. , sjson::parse_string_id(obj["type"])
  243. , cmd
  244. );
  245. if (!cmd.message_function) {
  246. error(client_id, "Unknown command type");
  247. continue;
  248. }
  249. // Call the handler.
  250. cmd.message_function(*this, client_id, msg, cmd.user_data);
  251. }
  252. }
  253. array::clear(*_input_read);
  254. }
  255. void ConsoleServer::register_command_name(const char *name, const char *brief, CommandTypeFunction function, void *user_data)
  256. {
  257. CE_ENSURE(NULL != name);
  258. CE_ENSURE(NULL != brief);
  259. CE_ENSURE(NULL != function);
  260. CommandData cmd;
  261. cmd.command_function = function;
  262. cmd.user_data = user_data;
  263. strncpy(cmd.name, name, sizeof(cmd.name) - 1);
  264. strncpy(cmd.brief, brief, sizeof(cmd.brief) - 1);
  265. hash_map::set(_commands, StringId32(name), cmd);
  266. }
  267. void ConsoleServer::register_message_type(const char *type, MessageTypeFunction function, void *user_data)
  268. {
  269. CE_ENSURE(NULL != type);
  270. CE_ENSURE(NULL != function);
  271. CommandData cmd;
  272. cmd.message_function = function;
  273. cmd.user_data = user_data;
  274. hash_map::set(_messages, StringId32(type), cmd);
  275. }
  276. s32 ConsoleServer::run_input_thread()
  277. {
  278. while (!_thread_exit) {
  279. // Wait for input from one of the sockets in _active_socket_set.
  280. _read_socket_set = _active_socket_set;
  281. SelectResult ret = _read_socket_set.select(UINT32_MAX);
  282. if (ret.error == SelectResult::GENERIC_ERROR) {
  283. return -1;
  284. } else if (ret.error == SelectResult::TIMEOUT) {
  285. continue;
  286. }
  287. FileBuffer fb(*_input_write);
  288. BinaryWriter bw(fb);
  289. // Read data from all clients that are ready.
  290. const u32 num_sockets = _read_socket_set.num();
  291. for (u32 ii = 0; ii < num_sockets; ++ii) {
  292. TCPSocket cur_socket = _read_socket_set.get(ii);
  293. // Skip if socket is not ready for reading.
  294. if (_read_socket_set.isset(&cur_socket) == false)
  295. continue;
  296. // If ready socket is the one listening for incoming connections.
  297. if (cur_socket == _server) {
  298. if (_thread_exit)
  299. break;
  300. // Accept the incoming connection.
  301. TCPSocket client;
  302. AcceptResult ar = _server.accept_nonblock(client);
  303. if (ar.error == AcceptResult::SUCCESS) {
  304. console_server_internal::add_client(*this, client);
  305. _active_socket_set.set(&client);
  306. _client_connected.post();
  307. }
  308. } else { // Check if any other socket is ready for reading.
  309. u32 msg_len = 0;
  310. ReadResult rr = cur_socket.read(&msg_len, 4);
  311. if (rr.error != ReadResult::SUCCESS) {
  312. console_server_internal::remove_client_by_socket(*this, cur_socket);
  313. _active_socket_set.clr(&cur_socket);
  314. cur_socket.close();
  315. continue;
  316. }
  317. const u32 client_id = console_server_internal::get_client_id(*this, cur_socket);
  318. // Add client header and message length.
  319. bw.write(client_id);
  320. bw.write(msg_len);
  321. // Read message.
  322. u32 num_read;
  323. for (num_read = 0; num_read < msg_len;) {
  324. char buf[4096];
  325. const u32 num_pending = min(u32(sizeof(buf)), msg_len - num_read);
  326. rr = cur_socket.read(buf, num_pending);
  327. if (rr.error != ReadResult::SUCCESS) {
  328. console_server_internal::remove_client_by_socket(*this, cur_socket);
  329. _active_socket_set.clr(&cur_socket);
  330. cur_socket.close();
  331. break;
  332. }
  333. bw.write(buf, rr.bytes_read);
  334. num_read += rr.bytes_read;
  335. }
  336. if (num_read != msg_len) {
  337. // Remove partial data that has been written to the input buffer.
  338. for (u32 cc = 0; cc < 4 + 4 + num_read; ++cc)
  339. array::pop_back(*_input_write);
  340. }
  341. }
  342. }
  343. if (array::size(*_input_write) > 0) {
  344. _input_semaphore.post();
  345. if (!_thread_exit)
  346. _handlers_semaphore.wait();
  347. }
  348. }
  349. return 0;
  350. }
  351. s32 ConsoleServer::run_output_thread()
  352. {
  353. while (1) {
  354. _output_mutex.lock();
  355. while (array::size(*_output_write) == 0 && !_thread_exit)
  356. _output_condition.wait(_output_mutex);
  357. if (_thread_exit) {
  358. _output_mutex.unlock();
  359. break;
  360. }
  361. Buffer *temp = _output_read;
  362. _output_read = _output_write;
  363. _output_write = temp;
  364. _output_mutex.unlock();
  365. FileBuffer fb(*_output_read);
  366. BinaryReader br(fb);
  367. while (!fb.end_of_file()) {
  368. // Read client, message size and message.
  369. u32 client_id;
  370. u32 msg_len;
  371. br.read(client_id);
  372. br.read(msg_len);
  373. const char *msg = array::begin(*_output_read) + fb.position();
  374. br.skip(msg_len);
  375. // Lookup socket by its ID.
  376. TCPSocket socket;
  377. if (console_server_internal::get_socket_by_id(&socket, *this, client_id) != true)
  378. continue;
  379. socket.write(msg - 4, msg_len + 4);
  380. }
  381. array::clear(*_output_read);
  382. }
  383. return 0;
  384. }
  385. namespace console_server_globals
  386. {
  387. ConsoleServer *_console_server = NULL;
  388. void init()
  389. {
  390. _console_server = CE_NEW(default_allocator(), ConsoleServer)(default_allocator());
  391. }
  392. void shutdown()
  393. {
  394. _console_server->shutdown();
  395. CE_DELETE(default_allocator(), _console_server);
  396. _console_server = NULL;
  397. }
  398. } // namespace console_server_globals
  399. ConsoleServer *console_server()
  400. {
  401. return console_server_globals::_console_server;
  402. }
  403. } // namespace crown