SChannelConnection.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. #define SECURITY_WIN32
  2. #define NOMINMAX
  3. #include "SChannelConnection.h"
  4. #ifdef HTTPS_BACKEND_SCHANNEL
  5. #include <windows.h>
  6. #include <security.h>
  7. #include <schnlsp.h>
  8. #include <assert.h>
  9. #include <algorithm>
  10. #include <memory>
  11. #include <array>
  12. #ifndef SCH_USE_STRONG_CRYPTO
  13. # define SCH_USE_STRONG_CRYPTO 0x00400000
  14. #endif
  15. #ifndef SP_PROT_TLS1_1_CLIENT
  16. # define SP_PROT_TLS1_1_CLIENT 0x00000200
  17. #endif
  18. #ifndef SP_PROT_TLS1_2_CLIENT
  19. # define SP_PROT_TLS1_2_CLIENT 0x00000800
  20. #endif
  21. #ifdef DEBUG_SCHANNEL
  22. #include <iostream>
  23. std::ostream &debug = std::cout;
  24. #else
  25. struct Debug
  26. {
  27. template<typename T>
  28. Debug &operator<<(const T&) { return *this; }
  29. } debug;
  30. #endif
  31. static void enqueue(std::vector<char> &buffer, char *data, size_t size)
  32. {
  33. size_t oldSize = buffer.size();
  34. buffer.resize(oldSize + size);
  35. memcpy(&buffer[oldSize], data, size);
  36. }
  37. static void enqueue_prepend(std::vector<char> &buffer, char *data, size_t size)
  38. {
  39. size_t oldSize = buffer.size();
  40. buffer.resize(oldSize + size);
  41. memmove(&buffer[size], &buffer[0], oldSize);
  42. memcpy(&buffer[0], data, size);
  43. }
  44. static size_t dequeue(std::vector<char> &buffer, char *data, size_t size)
  45. {
  46. size = std::min(size, buffer.size());
  47. size_t remaining = buffer.size() - size;
  48. memcpy(data, &buffer[0], size);
  49. if (remaining > 0)
  50. {
  51. memmove(&buffer[0], &buffer[size], remaining);
  52. buffer.resize(remaining);
  53. }
  54. return size;
  55. }
  56. SChannelConnection::SChannelConnection()
  57. : context(nullptr)
  58. {
  59. }
  60. SChannelConnection::~SChannelConnection()
  61. {
  62. destroyContext();
  63. }
  64. SECURITY_STATUS InitializeSecurityContext(CredHandle *phCredential, std::unique_ptr<CtxtHandle>& phContext, const std::string& szTargetName, ULONG fContextReq, std::vector<char>& inputBuffer, std::vector<char>& outputBuffer, ULONG *pfContextAttr)
  65. {
  66. std::array<SecBuffer, 1> recvBuffers;
  67. recvBuffers[0].BufferType = SECBUFFER_TOKEN;
  68. recvBuffers[0].pvBuffer = outputBuffer.data();
  69. recvBuffers[0].cbBuffer = outputBuffer.size();
  70. std::array<SecBuffer, 2> sendBuffers;
  71. sendBuffers[0].BufferType = SECBUFFER_TOKEN;
  72. sendBuffers[0].pvBuffer = inputBuffer.data();
  73. sendBuffers[0].cbBuffer = inputBuffer.size();
  74. sendBuffers[1].BufferType = SECBUFFER_EMPTY;
  75. sendBuffers[1].pvBuffer = nullptr;
  76. sendBuffers[1].cbBuffer = 0;
  77. SecBufferDesc recvBufferDesc, sendBufferDesc;
  78. recvBufferDesc.ulVersion = sendBufferDesc.ulVersion = SECBUFFER_VERSION;
  79. recvBufferDesc.pBuffers = &recvBuffers[0];
  80. recvBufferDesc.cBuffers = recvBuffers.size();
  81. if (!inputBuffer.empty())
  82. {
  83. sendBufferDesc.pBuffers = &sendBuffers[0];
  84. sendBufferDesc.cBuffers = sendBuffers.size();
  85. }
  86. else
  87. {
  88. sendBufferDesc.pBuffers = nullptr;
  89. sendBufferDesc.cBuffers = 0;
  90. }
  91. CtxtHandle* phOldContext = nullptr;
  92. CtxtHandle* phNewContext = nullptr;
  93. if (!phContext)
  94. {
  95. phContext = std::make_unique<CtxtHandle>();
  96. phNewContext = phContext.get();
  97. }
  98. else
  99. {
  100. phOldContext = phContext.get();
  101. }
  102. auto ret = InitializeSecurityContext(phCredential, phOldContext, const_cast<char*>(szTargetName.c_str()), fContextReq, 0, 0, &sendBufferDesc, 0, phNewContext, &recvBufferDesc, pfContextAttr, nullptr);
  103. outputBuffer.resize(recvBuffers[0].cbBuffer);
  104. // Clear the input buffer, so the reader can append
  105. // If we have unprocessed data, leave it in the buffer
  106. size_t unprocessed = 0;
  107. if (sendBuffers[1].BufferType == SECBUFFER_EXTRA)
  108. unprocessed = sendBuffers[1].cbBuffer;
  109. if (unprocessed > 0)
  110. memmove(inputBuffer.data(), inputBuffer.data() + inputBuffer.size() - unprocessed, unprocessed);
  111. inputBuffer.resize(unprocessed);
  112. return ret;
  113. }
  114. bool SChannelConnection::connect(const std::string &hostname, uint16_t port)
  115. {
  116. debug << "Trying to connect to " << hostname << ":" << port << "\n";
  117. if (!socket.connect(hostname, port))
  118. return false;
  119. debug << "Connected\n";
  120. SCHANNEL_CRED cred;
  121. memset(&cred, 0, sizeof(cred));
  122. cred.dwVersion = SCHANNEL_CRED_VERSION;
  123. cred.grbitEnabledProtocols = SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT;
  124. cred.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | SCH_CRED_NO_DEFAULT_CREDS | SCH_USE_STRONG_CRYPTO | SCH_CRED_REVOCATION_CHECK_CHAIN;
  125. CredHandle credHandle;
  126. if (AcquireCredentialsHandle(nullptr, (char*) UNISP_NAME, SECPKG_CRED_OUTBOUND, nullptr, &cred, nullptr, nullptr, &credHandle, nullptr) != SEC_E_OK)
  127. {
  128. debug << "Failed to acquire handle\n";
  129. socket.close();
  130. return false;
  131. }
  132. debug << "Acquired handle\n";
  133. static constexpr size_t bufferSize = 8192;
  134. bool done = false, success = false, contextCreated = false;
  135. ULONG contextAttr;
  136. std::unique_ptr<CtxtHandle> context;
  137. std::vector<char> inputBuffer;
  138. std::vector<char> outputBuffer;
  139. do
  140. {
  141. outputBuffer.resize(bufferSize);
  142. bool recvData = false;
  143. bool sendData = false;
  144. auto ret = InitializeSecurityContext(&credHandle, context, hostname, ISC_REQ_STREAM, inputBuffer, outputBuffer, &contextAttr);
  145. switch (ret)
  146. {
  147. /*case SEC_I_COMPLETE_NEEDED:
  148. case SEC_I_COMPLETE_AND_CONTINUE:
  149. if (CompleteAuthToken(context.get(), &outputBuffer) != SEC_E_OK)
  150. done = true;
  151. else if (ret == SEC_I_COMPLETE_NEEDED)
  152. success = done = true;
  153. break;*/
  154. case SEC_I_CONTINUE_NEEDED:
  155. debug << "Initialize: continue needed\n";
  156. recvData = true;
  157. sendData = true;
  158. break;
  159. case SEC_E_INCOMPLETE_CREDENTIALS:
  160. debug << "Initialize failed: incomplete credentials\n";
  161. done = true;
  162. break;
  163. case SEC_E_INCOMPLETE_MESSAGE:
  164. debug << "Initialize: incomplete message\n";
  165. recvData = true;
  166. break;
  167. case SEC_E_OK:
  168. debug << "Initialize succeeded\n";
  169. success = done = true;
  170. sendData = true;
  171. break;
  172. default:
  173. debug << "Initialize done: " << outputBuffer.size() << " bytes of output and unknown status " << ret << "\n";
  174. done = true;
  175. success = false;
  176. break;
  177. }
  178. if (!done)
  179. contextCreated = true;
  180. if (sendData && !outputBuffer.empty())
  181. {
  182. socket.write(outputBuffer.data(), outputBuffer.size());
  183. debug << "Sent " << outputBuffer.size() << " bytes of data\n";
  184. }
  185. if (recvData)
  186. {
  187. size_t unprocessed = inputBuffer.size();
  188. inputBuffer.resize(unprocessed + bufferSize);
  189. size_t actual = socket.read(inputBuffer.data() + unprocessed, bufferSize);
  190. inputBuffer.resize(actual + unprocessed);
  191. debug << "Received " << actual << " bytes of data\n";
  192. if (unprocessed > 0)
  193. debug << " had " << unprocessed << " bytes of remaining, unprocessed data\n";
  194. if (actual + unprocessed == 0)
  195. {
  196. debug << "No data to submit, break\n";
  197. break;
  198. }
  199. }
  200. } while (!done);
  201. debug << "Done!\n";
  202. if (success)
  203. {
  204. SecPkgContext_Flags resultFlags;
  205. QueryContextAttributes(context.get(), SECPKG_ATTR_FLAGS, &resultFlags);
  206. if (resultFlags.Flags & ISC_REQ_CONFIDENTIALITY == 0)
  207. {
  208. debug << "Resulting context is not encrypted, marking as failed\n";
  209. success = false;
  210. }
  211. if (resultFlags.Flags & ISC_REQ_INTEGRITY == 0)
  212. {
  213. debug << "Resulting context is not signed, marking as failed\n";
  214. success = false;
  215. }
  216. }
  217. if (success)
  218. this->context = context.release();
  219. else if (contextCreated)
  220. DeleteSecurityContext(context.get());
  221. return success;
  222. }
  223. size_t SChannelConnection::read(char *buffer, size_t size)
  224. {
  225. if (decRecvBuffer.size() > 0)
  226. {
  227. size = dequeue(decRecvBuffer, buffer, size);
  228. debug << "Read " << size << " bytes of previously decoded data\n";
  229. return size;
  230. }
  231. else if (encRecvBuffer.size() > 0)
  232. {
  233. size = dequeue(encRecvBuffer, buffer, size);
  234. debug << "Read " << size << " bytes of extra data\n";
  235. }
  236. else
  237. {
  238. size = socket.read(buffer, size);
  239. debug << "Received " << size << " bytes of data\n";
  240. }
  241. return decrypt(buffer, size);
  242. }
  243. size_t SChannelConnection::decrypt(char *buffer, size_t size, bool recurse)
  244. {
  245. if (size == 0)
  246. return 0;
  247. SecBuffer secBuffers[4];
  248. secBuffers[0].cbBuffer = size;
  249. secBuffers[0].BufferType = SECBUFFER_DATA;
  250. secBuffers[0].pvBuffer = buffer;
  251. for (size_t i = 1; i < 4; ++i)
  252. {
  253. secBuffers[i].BufferType = SECBUFFER_EMPTY;
  254. secBuffers[i].pvBuffer = nullptr;
  255. secBuffers[i].cbBuffer = 0;
  256. }
  257. SecBufferDesc secBufferDesc;
  258. secBufferDesc.ulVersion = SECBUFFER_VERSION;
  259. secBufferDesc.cBuffers = 4;
  260. secBufferDesc.pBuffers = &secBuffers[0];
  261. auto ret = DecryptMessage(static_cast<CtxtHandle*>(context), &secBufferDesc, 0, nullptr); // FIXME
  262. debug << "DecryptMessage returns: " << ret << "\n";
  263. switch (ret)
  264. {
  265. case SEC_E_OK:
  266. {
  267. void *actualDataStart = buffer;
  268. for (size_t i = 0; i < 4; ++i)
  269. {
  270. auto &buffer = secBuffers[i];
  271. if (buffer.BufferType == SECBUFFER_DATA)
  272. {
  273. actualDataStart = buffer.pvBuffer;
  274. size = buffer.cbBuffer;
  275. }
  276. else if (buffer.BufferType == SECBUFFER_EXTRA)
  277. {
  278. debug << "\tExtra data in buffer " << i << " (" << buffer.cbBuffer << " bytes)\n";
  279. enqueue(encRecvBuffer, static_cast<char*>(buffer.pvBuffer), buffer.cbBuffer);
  280. }
  281. else if (buffer.BufferType != SECBUFFER_EMPTY)
  282. debug << "\tBuffer of type " << buffer.BufferType << "\n";
  283. }
  284. if (actualDataStart)
  285. memmove(buffer, actualDataStart, size);
  286. break;
  287. }
  288. case SEC_E_INCOMPLETE_MESSAGE:
  289. {
  290. // Move all our current data to encRecvBuffer
  291. enqueue(encRecvBuffer, buffer, size);
  292. // Now try to read some more data from the socket
  293. size_t bufferSize = encRecvBuffer.size() + 8192;
  294. char *recvBuffer = new char[bufferSize];
  295. size_t recvd = socket.read(recvBuffer+encRecvBuffer.size(), 8192);
  296. debug << recvd << " bytes of extra data read from socket\n";
  297. if (recvd == 0 && !recurse)
  298. {
  299. debug << "Recursion prevented, bailing\n";
  300. return 0;
  301. }
  302. // Fill our buffer with the queued data and the newly received data
  303. size_t totalSize = encRecvBuffer.size() + recvd;
  304. dequeue(encRecvBuffer, recvBuffer, encRecvBuffer.size());
  305. debug << "Trying to decrypt with " << totalSize << " bytes of data\n";
  306. // Now try to decrypt that
  307. size_t decrypted = decrypt(recvBuffer, totalSize, false);
  308. debug << "\tObtained " << decrypted << " bytes of decrypted data\n";
  309. // Copy the first size bytes to the output buffer
  310. size = std::min(size, decrypted);
  311. memcpy(buffer, recvBuffer, size);
  312. // And write the remainder to our queued decrypted data...
  313. // Note: we prepend, since our recursive call may already have written
  314. // something and we can be sure decrypt wasn't called if the buffer was
  315. // non-empty in read
  316. enqueue_prepend(decRecvBuffer, recvBuffer+size, decrypted-size);
  317. debug << "\tStoring " << decrypted-size << " bytes of extra decrypted data\n";
  318. return size;
  319. }
  320. // TODO: More?
  321. default:
  322. size = 0;
  323. break;
  324. }
  325. debug << "\tDecrypted " << size << " bytes of data\n";
  326. return size;
  327. }
  328. size_t SChannelConnection::write(const char *buffer, size_t size)
  329. {
  330. static constexpr size_t bufferSize = 8192;
  331. assert(size <= bufferSize);
  332. SecPkgContext_StreamSizes Sizes;
  333. QueryContextAttributes(
  334. static_cast<CtxtHandle*>(context),
  335. SECPKG_ATTR_STREAM_SIZES,
  336. &Sizes);
  337. debug << "stream sizes:\n\theader: " << Sizes.cbHeader << "\n\tfooter: " << Sizes.cbTrailer << "\n";
  338. char *sendBuffer = new char[bufferSize + Sizes.cbHeader + Sizes.cbTrailer];
  339. memcpy(sendBuffer+Sizes.cbHeader, buffer, size);
  340. SecBuffer secBuffers[4];
  341. secBuffers[0].cbBuffer = Sizes.cbHeader;
  342. secBuffers[0].BufferType = SECBUFFER_STREAM_HEADER;
  343. secBuffers[0].pvBuffer = sendBuffer;
  344. secBuffers[1].cbBuffer = size;
  345. secBuffers[1].BufferType = SECBUFFER_DATA;
  346. secBuffers[1].pvBuffer = sendBuffer+Sizes.cbHeader;
  347. secBuffers[2].cbBuffer = Sizes.cbTrailer;
  348. secBuffers[2].pvBuffer = sendBuffer+Sizes.cbHeader+size;
  349. secBuffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
  350. secBuffers[3].cbBuffer = 0;
  351. secBuffers[3].BufferType = SECBUFFER_EMPTY;
  352. secBuffers[3].pvBuffer = nullptr;
  353. SecBufferDesc secBufferDesc;
  354. secBufferDesc.ulVersion = SECBUFFER_VERSION;
  355. secBufferDesc.cBuffers = 4;
  356. secBufferDesc.pBuffers = secBuffers;
  357. auto ret = EncryptMessage(static_cast<CtxtHandle*>(context), 0, &secBufferDesc, 0); // FIXME
  358. debug << "Send:\n\tHeader size: " << secBuffers[0].cbBuffer << "\n\t\ttype: " << secBuffers[0].BufferType << "\n\tData size: " << secBuffers[1].cbBuffer << "\n\t\ttype: " << secBuffers[1].BufferType << "\n\tFooter size: " << secBuffers[2].cbBuffer << "\n\t\ttype: " << secBuffers[2].BufferType << "\n";
  359. size_t sendSize = 0;
  360. for (size_t i = 0; i < 4; ++i)
  361. if (secBuffers[i].cbBuffer != bufferSize)
  362. sendSize += secBuffers[i].cbBuffer;
  363. debug << "\tReal length? " << sendSize << "\n";
  364. switch (ret)
  365. {
  366. case SEC_E_OK:
  367. socket.write(sendBuffer, sendSize);
  368. break;
  369. // TODO: More?
  370. default:
  371. size = 0;
  372. break;
  373. }
  374. delete[] sendBuffer;
  375. return size;
  376. }
  377. void SChannelConnection::destroyContext()
  378. {
  379. if (context)
  380. {
  381. DeleteSecurityContext(context);
  382. delete context;
  383. context = nullptr;
  384. }
  385. }
  386. void SChannelConnection::close()
  387. {
  388. destroyContext();
  389. socket.close();
  390. }
  391. bool SChannelConnection::valid()
  392. {
  393. return true;
  394. }
  395. #endif // HTTPS_BACKEND_SCHANNEL