websocket.H 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. #include <cpoll/cpoll.H>
  2. #include <rgc.H>
  3. using namespace CP;
  4. using namespace RGC;
  5. namespace cppsp
  6. {
  7. struct WebSocketParser
  8. {
  9. struct ws_header1
  10. {
  11. //char flags:8;
  12. unsigned int opcode :4;
  13. bool rsv1 :1;
  14. bool rsv2 :1;
  15. bool rsv3 :1;
  16. bool fin :1;
  17. unsigned int payload_len :7;
  18. bool mask :1;
  19. }__attribute__((packed));
  20. struct ws_footer1
  21. {
  22. uint32_t masking_key;
  23. }__attribute__((packed));
  24. struct ws_header_extended16
  25. {
  26. uint16_t payload_len;
  27. }__attribute__((packed));
  28. struct ws_header_extended64
  29. {
  30. uint64_t payload_len;
  31. }__attribute__((packed));
  32. struct WSFrame
  33. {
  34. String data;
  35. char opcode;
  36. bool fin;
  37. };
  38. MemoryStream ms;
  39. int pos = 0;
  40. String beginPutData(int len) {
  41. if (ms.bufferSize - ms.bufferPos < len) ms.flushBuffer(len);
  42. return {(char*)ms.buffer + ms.bufferPos,ms.bufferSize-ms.bufferPos};
  43. }
  44. void endPutData(int len) {
  45. ms.bufferPos += len;
  46. ms.flush();
  47. }
  48. void skip(int length) {
  49. pos += length;
  50. }
  51. inline void unmask(String data, uint32_t key) {
  52. /*uint32_t* d = (uint32_t*) data.data();
  53. int len = data.length() / sizeof(*d);
  54. for (int i = 0; i < len; i++) {
  55. d[i] ^= key;
  56. }
  57. uint8_t* tmp = (uint8_t*) (d + len);
  58. uint8_t* tmp1 = (uint8_t*) &key;
  59. int leftover = data.length() % sizeof(*d);
  60. if (leftover > 0) tmp[0] ^= tmp1[0];
  61. if (leftover > 1) tmp[1] ^= tmp1[1];
  62. if (leftover > 2) tmp[2] ^= tmp1[2];
  63. if (leftover > 3) tmp[3] ^= tmp1[3];*/
  64. uint8_t* k = (uint8_t*) &key;
  65. for (int i = 0; i < data.length(); i++) {
  66. data.d[i] = data.d[i] ^ k[i % sizeof(key)];
  67. }
  68. }
  69. bool process(WSFrame& out) {
  70. char* data = (char*) ms.data() + pos;
  71. int len = ms.length() - pos;
  72. int minLen = sizeof(ws_header1);
  73. if (len < minLen) return false;
  74. ws_header1* h1 = (ws_header1*) data;
  75. uint8_t pLen1 = h1->payload_len; // & ~(uint8_t) 128;
  76. //printf("pLen1 = %i\n", pLen1);
  77. int pLen2 = 0;
  78. if (pLen1 == 126) pLen2 = 2;
  79. if (pLen1 == 127) pLen2 = 8;
  80. minLen += pLen2;
  81. if (h1->mask) minLen += 4;
  82. if (len < minLen) return false;
  83. //printf("len = %i\n", len);
  84. //printf("minLen = %i\n", minLen);
  85. uint64_t payloadLen;
  86. switch (pLen1) {
  87. case 126:
  88. {
  89. ws_header_extended16* h2 = (ws_header_extended16*) (h1 + 1);
  90. payloadLen = ntohs(h2->payload_len);
  91. break;
  92. }
  93. case 127:
  94. {
  95. ws_header_extended64* h2 = (ws_header_extended64*) (h1 + 1);
  96. payloadLen = ntohll(h2->payload_len);
  97. break;
  98. }
  99. default:
  100. payloadLen = pLen1;
  101. break;
  102. }
  103. //printf("payloadLen = %lli\n", payloadLen);
  104. if (len < int(minLen + payloadLen)) return false;
  105. char* payload = data + minLen;
  106. out.data= {payload,(int)payloadLen};
  107. out.fin = h1->fin;
  108. out.opcode = h1->opcode;
  109. pos += minLen + (int) payloadLen;
  110. if (h1->mask) unmask( { payload, (int) payloadLen },
  111. ((ws_footer1*) ((char*) (h1 + 1) + pLen2))->masking_key);
  112. return true;
  113. }
  114. //free up buffer space
  115. void reset() {
  116. if (pos > 0) {
  117. int shift = pos;
  118. if (ms.length() - shift > 0) memmove(ms.buffer, ms.buffer + shift, ms.length() - shift);
  119. ms.len -= shift;
  120. pos -= shift;
  121. ms.bufferPos = ms.len;
  122. }
  123. }
  124. };
  125. class FrameWriter
  126. {
  127. public:
  128. MemoryStream ms1, ms2;
  129. Ref<Stream> output;
  130. vector<String> queue;
  131. struct queueItem
  132. {
  133. int next; //is actually a pointer, but relative to the base of the array (MemoryStream)
  134. int len;
  135. char data[0];
  136. };
  137. int _first = -1, _last = -1, _count = 0;
  138. bool use_ms2 = false;
  139. bool _append;
  140. bool closed = false;
  141. bool writeQueued = false;
  142. inline MemoryStream& ms() {
  143. return use_ms2 ? ms2 : ms1;
  144. }
  145. inline queueItem& _item(int i) {
  146. return *(queueItem*) (ms().data() + i);
  147. }
  148. /**
  149. Prepare for the insertion of a chunk into the queue;
  150. @param append whether to append to the queue or insert at the beginning
  151. @return the allocated buffer space; may be larger than the requested length
  152. You must not call beginInsert again before calling endInsert.
  153. */
  154. String beginInsert(int len, bool append = true) {
  155. _append = append;
  156. String tmp = ms().beginAppend(len + sizeof(queueItem));
  157. return tmp.subString(sizeof(queueItem));
  158. }
  159. /**
  160. Complete the insertion of a chunk.
  161. */
  162. void endInsert(int len) {
  163. //printf("endInsert: len=%i\n",len);
  164. int tmp = ms().length();
  165. ms().endAppend(len + sizeof(queueItem));
  166. if (_append) {
  167. _item(tmp).next = -1;
  168. if (_last >= 0) _item(_last).next = tmp;
  169. _last = tmp;
  170. if (_first < 0) _first = tmp;
  171. } else {
  172. _item(tmp).next = _first;
  173. _first = tmp;
  174. if (_last < 0) _last = tmp;
  175. }
  176. _item(tmp).len = len;
  177. ++_count;
  178. }
  179. bool writing = false;
  180. void flush() {
  181. beginFlush();
  182. }
  183. void beginFlush() {
  184. if (writing) {
  185. writeQueued = true;
  186. return;
  187. }
  188. if (ms().length() <= 0 || _count <= 0) return;
  189. writing = true;
  190. int iovcnt = 0;
  191. iovec* iov = (iovec*) ms().beginAppend(sizeof(iovec) * _count).data();
  192. ms().endAppend(sizeof(iovec) * _count);
  193. for (int i = _first; i >= 0; i = _item(i).next) {
  194. iov[iovcnt++]= {_item(i).data,(size_t)_item(i).len};
  195. //printf("id=%i iovcnt=%i len=%i\n",i,iovcnt,_item(i).len);
  196. }
  197. use_ms2 = !use_ms2;
  198. _first = _last = -1;
  199. _count = 0;
  200. output->writevAll(iov, iovcnt, { &FrameWriter::_writevCB, this });
  201. }
  202. void _writevCB(int i) {
  203. writing = false;
  204. if (i <= 0) {
  205. closed = true;
  206. return;
  207. }
  208. if (writeQueued) {
  209. writeQueued = false;
  210. beginFlush();
  211. }
  212. }
  213. };
  214. String ws_beginWriteFrame(FrameWriter& fw, int len);
  215. void ws_endWriteFrame(FrameWriter& fw, String buf, int opcode);
  216. struct Page;
  217. struct Request;
  218. void ws_init(Page& p, CP::Callback cb);
  219. bool ws_iswebsocket(const cppsp::Request& req);
  220. }