SChannelConnection.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. #define SECURITY_WIN32
  2. #include <windows.h>
  3. #include <security.h>
  4. #include <schnlsp.h>
  5. #include <assert.h>
  6. #include <algorithm>
  7. #include <memory>
  8. #include "common/config.h"
  9. #include "SChannelConnection.h"
  10. #ifndef SCH_USE_STRONG_CRYPTO
  11. # define SCH_USE_STRONG_CRYPTO 0x00400000
  12. #endif
  13. #ifndef SP_PROT_TLS1_1_CLIENT
  14. # define SP_PROT_TLS1_1_CLIENT 0x00000200
  15. #endif
  16. #ifndef SP_PROT_TLS1_2_CLIENT
  17. # define SP_PROT_TLS1_2_CLIENT 0x00000800
  18. #endif
  19. #ifdef DEBUG_SCHANNEL
  20. #include <iostream>
  21. std::ostream &debug = std::cout;
  22. #else
  23. struct Debug
  24. {
  25. template<typename T>
  26. Debug &operator<<(const T&) { return *this; }
  27. } debug;
  28. #endif
  29. static void enqueue(std::vector<char> &buffer, char *data, size_t size)
  30. {
  31. size_t oldSize = buffer.size();
  32. buffer.resize(oldSize + size);
  33. memcpy(&buffer[oldSize], data, size);
  34. }
  35. static void enqueue_prepend(std::vector<char> &buffer, char *data, size_t size)
  36. {
  37. size_t oldSize = buffer.size();
  38. buffer.resize(oldSize + size);
  39. memmove(&buffer[size], &buffer[0], oldSize);
  40. memcpy(&buffer[0], data, size);
  41. }
  42. static size_t dequeue(std::vector<char> &buffer, char *data, size_t size)
  43. {
  44. size = std::min(size, buffer.size());
  45. size_t remaining = buffer.size() - size;
  46. memcpy(data, &buffer[0], size);
  47. memmove(&buffer[0], &buffer[size], remaining);
  48. buffer.resize(remaining);
  49. return size;
  50. }
  51. SChannelConnection::SChannelConnection()
  52. : context(nullptr)
  53. {
  54. }
  55. SChannelConnection::~SChannelConnection()
  56. {
  57. destroyContext();
  58. }
  59. SECURITY_STATUS InitializeSecurityContext(CredHandle *phCredential, std::unique_ptr<CtxtHandle>& phContext, const std::string& szTargetName, ULONG fContextReq, std::vector<char>& inputBuffer, std::vector<char>& outputBuffer, ULONG *pfContextAttr)
  60. {
  61. std::array<SecBuffer, 1> recvBuffers;
  62. recvBuffers[0].BufferType = SECBUFFER_TOKEN;
  63. recvBuffers[0].pvBuffer = outputBuffer.data();
  64. recvBuffers[0].cbBuffer = outputBuffer.size();
  65. std::array<SecBuffer, 2> sendBuffers;
  66. sendBuffers[0].BufferType = SECBUFFER_TOKEN;
  67. sendBuffers[0].pvBuffer = inputBuffer.data();
  68. sendBuffers[0].cbBuffer = inputBuffer.size();
  69. sendBuffers[1].BufferType = SECBUFFER_EMPTY;
  70. sendBuffers[1].pvBuffer = nullptr;
  71. sendBuffers[1].cbBuffer = 0;
  72. SecBufferDesc recvBufferDesc, sendBufferDesc;
  73. recvBufferDesc.ulVersion = sendBufferDesc.ulVersion = SECBUFFER_VERSION;
  74. recvBufferDesc.pBuffers = &recvBuffers[0];
  75. recvBufferDesc.cBuffers = recvBuffers.size();
  76. if (!inputBuffer.empty())
  77. {
  78. sendBufferDesc.pBuffers = &sendBuffers[0];
  79. sendBufferDesc.cBuffers = sendBuffers.size();
  80. }
  81. else
  82. {
  83. sendBufferDesc.pBuffers = nullptr;
  84. sendBufferDesc.cBuffers = 0;
  85. }
  86. CtxtHandle* phOldContext = nullptr;
  87. CtxtHandle* phNewContext = nullptr;
  88. if (!phContext)
  89. {
  90. phContext = std::make_unique<CtxtHandle>();
  91. phNewContext = phContext.get();
  92. }
  93. else
  94. {
  95. phOldContext = phContext.get();
  96. }
  97. auto ret = InitializeSecurityContext(phCredential, phOldContext, const_cast<char*>(szTargetName.c_str()), fContextReq, 0, 0, &sendBufferDesc, 0, phNewContext, &recvBufferDesc, pfContextAttr, nullptr);
  98. outputBuffer.resize(recvBuffers[0].cbBuffer);
  99. // Clear the input buffer, so the reader can append
  100. // If we have unprocessed data, leave it in the buffer
  101. size_t unprocessed = 0;
  102. if (sendBuffers[1].BufferType == SECBUFFER_EXTRA)
  103. unprocessed = sendBuffers[1].cbBuffer;
  104. if (unprocessed > 0)
  105. memmove(inputBuffer.data(), inputBuffer.data() + inputBuffer.size() - unprocessed, unprocessed);
  106. inputBuffer.resize(unprocessed);
  107. return ret;
  108. }
  109. bool SChannelConnection::connect(const std::string &hostname, uint16_t port)
  110. {
  111. debug << "Trying to connect to " << hostname << ":" << port << "\n";
  112. if (!socket.connect(hostname, port))
  113. return false;
  114. debug << "Connected\n";
  115. SCHANNEL_CRED cred;
  116. memset(&cred, 0, sizeof(cred));
  117. cred.dwVersion = SCHANNEL_CRED_VERSION;
  118. cred.grbitEnabledProtocols = SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT;
  119. cred.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | SCH_CRED_NO_DEFAULT_CREDS | SCH_USE_STRONG_CRYPTO | SCH_CRED_REVOCATION_CHECK_CHAIN;
  120. CredHandle credHandle;
  121. if (AcquireCredentialsHandle(nullptr, (char*) UNISP_NAME, SECPKG_CRED_OUTBOUND, nullptr, &cred, nullptr, nullptr, &credHandle, nullptr) != SEC_E_OK)
  122. {
  123. debug << "Failed to acquire handle\n";
  124. socket.close();
  125. return false;
  126. }
  127. debug << "Acquired handle\n";
  128. static constexpr size_t bufferSize = 8192;
  129. bool done = false, success = false, contextCreated = false;
  130. ULONG contextAttr;
  131. std::unique_ptr<CtxtHandle> context;
  132. std::vector<char> inputBuffer;
  133. std::vector<char> outputBuffer;
  134. do
  135. {
  136. outputBuffer.resize(bufferSize);
  137. bool recvData = false;
  138. bool sendData = false;
  139. auto ret = InitializeSecurityContext(&credHandle, context, hostname, ISC_REQ_STREAM, inputBuffer, outputBuffer, &contextAttr);
  140. switch (ret)
  141. {
  142. /*case SEC_I_COMPLETE_NEEDED:
  143. case SEC_I_COMPLETE_AND_CONTINUE:
  144. if (CompleteAuthToken(context.get(), &outputBuffer) != SEC_E_OK)
  145. done = true;
  146. else if (ret == SEC_I_COMPLETE_NEEDED)
  147. success = done = true;
  148. break;*/
  149. case SEC_I_CONTINUE_NEEDED:
  150. debug << "Initialize: continue needed\n";
  151. recvData = true;
  152. sendData = true;
  153. break;
  154. case SEC_E_INCOMPLETE_CREDENTIALS:
  155. debug << "Initialize failed: incomplete credentials\n";
  156. done = true;
  157. break;
  158. case SEC_E_INCOMPLETE_MESSAGE:
  159. debug << "Initialize: incomplete message\n";
  160. recvData = true;
  161. break;
  162. case SEC_E_OK:
  163. debug << "Initialize succeeded\n";
  164. success = done = true;
  165. sendData = true;
  166. break;
  167. default:
  168. debug << "Initialize done: " << outputBuffer.size() << " bytes of output and unknown status " << ret << "\n";
  169. done = true;
  170. success = false;
  171. break;
  172. }
  173. if (!done)
  174. contextCreated = true;
  175. if (sendData && !outputBuffer.empty())
  176. {
  177. socket.write(outputBuffer.data(), outputBuffer.size());
  178. debug << "Sent " << outputBuffer.size() << " bytes of data\n";
  179. }
  180. if (recvData)
  181. {
  182. size_t unprocessed = inputBuffer.size();
  183. inputBuffer.resize(unprocessed + bufferSize);
  184. size_t actual = socket.read(inputBuffer.data() + unprocessed, bufferSize);
  185. inputBuffer.resize(actual + unprocessed);
  186. debug << "Received " << actual << " bytes of data\n";
  187. if (unprocessed > 0)
  188. debug << " had " << unprocessed << " bytes of remaining, unprocessed data\n";
  189. if (actual + unprocessed == 0)
  190. {
  191. debug << "No data to submit, break\n";
  192. break;
  193. }
  194. }
  195. } while (!done);
  196. debug << "Done!\n";
  197. if (success)
  198. {
  199. SecPkgContext_Flags resultFlags;
  200. QueryContextAttributes(context.get(), SECPKG_ATTR_FLAGS, &resultFlags);
  201. if (resultFlags.Flags & ISC_REQ_CONFIDENTIALITY == 0)
  202. {
  203. debug << "Resulting context is not encrypted, marking as failed\n";
  204. success = false;
  205. }
  206. if (resultFlags.Flags & ISC_REQ_INTEGRITY == 0)
  207. {
  208. debug << "Resulting context is not signed, marking as failed\n";
  209. success = false;
  210. }
  211. }
  212. if (success)
  213. this->context = context.release();
  214. else if (contextCreated)
  215. DeleteSecurityContext(context.get());
  216. return success;
  217. }
  218. size_t SChannelConnection::read(char *buffer, size_t size)
  219. {
  220. if (decRecvBuffer.size() > 0)
  221. {
  222. size = dequeue(decRecvBuffer, buffer, size);
  223. debug << "Read " << size << " bytes of previously decoded data\n";
  224. return size;
  225. }
  226. else if (encRecvBuffer.size() > 0)
  227. {
  228. size = dequeue(encRecvBuffer, buffer, size);
  229. debug << "Read " << size << " bytes of extra data\n";
  230. }
  231. else
  232. {
  233. size = socket.read(buffer, size);
  234. debug << "Received " << size << " bytes of data\n";
  235. }
  236. return decrypt(buffer, size);
  237. }
  238. size_t SChannelConnection::decrypt(char *buffer, size_t size, bool recurse)
  239. {
  240. if (size == 0)
  241. return 0;
  242. SecBuffer secBuffers[4];
  243. secBuffers[0].cbBuffer = size;
  244. secBuffers[0].BufferType = SECBUFFER_DATA;
  245. secBuffers[0].pvBuffer = buffer;
  246. for (size_t i = 1; i < 4; ++i)
  247. {
  248. secBuffers[i].BufferType = SECBUFFER_EMPTY;
  249. secBuffers[i].pvBuffer = nullptr;
  250. secBuffers[i].cbBuffer = 0;
  251. }
  252. SecBufferDesc secBufferDesc;
  253. secBufferDesc.ulVersion = SECBUFFER_VERSION;
  254. secBufferDesc.cBuffers = 4;
  255. secBufferDesc.pBuffers = &secBuffers[0];
  256. auto ret = DecryptMessage(static_cast<CtxtHandle*>(context), &secBufferDesc, 0, nullptr); // FIXME
  257. debug << "DecryptMessage returns: " << ret << "\n";
  258. switch (ret)
  259. {
  260. case SEC_E_OK:
  261. {
  262. void *actualDataStart = buffer;
  263. for (size_t i = 0; i < 4; ++i)
  264. {
  265. auto &buffer = secBuffers[i];
  266. if (buffer.BufferType == SECBUFFER_DATA)
  267. {
  268. actualDataStart = buffer.pvBuffer;
  269. size = buffer.cbBuffer;
  270. }
  271. else if (buffer.BufferType == SECBUFFER_EXTRA)
  272. {
  273. debug << "\tExtra data in buffer " << i << " (" << buffer.cbBuffer << " bytes)\n";
  274. enqueue(encRecvBuffer, static_cast<char*>(buffer.pvBuffer), buffer.cbBuffer);
  275. }
  276. else if (buffer.BufferType != SECBUFFER_EMPTY)
  277. debug << "\tBuffer of type " << buffer.BufferType << "\n";
  278. }
  279. if (actualDataStart)
  280. memmove(buffer, actualDataStart, size);
  281. break;
  282. }
  283. case SEC_E_INCOMPLETE_MESSAGE:
  284. {
  285. // Move all our current data to encRecvBuffer
  286. enqueue(encRecvBuffer, buffer, size);
  287. // Now try to read some more data from the socket
  288. size_t bufferSize = encRecvBuffer.size() + 8192;
  289. char *recvBuffer = new char[bufferSize];
  290. size_t recvd = socket.read(recvBuffer+encRecvBuffer.size(), 8192);
  291. debug << recvd << " bytes of extra data read from socket\n";
  292. if (recvd == 0 && !recurse)
  293. {
  294. debug << "Recursion prevented, bailing\n";
  295. return 0;
  296. }
  297. // Fill our buffer with the queued data and the newly received data
  298. size_t totalSize = encRecvBuffer.size() + recvd;
  299. dequeue(encRecvBuffer, recvBuffer, encRecvBuffer.size());
  300. debug << "Trying to decrypt with " << totalSize << " bytes of data\n";
  301. // Now try to decrypt that
  302. size_t decrypted = decrypt(recvBuffer, totalSize, false);
  303. debug << "\tObtained " << decrypted << " bytes of decrypted data\n";
  304. // Copy the first size bytes to the output buffer
  305. size = std::min(size, decrypted);
  306. memcpy(buffer, recvBuffer, size);
  307. // And write the remainder to our queued decrypted data...
  308. // Note: we prepend, since our recursive call may already have written
  309. // something and we can be sure decrypt wasn't called if the buffer was
  310. // non-empty in read
  311. enqueue_prepend(decRecvBuffer, recvBuffer+size, decrypted-size);
  312. debug << "\tStoring " << decrypted-size << " bytes of extra decrypted data\n";
  313. return size;
  314. }
  315. // TODO: More?
  316. default:
  317. size = 0;
  318. break;
  319. }
  320. debug << "\tDecrypted " << size << " bytes of data\n";
  321. return size;
  322. }
  323. size_t SChannelConnection::write(const char *buffer, size_t size)
  324. {
  325. static constexpr size_t bufferSize = 8192;
  326. assert(size <= bufferSize);
  327. SecPkgContext_StreamSizes Sizes;
  328. QueryContextAttributes(
  329. static_cast<CtxtHandle*>(context),
  330. SECPKG_ATTR_STREAM_SIZES,
  331. &Sizes);
  332. debug << "stream sizes:\n\theader: " << Sizes.cbHeader << "\n\tfooter: " << Sizes.cbTrailer << "\n";
  333. char *sendBuffer = new char[bufferSize + Sizes.cbHeader + Sizes.cbTrailer];
  334. memcpy(sendBuffer+Sizes.cbHeader, buffer, size);
  335. SecBuffer secBuffers[4];
  336. secBuffers[0].cbBuffer = Sizes.cbHeader;
  337. secBuffers[0].BufferType = SECBUFFER_STREAM_HEADER;
  338. secBuffers[0].pvBuffer = sendBuffer;
  339. secBuffers[1].cbBuffer = size;
  340. secBuffers[1].BufferType = SECBUFFER_DATA;
  341. secBuffers[1].pvBuffer = sendBuffer+Sizes.cbHeader;
  342. secBuffers[2].cbBuffer = Sizes.cbTrailer;
  343. secBuffers[2].pvBuffer = sendBuffer+Sizes.cbHeader+size;
  344. secBuffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
  345. secBuffers[3].cbBuffer = 0;
  346. secBuffers[3].BufferType = SECBUFFER_EMPTY;
  347. secBuffers[3].pvBuffer = nullptr;
  348. SecBufferDesc secBufferDesc;
  349. secBufferDesc.ulVersion = SECBUFFER_VERSION;
  350. secBufferDesc.cBuffers = 4;
  351. secBufferDesc.pBuffers = secBuffers;
  352. auto ret = EncryptMessage(static_cast<CtxtHandle*>(context), 0, &secBufferDesc, 0); // FIXME
  353. debug << "Send:\n\tHeader size: " << secBuffers[0].cbBuffer << "\n\t\ttype: " << secBuffers[0].BufferType << "\n\tData size: " << secBuffers[1].cbBuffer << "\n\t\ttype: " << secBuffers[1].BufferType << "\n\tFooter size: " << secBuffers[2].cbBuffer << "\n\t\ttype: " << secBuffers[2].BufferType << "\n";
  354. size_t sendSize = 0;
  355. for (size_t i = 0; i < 4; ++i)
  356. if (secBuffers[i].cbBuffer != bufferSize)
  357. sendSize += secBuffers[i].cbBuffer;
  358. debug << "\tReal length? " << sendSize << "\n";
  359. switch (ret)
  360. {
  361. case SEC_E_OK:
  362. socket.write(sendBuffer, sendSize);
  363. break;
  364. // TODO: More?
  365. default:
  366. size = 0;
  367. break;
  368. }
  369. delete[] sendBuffer;
  370. return size;
  371. }
  372. void SChannelConnection::destroyContext()
  373. {
  374. if (context)
  375. {
  376. DeleteSecurityContext(context);
  377. delete context;
  378. context = nullptr;
  379. }
  380. }
  381. void SChannelConnection::close()
  382. {
  383. destroyContext();
  384. socket.close();
  385. }
  386. bool SChannelConnection::valid()
  387. {
  388. return true;
  389. }