socket.h 15 KB


  1. /*
  2. * This is free and unencumbered software released into the public domain.
  3. *
  4. * Anyone is free to copy, modify, publish, use, compile, sell, or
  5. * distribute this software, either in source code form or as a compiled
  6. * binary, for any purpose, commercial or non-commercial, and by any
  7. * means.
  8. *
  9. * In jurisdictions that recognize copyright laws, the author or authors
  10. * of this software dedicate any and all copyright interest in the
  11. * software to the public domain. We make this dedication for the benefit
  12. * of the public at large and to the detriment of our heirs and
  13. * successors. We intend this dedication to be an overt act of
  14. * relinquishment in perpetuity of all present and future rights to this
  15. * software under copyright law.
  16. *
  17. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  18. * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  19. * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
  20. * IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
  21. * OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
  22. * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
  23. * OTHER DEALINGS IN THE SOFTWARE.
  24. *
  25. * For more information, please refer to <http://unlicense.org>
  26. */
  27. #ifndef GUL_SOCKET__H
  28. #define GUL_SOCKET__H
  29. #ifdef _MSC_VER
  30. #define _WINSOCKAPI_
  31. #define _WINSOCK_DEPRECATED_NO_WARNINGS
  32. #endif
  33. #include <cstdio>
  34. #include <fcntl.h>
  35. #include <ctype.h>
  36. #include <errno.h>
  37. #include <cstdint>
  38. #include <cstring>
  39. #include <chrono>
  40. #include <string>
  41. #if defined _MSC_VER
  42. #include <winsock2.h>
  43. #pragma comment(lib,"wsock32.lib")
  44. #else
  45. #include <sys/ioctl.h>
  46. #include <unistd.h>
  47. #include <arpa/inet.h>
  48. #include <sys/select.h>
  49. #include <sys/socket.h>
  50. #include <sys/un.h>
  51. #endif
  52. namespace gul
  53. {
  54. ////////////////////////////////////////////////////////////////////////////////
  55. /// New Implementations here!
  56. ////////////////////////////////////////////////////////////////////////////////
  57. /**
  58. * @brief The socket_address class
  59. *
  60. * The socket address class is used to identify the address of the
  61. * connected socket. It is used by the udp_socket to indicate which
  62. * client has sent a packet and to indicate which client to send a message
  63. * to.
  64. */
  65. class socket_address
  66. {
  67. protected:
  68. #if defined _MSC_VER
  69. using address_t = struct ::sockaddr_in;
  70. #else
  71. using address_t = struct sockaddr_in;
  72. #endif
  73. public:
  74. socket_address()
  75. {
  76. memset( reinterpret_cast<char*>(&m_address), 0, sizeof(m_address));
  77. }
  78. socket_address(uint16_t _port)
  79. {
  80. memset( reinterpret_cast<char *>(&m_address), 0, sizeof(m_address));
  81. m_address.sin_family = AF_INET;
  82. m_address.sin_port = htons(_port);
  83. m_address.sin_addr.s_addr = INADDR_ANY;
  84. }
  85. socket_address(char const * ip_address, uint16_t _port)
  86. {
  87. //setup address structure
  88. memset( reinterpret_cast<char *>(&m_address), 0, sizeof(m_address));
  89. m_address.sin_family = AF_INET;
  90. m_address.sin_port = htons(_port);
  91. #if defined _MSC_VER
  92. m_address.sin_addr.S_un.S_addr = inet_addr(ip_address);
  93. #else
  94. m_address.sin_addr.s_addr = inet_addr(ip_address);
  95. #endif
  96. }
  97. operator bool()
  98. {
  99. return m_address.sin_port == 0;
  100. }
  101. /**
  102. * @brief native_address
  103. * @return
  104. * Returns the native address handle of the address strut
  105. */
  106. address_t const & native_address() const
  107. {
  108. return m_address;
  109. }
  110. /**
  111. * @brief native_address
  112. * @return
  113. * Returns the native address handle of the address strut
  114. */
  115. address_t & native_address()
  116. {
  117. return m_address;
  118. }
  119. /**
  120. * @brief ip
  121. * @return
  122. * Returns the ipaddress as a character string
  123. */
  124. char const * ip() const
  125. {
  126. return inet_ntoa(m_address.sin_addr);
  127. }
  128. /**
  129. * @brief port
  130. * @return
  131. * Returns the port number
  132. */
  133. uint16_t port() const
  134. {
  135. return ntohs(m_address.sin_port);
  136. }
  137. operator address_t()
  138. {
  139. return m_address;
  140. static_assert( sizeof(address_t) == sizeof(socket_address), "struct sizes are not the same");
  141. }
  142. protected:
  143. address_t m_address;
  144. };
  145. enum class socket_domain
  146. {
  147. NET,
  148. UNIX
  149. };
  150. enum class socket_type
  151. {
  152. STREAM,
  153. DGRAM
  154. };
  155. class socket
  156. {
  157. protected:
  158. #if defined _WIN32
  159. using native_msg_size_input_t = int;// the input message length type for send/recv/sendto/recvfrom
  160. using native_msg_size_return_t = int;// the return type for send/recv/sendto/recvfrom
  161. using native_raw_buffer_t = char;
  162. using socket_t = SOCKET;
  163. #else
  164. using native_msg_size_input_t = size_t; // the input message length type for send/recv/sendto/recvfrom
  165. using native_msg_size_return_t = ssize_t;// the return type for send/recv/sendto/recvfrom
  166. using native_raw_buffer_t = void;
  167. using socket_t = int; //
  168. #endif
  169. public:
  170. #if defined _WIN32
  171. static const socket_t invalid_socket = INVALID_SOCKET;
  172. static const int socket_error = SOCKET_ERROR;
  173. static const int msg_error = -1;
  174. #else
  175. static const socket_t invalid_socket = -1;
  176. static const int socket_error = -1;
  177. static const ssize_t msg_error = -1;
  178. #endif
  179. static const int bind_error = -1;
  180. static const int listen_error = -1;
  181. static const int connect_error = -1;
  182. typedef std::int32_t msg_size_t;
  183. static const msg_size_t error = -1;
  184. socket() : m_fd( invalid_socket )
  185. {
  186. }
  187. socket(socket const & other) : m_fd(other.m_fd)
  188. {
  189. }
  190. socket(socket && other) : m_fd(other.m_fd)
  191. {
  192. other.m_fd=invalid_socket;
  193. }
  194. socket & operator=(socket && other)
  195. {
  196. if( this != &other)
  197. {
  198. m_fd = other.m_fd;
  199. other.m_fd=invalid_socket;
  200. }
  201. return *this;
  202. }
  203. bool create(socket_domain d, socket_type t)
  204. {
  205. if( d == socket_domain::NET )
  206. {
  207. if( t == socket_type::STREAM)
  208. {
  209. return _create(AF_INET, SOCK_STREAM, 0);
  210. }
  211. else if( t==socket_type::DGRAM)
  212. {
  213. return _create(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
  214. }
  215. }
  216. else if( d == socket_domain::UNIX)
  217. {
  218. if( t == socket_type::STREAM)
  219. {
  220. return _create(AF_UNIX, SOCK_STREAM, 0);
  221. }
  222. else if( t==socket_type::DGRAM)
  223. {
  224. return _create(AF_UNIX, SOCK_DGRAM, IPPROTO_UDP);
  225. }
  226. }
  227. return false;
  228. }
  229. bool bind(const std::string endpoint)
  230. {
  231. const auto separator = endpoint.find_last_of(':');
  232. //Check if input wasn't missformed
  233. if (separator == std::string::npos)
  234. {
  235. // possibly domain socket
  236. //throw std::runtime_error("string is not of address:port form");
  237. if( !create(socket_domain::UNIX, socket_type::STREAM) )
  238. {
  239. return false;
  240. }
  241. struct sockaddr_un d_name;
  242. memset(&d_name, 0 ,sizeof(struct sockaddr_un) );
  243. d_name.sun_family = AF_UNIX;
  244. strcpy(d_name.sun_path, endpoint.c_str());
  245. ::unlink( d_name.sun_path );
  246. int ret = ::bind(m_fd, reinterpret_cast<const struct sockaddr *>(&d_name),
  247. sizeof(struct sockaddr_un));
  248. if( ret == bind_error)
  249. return false;
  250. return true;
  251. }
  252. else
  253. {
  254. if (separator == endpoint.size() - 1)
  255. {
  256. return false;
  257. //throw std::runtime_error("string has ':' as last character. Expected port number here");
  258. }
  259. //Isolate address
  260. std::string address = endpoint.substr(0, separator);
  261. //Read from string as unsigned
  262. const auto port = static_cast<uint16_t>( strtoul(endpoint.substr(separator + 1).c_str(), nullptr, 10) );
  263. if( !create(socket_domain::NET, socket_type::STREAM) )
  264. {
  265. return false;
  266. }
  267. return _bind( socket_address(port));
  268. }
  269. }
  270. socket accept()
  271. {
  272. socket client;
  273. int length = sizeof( m_address.native_address() );
  274. #if defined _MSC_VER
  275. using socklen_t = int;
  276. #endif
  277. client.m_fd = ::accept(m_fd, reinterpret_cast<struct sockaddr*>(&m_address.native_address()), reinterpret_cast<socklen_t*>(&length));
  278. ::getpeername(client.m_fd , reinterpret_cast<struct sockaddr *>(&client.m_address.native_address()), reinterpret_cast<socklen_t*>(&length) );
  279. return client;
  280. }
  281. /**
  282. * @brief close
  283. * Closes the socket. The socket will cast to false after this has been
  284. * called.
  285. */
  286. void close()
  287. {
  288. #ifdef _MSC_VER
  289. ::closesocket(m_fd);
  290. #else
  291. ::shutdown(m_fd, SHUT_RDWR);
  292. ::close(m_fd);
  293. #endif
  294. m_fd = invalid_socket;
  295. }
  296. /**
  297. * @brief native_handle
  298. * @return
  299. *
  300. * Returns the native handle of the socket descriptor
  301. */
  302. socket_t native_handle() const
  303. {
  304. return m_fd;
  305. }
  306. bool set_recv_timeout(std::chrono::microseconds ms)
  307. {
  308. #ifdef _MSC_VER
  309. // WINDOWS
  310. DWORD timeout = ms.count() / 1000;
  311. if( setsockopt(socket, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof timeout) != 0)
  312. {
  313. return false;
  314. }
  315. #else
  316. struct timeval timeout;
  317. timeout.tv_sec = static_cast< decltype(timeout.tv_usec) >( ms.count() / 1000000u );
  318. timeout.tv_usec = static_cast< decltype(timeout.tv_usec) >( ms.count() % 1000000u );
  319. if (setsockopt (m_fd, SOL_SOCKET, SO_RCVTIMEO, (char *)&timeout, sizeof(timeout)) < 0)
  320. {
  321. return false;
  322. //error("setsockopt failed\n");
  323. }
  324. #endif
  325. return true;
  326. }
  327. bool set_send_timeout(std::chrono::microseconds ms)
  328. {
  329. #ifdef _MSC_VER
  330. // WINDOWS
  331. DWORD timeout = ms.count() / 1000;
  332. if( setsockopt(socket, SOL_SOCKET, SO_SNDTIMEO, (const char*)&timeout, sizeof timeout) != 0)
  333. {
  334. return false;
  335. }
  336. #else
  337. struct timeval timeout;
  338. timeout.tv_sec = static_cast< decltype(timeout.tv_usec) >( ms.count() / 1000000u );
  339. timeout.tv_usec = static_cast< decltype(timeout.tv_usec) >( ms.count() % 1000000u );
  340. if (setsockopt (m_fd, SOL_SOCKET, SO_SNDTIMEO, (char *)&timeout, sizeof(timeout)) < 0)
  341. {
  342. return false;
  343. //error("setsockopt failed\n");
  344. }
  345. #endif
  346. return true;
  347. }
  348. bool listen( std::size_t max_connections)
  349. {
  350. decltype(socket_error) code = ::listen( m_fd, static_cast<int>(max_connections));
  351. if( code == socket_error)
  352. {
  353. return false;
  354. }
  355. return true;
  356. }
  357. operator bool() const
  358. {
  359. return !( ( m_fd == invalid_socket ) );
  360. }
  361. //=============================================================================
  362. msg_size_t sendto(void const * data, size_t length, socket_address const & addr)
  363. {
  364. native_msg_size_return_t ret = ::sendto(m_fd,
  365. reinterpret_cast<native_raw_buffer_t const*>(data),
  366. static_cast<native_msg_size_input_t>(length&0xFFFFFFFF) ,
  367. 0 ,
  368. reinterpret_cast<struct sockaddr const *>(&addr.native_address()),
  369. sizeof(struct sockaddr_in ));
  370. if ( ret == msg_error)
  371. {
  372. #ifdef _MSC_VER
  373. //printf("Send failed with error code : %d" , WSAGetLastError() );
  374. #else
  375. //printf("Send failed with error code : %d : %s" , errno, strerror(errno) );
  376. #endif
  377. return msg_size_t(ret);
  378. }
  379. return msg_size_t(ret);
  380. }
  381. msg_size_t recvfrom(void * buf, size_t length, socket_address & addr)
  382. {
  383. #if defined _MSC_VER
  384. int slen = sizeof(struct sockaddr_in);
  385. #else
  386. socklen_t slen = sizeof(struct sockaddr_in);
  387. #endif
  388. native_msg_size_return_t ret = ::recvfrom( m_fd, reinterpret_cast<native_raw_buffer_t*>(buf), static_cast<native_msg_size_input_t>(length&0xFFFFFFFF), 0, reinterpret_cast<struct sockaddr *>(&addr.native_address()), &slen);
  389. if (ret == msg_error)
  390. {
  391. #ifdef _MSC_VER
  392. //printf("Recv failed with error code : %d" , WSAGetLastError() );
  393. #else
  394. //printf("Recv failed with error code : %d : %s" , errno, strerror(errno) );
  395. #endif
  396. //return msg_error;
  397. }
  398. return msg_size_t(ret);
  399. }
  400. msg_size_t send( void const * data, size_t _size)
  401. {
  402. native_msg_size_return_t ret = ::send(m_fd, reinterpret_cast<const native_raw_buffer_t*>(data), static_cast<native_msg_size_input_t>(_size&0xFFFFFFFF), 0);
  403. return msg_size_t(ret);
  404. }
  405. msg_size_t recv(void * data, size_t _size, bool wait_for_all=true)
  406. {
  407. //bool wait_for_all = true; // default for now.
  408. native_msg_size_return_t t = ::recv( m_fd, reinterpret_cast<native_raw_buffer_t*>(data), static_cast<native_msg_size_input_t>(_size&0xFFFFFFFF), wait_for_all ? MSG_WAITALL : 0 );
  409. if( t == 0 && _size != 0 ) // gracefully closed
  410. {
  411. m_fd = invalid_socket;
  412. }
  413. return msg_size_t(t);
  414. }
  415. //=============================================================================
  416. protected:
  417. socket_t m_fd = invalid_socket;
  418. socket_address m_address;
  419. bool _create(int _domain, int _type, int _protocol)
  420. {
  421. #ifdef _MSC_VER
  422. WSADATA wsa;
  423. if (WSAStartup(MAKEWORD(2,2),&wsa) != 0)
  424. {
  425. //printf("Failed. Error Code : %d",WSAGetLastError());
  426. return false;
  427. }
  428. #endif
  429. if ( (m_fd=::socket(_domain, _type, _protocol)) == invalid_socket)
  430. {
  431. #ifdef _MSC_VER
  432. //printf("Create failed with error code : %d\n" , WSAGetLastError() );
  433. #else
  434. //printf("Create failed with error code : %d : %s\n" , errno, strerror(errno) );
  435. #endif
  436. return false;
  437. }
  438. return true;
  439. }
  440. bool _bind(socket_address const & addr)
  441. {
  442. auto ret = ::bind(m_fd, reinterpret_cast<struct sockaddr const*>(&addr.native_address()) , sizeof(socket_address));
  443. if( ret == bind_error)
  444. {
  445. #ifdef _MSC_VER
  446. //printf("Bind failed with error code : %d" , WSAGetLastError() );
  447. #else
  448. //printf("Bind failed with error code : %d : %s" , errno, strerror(errno) );
  449. #endif
  450. return false;
  451. }
  452. return true;
  453. }
  454. };
  455. }
  456. #endif