SChannelConnection.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. #define SECURITY_WIN32
  2. #define NOMINMAX
  3. #include "SChannelConnection.h"
  4. #ifdef HTTPS_BACKEND_SCHANNEL
  5. #include <windows.h>
  6. #include <security.h>
  7. #include <schnlsp.h>
  8. #include <assert.h>
  9. #include <algorithm>
  10. #include <memory>
  11. #include <array>
  12. #ifndef SCH_USE_STRONG_CRYPTO
  13. # define SCH_USE_STRONG_CRYPTO 0x00400000
  14. #endif
  15. #ifndef SP_PROT_TLS1_1_CLIENT
  16. # define SP_PROT_TLS1_1_CLIENT 0x00000200
  17. #endif
  18. #ifndef SP_PROT_TLS1_2_CLIENT
  19. # define SP_PROT_TLS1_2_CLIENT 0x00000800
  20. #endif
  21. #ifdef DEBUG_SCHANNEL
  22. #include <iostream>
  23. std::ostream &debug = std::cout;
  24. #else
  25. struct Debug
  26. {
  27. template<typename T>
  28. Debug &operator<<(const T&) { return *this; }
  29. } debug;
  30. #endif
  31. static void enqueue(std::vector<char> &buffer, char *data, size_t size)
  32. {
  33. size_t oldSize = buffer.size();
  34. buffer.resize(oldSize + size);
  35. memcpy(&buffer[oldSize], data, size);
  36. }
  37. static void enqueue_prepend(std::vector<char> &buffer, char *data, size_t size)
  38. {
  39. size_t oldSize = buffer.size();
  40. buffer.resize(oldSize + size);
  41. if (oldSize > 0)
  42. memmove(&buffer[size], &buffer[0], oldSize);
  43. memcpy(&buffer[0], data, size);
  44. }
  45. static size_t dequeue(std::vector<char> &buffer, char *data, size_t size)
  46. {
  47. size = std::min(size, buffer.size());
  48. size_t remaining = buffer.size() - size;
  49. memcpy(data, &buffer[0], size);
  50. if (remaining > 0)
  51. {
  52. memmove(&buffer[0], &buffer[size], remaining);
  53. buffer.resize(remaining);
  54. }
  55. else
  56. {
  57. buffer.resize(0);
  58. }
  59. return size;
  60. }
  61. SChannelConnection::SChannelConnection()
  62. : context(nullptr)
  63. {
  64. }
  65. SChannelConnection::~SChannelConnection()
  66. {
  67. destroyContext();
  68. }
  69. SECURITY_STATUS InitializeSecurityContext(CredHandle *phCredential, std::unique_ptr<CtxtHandle>& phContext, const std::string& szTargetName, ULONG fContextReq, std::vector<char>& inputBuffer, std::vector<char>& outputBuffer, ULONG *pfContextAttr)
  70. {
  71. std::array<SecBuffer, 1> recvBuffers;
  72. recvBuffers[0].BufferType = SECBUFFER_TOKEN;
  73. recvBuffers[0].pvBuffer = outputBuffer.data();
  74. recvBuffers[0].cbBuffer = outputBuffer.size();
  75. std::array<SecBuffer, 2> sendBuffers;
  76. sendBuffers[0].BufferType = SECBUFFER_TOKEN;
  77. sendBuffers[0].pvBuffer = inputBuffer.data();
  78. sendBuffers[0].cbBuffer = inputBuffer.size();
  79. sendBuffers[1].BufferType = SECBUFFER_EMPTY;
  80. sendBuffers[1].pvBuffer = nullptr;
  81. sendBuffers[1].cbBuffer = 0;
  82. SecBufferDesc recvBufferDesc, sendBufferDesc;
  83. recvBufferDesc.ulVersion = sendBufferDesc.ulVersion = SECBUFFER_VERSION;
  84. recvBufferDesc.pBuffers = &recvBuffers[0];
  85. recvBufferDesc.cBuffers = recvBuffers.size();
  86. if (!inputBuffer.empty())
  87. {
  88. sendBufferDesc.pBuffers = &sendBuffers[0];
  89. sendBufferDesc.cBuffers = sendBuffers.size();
  90. }
  91. else
  92. {
  93. sendBufferDesc.pBuffers = nullptr;
  94. sendBufferDesc.cBuffers = 0;
  95. }
  96. CtxtHandle* phOldContext = nullptr;
  97. CtxtHandle* phNewContext = nullptr;
  98. if (!phContext)
  99. {
  100. phContext = std::make_unique<CtxtHandle>();
  101. phNewContext = phContext.get();
  102. }
  103. else
  104. {
  105. phOldContext = phContext.get();
  106. }
  107. auto ret = InitializeSecurityContext(phCredential, phOldContext, const_cast<char*>(szTargetName.c_str()), fContextReq, 0, 0, &sendBufferDesc, 0, phNewContext, &recvBufferDesc, pfContextAttr, nullptr);
  108. outputBuffer.resize(recvBuffers[0].cbBuffer);
  109. // Clear the input buffer, so the reader can append
  110. // If we have unprocessed data, leave it in the buffer
  111. size_t unprocessed = 0;
  112. if (sendBuffers[1].BufferType == SECBUFFER_EXTRA)
  113. unprocessed = sendBuffers[1].cbBuffer;
  114. if (unprocessed > 0)
  115. memmove(inputBuffer.data(), inputBuffer.data() + inputBuffer.size() - unprocessed, unprocessed);
  116. inputBuffer.resize(unprocessed);
  117. return ret;
  118. }
  119. bool SChannelConnection::connect(const std::string &hostname, uint16_t port)
  120. {
  121. debug << "Trying to connect to " << hostname << ":" << port << "\n";
  122. if (!socket.connect(hostname, port))
  123. return false;
  124. debug << "Connected\n";
  125. SCHANNEL_CRED cred;
  126. memset(&cred, 0, sizeof(cred));
  127. cred.dwVersion = SCHANNEL_CRED_VERSION;
  128. cred.grbitEnabledProtocols = SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT;
  129. cred.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | SCH_CRED_NO_DEFAULT_CREDS | SCH_USE_STRONG_CRYPTO | SCH_CRED_REVOCATION_CHECK_CHAIN;
  130. CredHandle credHandle;
  131. if (AcquireCredentialsHandle(nullptr, (char*) UNISP_NAME, SECPKG_CRED_OUTBOUND, nullptr, &cred, nullptr, nullptr, &credHandle, nullptr) != SEC_E_OK)
  132. {
  133. debug << "Failed to acquire handle\n";
  134. socket.close();
  135. return false;
  136. }
  137. debug << "Acquired handle\n";
  138. static constexpr size_t bufferSize = 8192;
  139. bool done = false, success = false, contextCreated = false;
  140. ULONG contextAttr;
  141. std::unique_ptr<CtxtHandle> context;
  142. std::vector<char> inputBuffer;
  143. std::vector<char> outputBuffer;
  144. do
  145. {
  146. outputBuffer.resize(bufferSize);
  147. bool recvData = false;
  148. bool sendData = false;
  149. auto ret = InitializeSecurityContext(&credHandle, context, hostname, ISC_REQ_STREAM, inputBuffer, outputBuffer, &contextAttr);
  150. switch (ret)
  151. {
  152. /*case SEC_I_COMPLETE_NEEDED:
  153. case SEC_I_COMPLETE_AND_CONTINUE:
  154. if (CompleteAuthToken(context.get(), &outputBuffer) != SEC_E_OK)
  155. done = true;
  156. else if (ret == SEC_I_COMPLETE_NEEDED)
  157. success = done = true;
  158. break;*/
  159. case SEC_I_CONTINUE_NEEDED:
  160. debug << "Initialize: continue needed\n";
  161. recvData = true;
  162. sendData = true;
  163. break;
  164. case SEC_E_INCOMPLETE_CREDENTIALS:
  165. debug << "Initialize failed: incomplete credentials\n";
  166. done = true;
  167. break;
  168. case SEC_E_INCOMPLETE_MESSAGE:
  169. debug << "Initialize: incomplete message\n";
  170. recvData = true;
  171. break;
  172. case SEC_E_OK:
  173. debug << "Initialize succeeded\n";
  174. success = done = true;
  175. sendData = true;
  176. break;
  177. default:
  178. debug << "Initialize done: " << outputBuffer.size() << " bytes of output and unknown status " << ret << "\n";
  179. done = true;
  180. success = false;
  181. break;
  182. }
  183. if (!done)
  184. contextCreated = true;
  185. if (sendData && !outputBuffer.empty())
  186. {
  187. socket.write(outputBuffer.data(), outputBuffer.size());
  188. debug << "Sent " << outputBuffer.size() << " bytes of data\n";
  189. }
  190. if (recvData)
  191. {
  192. size_t unprocessed = inputBuffer.size();
  193. inputBuffer.resize(unprocessed + bufferSize);
  194. size_t actual = socket.read(inputBuffer.data() + unprocessed, bufferSize);
  195. inputBuffer.resize(actual + unprocessed);
  196. debug << "Received " << actual << " bytes of data\n";
  197. if (unprocessed > 0)
  198. debug << " had " << unprocessed << " bytes of remaining, unprocessed data\n";
  199. if (actual + unprocessed == 0)
  200. {
  201. debug << "No data to submit, break\n";
  202. break;
  203. }
  204. }
  205. } while (!done);
  206. debug << "Done!\n";
  207. if (success)
  208. {
  209. SecPkgContext_Flags resultFlags;
  210. QueryContextAttributes(context.get(), SECPKG_ATTR_FLAGS, &resultFlags);
  211. if ((resultFlags.Flags & ISC_REQ_CONFIDENTIALITY) == 0)
  212. {
  213. debug << "Resulting context is not encrypted, marking as failed\n";
  214. success = false;
  215. }
  216. if ((resultFlags.Flags & ISC_REQ_INTEGRITY) == 0)
  217. {
  218. debug << "Resulting context is not signed, marking as failed\n";
  219. success = false;
  220. }
  221. }
  222. if (success)
  223. this->context = context.release();
  224. else if (contextCreated)
  225. DeleteSecurityContext(context.get());
  226. return success;
  227. }
  228. size_t SChannelConnection::read(char *buffer, size_t size)
  229. {
  230. if (decRecvBuffer.size() > 0)
  231. {
  232. size = dequeue(decRecvBuffer, buffer, size);
  233. debug << "Read " << size << " bytes of previously decoded data\n";
  234. return size;
  235. }
  236. else if (encRecvBuffer.size() > 0)
  237. {
  238. size = dequeue(encRecvBuffer, buffer, size);
  239. debug << "Read " << size << " bytes of extra data\n";
  240. }
  241. else
  242. {
  243. size = socket.read(buffer, size);
  244. debug << "Received " << size << " bytes of data\n";
  245. }
  246. return decrypt(buffer, size);
  247. }
  248. size_t SChannelConnection::decrypt(char *buffer, size_t size, bool recurse)
  249. {
  250. if (size == 0)
  251. return 0;
  252. SecBuffer secBuffers[4];
  253. secBuffers[0].cbBuffer = size;
  254. secBuffers[0].BufferType = SECBUFFER_DATA;
  255. secBuffers[0].pvBuffer = buffer;
  256. for (size_t i = 1; i < 4; ++i)
  257. {
  258. secBuffers[i].BufferType = SECBUFFER_EMPTY;
  259. secBuffers[i].pvBuffer = nullptr;
  260. secBuffers[i].cbBuffer = 0;
  261. }
  262. SecBufferDesc secBufferDesc;
  263. secBufferDesc.ulVersion = SECBUFFER_VERSION;
  264. secBufferDesc.cBuffers = 4;
  265. secBufferDesc.pBuffers = &secBuffers[0];
  266. auto ret = DecryptMessage(static_cast<CtxtHandle*>(context), &secBufferDesc, 0, nullptr); // FIXME
  267. debug << "DecryptMessage returns: " << ret << "\n";
  268. switch (ret)
  269. {
  270. case SEC_E_OK:
  271. {
  272. void *actualDataStart = buffer;
  273. for (size_t i = 0; i < 4; ++i)
  274. {
  275. auto &buffer = secBuffers[i];
  276. if (buffer.BufferType == SECBUFFER_DATA)
  277. {
  278. actualDataStart = buffer.pvBuffer;
  279. size = buffer.cbBuffer;
  280. }
  281. else if (buffer.BufferType == SECBUFFER_EXTRA)
  282. {
  283. debug << "\tExtra data in buffer " << i << " (" << buffer.cbBuffer << " bytes)\n";
  284. enqueue(encRecvBuffer, static_cast<char*>(buffer.pvBuffer), buffer.cbBuffer);
  285. }
  286. else if (buffer.BufferType != SECBUFFER_EMPTY)
  287. debug << "\tBuffer of type " << buffer.BufferType << "\n";
  288. }
  289. if (actualDataStart)
  290. memmove(buffer, actualDataStart, size);
  291. break;
  292. }
  293. case SEC_E_INCOMPLETE_MESSAGE:
  294. {
  295. // Move all our current data to encRecvBuffer
  296. enqueue(encRecvBuffer, buffer, size);
  297. // Now try to read some more data from the socket
  298. size_t bufferSize = encRecvBuffer.size() + 8192;
  299. char *recvBuffer = new char[bufferSize];
  300. size_t recvd = socket.read(recvBuffer+encRecvBuffer.size(), 8192);
  301. debug << recvd << " bytes of extra data read from socket\n";
  302. if (recvd == 0 && !recurse)
  303. {
  304. debug << "Recursion prevented, bailing\n";
  305. return 0;
  306. }
  307. // Fill our buffer with the queued data and the newly received data
  308. size_t totalSize = encRecvBuffer.size() + recvd;
  309. dequeue(encRecvBuffer, recvBuffer, encRecvBuffer.size());
  310. debug << "Trying to decrypt with " << totalSize << " bytes of data\n";
  311. // Now try to decrypt that
  312. size_t decrypted = decrypt(recvBuffer, totalSize, false);
  313. debug << "\tObtained " << decrypted << " bytes of decrypted data\n";
  314. // Copy the first size bytes to the output buffer
  315. size = std::min(size, decrypted);
  316. memcpy(buffer, recvBuffer, size);
  317. // And write the remainder to our queued decrypted data...
  318. // Note: we prepend, since our recursive call may already have written
  319. // something and we can be sure decrypt wasn't called if the buffer was
  320. // non-empty in read
  321. enqueue_prepend(decRecvBuffer, recvBuffer+size, decrypted-size);
  322. debug << "\tStoring " << decrypted-size << " bytes of extra decrypted data\n";
  323. return size;
  324. }
  325. // TODO: More?
  326. default:
  327. size = 0;
  328. break;
  329. }
  330. debug << "\tDecrypted " << size << " bytes of data\n";
  331. return size;
  332. }
  333. size_t SChannelConnection::write(const char *buffer, size_t size)
  334. {
  335. static constexpr size_t bufferSize = 8192;
  336. assert(size <= bufferSize);
  337. SecPkgContext_StreamSizes Sizes;
  338. QueryContextAttributes(
  339. static_cast<CtxtHandle*>(context),
  340. SECPKG_ATTR_STREAM_SIZES,
  341. &Sizes);
  342. debug << "stream sizes:\n\theader: " << Sizes.cbHeader << "\n\tfooter: " << Sizes.cbTrailer << "\n";
  343. char *sendBuffer = new char[bufferSize + Sizes.cbHeader + Sizes.cbTrailer];
  344. memcpy(sendBuffer+Sizes.cbHeader, buffer, size);
  345. SecBuffer secBuffers[4];
  346. secBuffers[0].cbBuffer = Sizes.cbHeader;
  347. secBuffers[0].BufferType = SECBUFFER_STREAM_HEADER;
  348. secBuffers[0].pvBuffer = sendBuffer;
  349. secBuffers[1].cbBuffer = size;
  350. secBuffers[1].BufferType = SECBUFFER_DATA;
  351. secBuffers[1].pvBuffer = sendBuffer+Sizes.cbHeader;
  352. secBuffers[2].cbBuffer = Sizes.cbTrailer;
  353. secBuffers[2].pvBuffer = sendBuffer+Sizes.cbHeader+size;
  354. secBuffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
  355. secBuffers[3].cbBuffer = 0;
  356. secBuffers[3].BufferType = SECBUFFER_EMPTY;
  357. secBuffers[3].pvBuffer = nullptr;
  358. SecBufferDesc secBufferDesc;
  359. secBufferDesc.ulVersion = SECBUFFER_VERSION;
  360. secBufferDesc.cBuffers = 4;
  361. secBufferDesc.pBuffers = secBuffers;
  362. auto ret = EncryptMessage(static_cast<CtxtHandle*>(context), 0, &secBufferDesc, 0); // FIXME
  363. debug << "Send:\n\tHeader size: " << secBuffers[0].cbBuffer << "\n\t\ttype: " << secBuffers[0].BufferType << "\n\tData size: " << secBuffers[1].cbBuffer << "\n\t\ttype: " << secBuffers[1].BufferType << "\n\tFooter size: " << secBuffers[2].cbBuffer << "\n\t\ttype: " << secBuffers[2].BufferType << "\n";
  364. size_t sendSize = 0;
  365. for (size_t i = 0; i < 4; ++i)
  366. if (secBuffers[i].cbBuffer != bufferSize)
  367. sendSize += secBuffers[i].cbBuffer;
  368. debug << "\tReal length? " << sendSize << "\n";
  369. switch (ret)
  370. {
  371. case SEC_E_OK:
  372. socket.write(sendBuffer, sendSize);
  373. break;
  374. // TODO: More?
  375. default:
  376. size = 0;
  377. break;
  378. }
  379. delete[] sendBuffer;
  380. return size;
  381. }
  382. void SChannelConnection::destroyContext()
  383. {
  384. if (context)
  385. {
  386. DeleteSecurityContext(context);
  387. delete context;
  388. context = nullptr;
  389. }
  390. }
  391. void SChannelConnection::close()
  392. {
  393. destroyContext();
  394. socket.close();
  395. }
  396. bool SChannelConnection::valid()
  397. {
  398. return true;
  399. }
  400. #endif // HTTPS_BACKEND_SCHANNEL