SChannelConnection.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. #define SECURITY_WIN32
  2. #include <windows.h>
  3. #include <security.h>
  4. #include <schnlsp.h>
  5. #include <assert.h>
  6. #include <algorithm>
  7. #include <memory>
  8. #include "common/config.h"
  9. #include "SChannelConnection.h"
  10. #ifndef SCH_USE_STRONG_CRYPTO
  11. # define SCH_USE_STRONG_CRYPTO 0x00400000
  12. #endif
  13. #ifndef SP_PROT_TLS1_1_CLIENT
  14. # define SP_PROT_TLS1_1_CLIENT 0x00000200
  15. #endif
  16. #ifndef SP_PROT_TLS1_2_CLIENT
  17. # define SP_PROT_TLS1_2_CLIENT 0x00000800
  18. #endif
  19. #ifdef DEBUG_SCHANNEL
  20. #include <iostream>
  21. std::ostream &debug = std::cout;
  22. #else
  23. struct Debug
  24. {
  25. template<typename T>
  26. Debug &operator<<(const T&) { return *this; }
  27. } debug;
  28. #endif
  29. static void enqueue(std::vector<char> &buffer, char *data, size_t size)
  30. {
  31. size_t oldSize = buffer.size();
  32. buffer.resize(oldSize + size);
  33. memcpy(&buffer[oldSize], data, size);
  34. }
  35. static void enqueue_prepend(std::vector<char> &buffer, char *data, size_t size)
  36. {
  37. size_t oldSize = buffer.size();
  38. buffer.resize(oldSize + size);
  39. memmove(&buffer[size], &buffer[0], oldSize);
  40. memcpy(&buffer[0], data, size);
  41. }
  42. static size_t dequeue(std::vector<char> &buffer, char *data, size_t size)
  43. {
  44. size = std::min(size, buffer.size());
  45. size_t remaining = buffer.size() - size;
  46. memcpy(data, &buffer[0], size);
  47. memmove(&buffer[0], &buffer[size], remaining);
  48. buffer.resize(remaining);
  49. return size;
  50. }
  51. SChannelConnection::SChannelConnection()
  52. : context(nullptr)
  53. {
  54. }
  55. SChannelConnection::~SChannelConnection()
  56. {
  57. // TODO?
  58. if (CtxtHandle *context = static_cast<CtxtHandle*>(this->context))
  59. {
  60. DeleteSecurityContext(context);
  61. delete context;
  62. }
  63. }
  64. SECURITY_STATUS InitializeSecurityContext(CredHandle *phCredential, std::unique_ptr<CtxtHandle>& phContext, const std::string& szTargetName, ULONG fContextReq, std::vector<char>& inputBuffer, std::vector<char>& outputBuffer, ULONG *pfContextAttr)
  65. {
  66. std::array<SecBuffer, 1> recvBuffers;
  67. recvBuffers[0].BufferType = SECBUFFER_TOKEN;
  68. recvBuffers[0].pvBuffer = outputBuffer.data();
  69. recvBuffers[0].cbBuffer = outputBuffer.size();
  70. std::array<SecBuffer, 2> sendBuffers;
  71. sendBuffers[0].BufferType = SECBUFFER_TOKEN;
  72. sendBuffers[0].pvBuffer = inputBuffer.data();
  73. sendBuffers[0].cbBuffer = inputBuffer.size();
  74. sendBuffers[1].BufferType = SECBUFFER_EMPTY;
  75. sendBuffers[1].pvBuffer = nullptr;
  76. sendBuffers[1].cbBuffer = 0;
  77. SecBufferDesc recvBufferDesc, sendBufferDesc;
  78. recvBufferDesc.ulVersion = sendBufferDesc.ulVersion = SECBUFFER_VERSION;
  79. recvBufferDesc.pBuffers = &recvBuffers[0];
  80. recvBufferDesc.cBuffers = recvBuffers.size();
  81. if (!inputBuffer.empty())
  82. {
  83. sendBufferDesc.pBuffers = &sendBuffers[0];
  84. sendBufferDesc.cBuffers = sendBuffers.size();
  85. }
  86. else
  87. {
  88. sendBufferDesc.pBuffers = nullptr;
  89. sendBufferDesc.cBuffers = 0;
  90. }
  91. CtxtHandle* phOldContext = nullptr;
  92. CtxtHandle* phNewContext = nullptr;
  93. if (!phContext)
  94. {
  95. phContext = std::make_unique<CtxtHandle>();
  96. phNewContext = phContext.get();
  97. }
  98. else
  99. {
  100. phOldContext = phContext.get();
  101. }
  102. auto ret = InitializeSecurityContext(phCredential, phOldContext, const_cast<char*>(szTargetName.c_str()), fContextReq, 0, 0, &sendBufferDesc, 0, phNewContext, &recvBufferDesc, pfContextAttr, nullptr);
  103. outputBuffer.resize(recvBuffers[0].cbBuffer);
  104. // Clear the input buffer, so the reader can append
  105. // If we have unprocessed data, leave it in the buffer
  106. size_t unprocessed = 0;
  107. if (sendBuffers[1].BufferType == SECBUFFER_EXTRA)
  108. unprocessed = sendBuffers[1].cbBuffer;
  109. if (unprocessed > 0)
  110. memmove(inputBuffer.data(), inputBuffer.data() + inputBuffer.size() - unprocessed, unprocessed);
  111. inputBuffer.resize(unprocessed);
  112. return ret;
  113. }
  114. bool SChannelConnection::connect(const std::string &hostname, uint16_t port)
  115. {
  116. debug << "Trying to connect to " << hostname << ":" << port << "\n";
  117. if (!socket.connect(hostname, port))
  118. return false;
  119. debug << "Connected\n";
  120. SCHANNEL_CRED cred;
  121. memset(&cred, 0, sizeof(cred));
  122. cred.dwVersion = SCHANNEL_CRED_VERSION;
  123. cred.grbitEnabledProtocols = SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT;
  124. cred.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | SCH_CRED_NO_DEFAULT_CREDS | SCH_USE_STRONG_CRYPTO | SCH_CRED_REVOCATION_CHECK_CHAIN;
  125. CredHandle credHandle;
  126. if (AcquireCredentialsHandle(nullptr, (char*) UNISP_NAME, SECPKG_CRED_OUTBOUND, nullptr, &cred, nullptr, nullptr, &credHandle, nullptr) != SEC_E_OK)
  127. {
  128. debug << "Failed to acquire handle\n";
  129. socket.close();
  130. return false;
  131. }
  132. debug << "Acquired handle\n";
  133. static constexpr size_t bufferSize = 8192;
  134. bool done = false, success = false, contextCreated = false;
  135. ULONG contextAttr;
  136. std::unique_ptr<CtxtHandle> context;
  137. std::vector<char> inputBuffer;
  138. std::vector<char> outputBuffer;
  139. do
  140. {
  141. outputBuffer.resize(bufferSize);
  142. bool recvData = false;
  143. bool sendData = false;
  144. auto ret = InitializeSecurityContext(&credHandle, context, hostname, ISC_REQ_STREAM, inputBuffer, outputBuffer, &contextAttr);
  145. switch (ret)
  146. {
  147. /*case SEC_I_COMPLETE_NEEDED:
  148. case SEC_I_COMPLETE_AND_CONTINUE:
  149. if (CompleteAuthToken(context.get(), &outputBuffer) != SEC_E_OK)
  150. done = true;
  151. else if (ret == SEC_I_COMPLETE_NEEDED)
  152. success = done = true;
  153. break;*/
  154. case SEC_I_CONTINUE_NEEDED:
  155. debug << "Initialize: continue needed\n";
  156. recvData = true;
  157. sendData = true;
  158. break;
  159. case SEC_E_INCOMPLETE_CREDENTIALS:
  160. debug << "Initialize failed: incomplete credentials\n";
  161. done = true;
  162. break;
  163. case SEC_E_INCOMPLETE_MESSAGE:
  164. debug << "Initialize: incomplete message\n";
  165. recvData = true;
  166. break;
  167. case SEC_E_OK:
  168. debug << "Initialize succeeded\n";
  169. success = done = true;
  170. sendData = true;
  171. break;
  172. default:
  173. debug << "Initialize done: " << outputBuffer.size() << " bytes of output and status " << ret << "\n";
  174. done = true;
  175. // TODO: error
  176. break;
  177. }
  178. if (!done)
  179. contextCreated = true;
  180. if (sendData && !outputBuffer.empty())
  181. {
  182. socket.write(outputBuffer.data(), outputBuffer.size());
  183. debug << "Sent " << outputBuffer.size() << " bytes of data\n";
  184. }
  185. if (recvData)
  186. {
  187. size_t unprocessed = inputBuffer.size();
  188. inputBuffer.resize(unprocessed + bufferSize);
  189. size_t actual = socket.read(inputBuffer.data() + unprocessed, bufferSize);
  190. inputBuffer.resize(actual + unprocessed);
  191. debug << "Received " << actual << " bytes of data\n";
  192. if (unprocessed > 0)
  193. debug << " had " << unprocessed << " bytes of remaining, unprocessed data\n";
  194. if (actual + unprocessed == 0)
  195. {
  196. debug << "No data to submit, break\n";
  197. break;
  198. }
  199. }
  200. // TODO: A bunch of frees?
  201. } while (!done);
  202. debug << "Done!\n";
  203. if (success)
  204. {
  205. SecPkgContext_Flags resultFlags;
  206. QueryContextAttributes(context.get(), SECPKG_ATTR_FLAGS, &resultFlags);
  207. if (resultFlags.Flags & ISC_REQ_CONFIDENTIALITY == 0)
  208. {
  209. debug << "Resulting context is not encrypted, marking as failed\n";
  210. success = false;
  211. }
  212. if (resultFlags.Flags & ISC_REQ_INTEGRITY == 0)
  213. {
  214. debug << "Resulting context is not signed, marking as failed\n";
  215. success = false;
  216. }
  217. }
  218. if (success)
  219. this->context = static_cast<void*>(context.release());
  220. else if (contextCreated)
  221. DeleteSecurityContext(context.get());
  222. return success;
  223. }
  224. size_t SChannelConnection::read(char *buffer, size_t size)
  225. {
  226. if (decRecvBuffer.size() > 0)
  227. {
  228. size = dequeue(decRecvBuffer, buffer, size);
  229. debug << "Read " << size << " bytes of previously decoded data\n";
  230. return size;
  231. }
  232. else if (encRecvBuffer.size() > 0)
  233. {
  234. size = dequeue(encRecvBuffer, buffer, size);
  235. debug << "Read " << size << " bytes of extra data\n";
  236. }
  237. else
  238. {
  239. size = socket.read(buffer, size);
  240. debug << "Received " << size << " bytes of data\n";
  241. }
  242. return decrypt(buffer, size);
  243. }
  244. size_t SChannelConnection::decrypt(char *buffer, size_t size, bool recurse)
  245. {
  246. if (size == 0)
  247. return 0;
  248. SecBuffer secBuffers[4];
  249. secBuffers[0].cbBuffer = size;
  250. secBuffers[0].BufferType = SECBUFFER_DATA;
  251. secBuffers[0].pvBuffer = buffer;
  252. for (size_t i = 1; i < 4; ++i)
  253. {
  254. secBuffers[i].BufferType = SECBUFFER_EMPTY;
  255. secBuffers[i].pvBuffer = nullptr;
  256. secBuffers[i].cbBuffer = 0;
  257. }
  258. SecBufferDesc secBufferDesc;
  259. secBufferDesc.ulVersion = SECBUFFER_VERSION;
  260. secBufferDesc.cBuffers = 4;
  261. secBufferDesc.pBuffers = &secBuffers[0];
  262. auto ret = DecryptMessage(static_cast<CtxtHandle*>(context), &secBufferDesc, 0, nullptr); // FIXME
  263. debug << "DecryptMessage returns: " << ret << "\n";
  264. switch (ret)
  265. {
  266. case SEC_E_OK:
  267. {
  268. void *actualDataStart = buffer;
  269. for (size_t i = 0; i < 4; ++i)
  270. {
  271. auto &buffer = secBuffers[i];
  272. if (buffer.BufferType == SECBUFFER_DATA)
  273. {
  274. actualDataStart = buffer.pvBuffer;
  275. size = buffer.cbBuffer;
  276. }
  277. else if (buffer.BufferType == SECBUFFER_EXTRA)
  278. {
  279. debug << "\tExtra data in buffer " << i << " (" << buffer.cbBuffer << " bytes)\n";
  280. enqueue(encRecvBuffer, static_cast<char*>(buffer.pvBuffer), buffer.cbBuffer);
  281. }
  282. else if (buffer.BufferType != SECBUFFER_EMPTY)
  283. debug << "\tBuffer of type " << buffer.BufferType << "\n";
  284. }
  285. if (actualDataStart)
  286. memmove(buffer, actualDataStart, size);
  287. break;
  288. }
  289. case SEC_E_INCOMPLETE_MESSAGE:
  290. {
  291. // Move all our current data to encRecvBuffer
  292. enqueue(encRecvBuffer, buffer, size);
  293. // Now try to read some more data from the socket
  294. size_t bufferSize = encRecvBuffer.size() + 8192;
  295. char *recvBuffer = new char[bufferSize];
  296. size_t recvd = socket.read(recvBuffer+encRecvBuffer.size(), 8192);
  297. debug << recvd << " bytes of extra data read from socket\n";
  298. if (recvd == 0 && !recurse)
  299. {
  300. debug << "Recursion prevented, bailing\n";
  301. return 0;
  302. }
  303. // Fill our buffer with the queued data and the newly received data
  304. size_t totalSize = encRecvBuffer.size() + recvd;
  305. dequeue(encRecvBuffer, recvBuffer, encRecvBuffer.size());
  306. debug << "Trying to decrypt with " << totalSize << " bytes of data\n";
  307. // Now try to decrypt that
  308. size_t decrypted = decrypt(recvBuffer, totalSize, false);
  309. debug << "\tObtained " << decrypted << " bytes of decrypted data\n";
  310. // Copy the first size bytes to the output buffer
  311. size = std::min(size, decrypted);
  312. memcpy(buffer, recvBuffer, size);
  313. // And write the remainder to our queued decrypted data...
  314. // Note: we prepend, since our recursive call may already have written
  315. // something and we can be sure decrypt wasn't called if the buffer was
  316. // non-empty in read
  317. enqueue_prepend(decRecvBuffer, recvBuffer+size, decrypted-size);
  318. debug << "\tStoring " << decrypted-size << " bytes of extra decrypted data\n";
  319. return size;
  320. }
  321. // TODO: More?
  322. default:
  323. size = 0;
  324. break;
  325. }
  326. debug << "\tDecrypted " << size << " bytes of data\n";
  327. return size;
  328. }
  329. size_t SChannelConnection::write(const char *buffer, size_t size)
  330. {
  331. static constexpr size_t bufferSize = 8192;
  332. assert(size <= bufferSize);
  333. SecPkgContext_StreamSizes Sizes;
  334. QueryContextAttributes(
  335. static_cast<CtxtHandle*>(context),
  336. SECPKG_ATTR_STREAM_SIZES,
  337. &Sizes);
  338. debug << "stream sizes:\n\theader: " << Sizes.cbHeader << "\n\tfooter: " << Sizes.cbTrailer << "\n";
  339. char *sendBuffer = new char[bufferSize + Sizes.cbHeader + Sizes.cbTrailer];
  340. memcpy(sendBuffer+Sizes.cbHeader, buffer, size);
  341. SecBuffer secBuffers[4];
  342. secBuffers[0].cbBuffer = Sizes.cbHeader;
  343. secBuffers[0].BufferType = SECBUFFER_STREAM_HEADER;
  344. secBuffers[0].pvBuffer = sendBuffer;
  345. secBuffers[1].cbBuffer = size;
  346. secBuffers[1].BufferType = SECBUFFER_DATA;
  347. secBuffers[1].pvBuffer = sendBuffer+Sizes.cbHeader;
  348. secBuffers[2].cbBuffer = Sizes.cbTrailer;
  349. secBuffers[2].pvBuffer = sendBuffer+Sizes.cbHeader+size;
  350. secBuffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
  351. secBuffers[3].cbBuffer = 0;
  352. secBuffers[3].BufferType = SECBUFFER_EMPTY;
  353. secBuffers[3].pvBuffer = nullptr;
  354. SecBufferDesc secBufferDesc;
  355. secBufferDesc.ulVersion = SECBUFFER_VERSION;
  356. secBufferDesc.cBuffers = 4;
  357. secBufferDesc.pBuffers = secBuffers;
  358. auto ret = EncryptMessage(static_cast<CtxtHandle*>(context), 0, &secBufferDesc, 0); // FIXME
  359. debug << "Send:\n\tHeader size: " << secBuffers[0].cbBuffer << "\n\t\ttype: " << secBuffers[0].BufferType << "\n\tData size: " << secBuffers[1].cbBuffer << "\n\t\ttype: " << secBuffers[1].BufferType << "\n\tFooter size: " << secBuffers[2].cbBuffer << "\n\t\ttype: " << secBuffers[2].BufferType << "\n";
  360. size_t sendSize = 0;
  361. for (size_t i = 0; i < 4; ++i)
  362. if (secBuffers[i].cbBuffer != bufferSize)
  363. sendSize += secBuffers[i].cbBuffer;
  364. debug << "\tReal length? " << sendSize << "\n";
  365. switch (ret)
  366. {
  367. case SEC_E_OK:
  368. socket.write(sendBuffer, sendSize);
  369. break;
  370. // TODO: More?
  371. default:
  372. size = 0;
  373. break;
  374. }
  375. delete[] sendBuffer;
  376. return size;
  377. }
  378. void SChannelConnection::close()
  379. {
  380. // TODO
  381. }
  382. bool SChannelConnection::valid()
  383. {
  384. return true;
  385. }