SChannelConnection.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. #define SECURITY_WIN32
  2. #include <windows.h>
  3. #include <security.h>
  4. #include <schnlsp.h>
  5. #include <assert.h>
  6. #include <algorithm>
  7. #include <memory>
  8. #include "common/config.h"
  9. #include "SChannelConnection.h"
  10. #ifndef SCH_USE_STRONG_CRYPTO
  11. # define SCH_USE_STRONG_CRYPTO 0x00400000
  12. #endif
  13. #ifndef SP_PROT_TLS1_1_CLIENT
  14. # define SP_PROT_TLS1_1_CLIENT 0x00000200
  15. #endif
  16. #ifndef SP_PROT_TLS1_2_CLIENT
  17. # define SP_PROT_TLS1_2_CLIENT 0x00000800
  18. #endif
  19. #ifdef DEBUG_SCHANNEL
  20. #include <iostream>
  21. std::ostream &debug = std::cout;
  22. #else
  23. struct Debug
  24. {
  25. template<typename T>
  26. Debug &operator<<(const T&) { return *this; }
  27. } debug;
  28. #endif
  29. static void enqueue(std::vector<char> &buffer, char *data, size_t size)
  30. {
  31. size_t oldSize = buffer.size();
  32. buffer.resize(oldSize + size);
  33. memcpy(&buffer[oldSize], data, size);
  34. }
  35. static void enqueue_prepend(std::vector<char> &buffer, char *data, size_t size)
  36. {
  37. size_t oldSize = buffer.size();
  38. buffer.resize(oldSize + size);
  39. memmove(&buffer[size], &buffer[0], oldSize);
  40. memcpy(&buffer[0], data, size);
  41. }
  42. static size_t dequeue(std::vector<char> &buffer, char *data, size_t size)
  43. {
  44. size = std::min(size, buffer.size());
  45. size_t remaining = buffer.size() - size;
  46. memcpy(data, &buffer[0], size);
  47. memmove(&buffer[0], &buffer[size], remaining);
  48. buffer.resize(remaining);
  49. return size;
  50. }
  51. SChannelConnection::SChannelConnection()
  52. : context(nullptr)
  53. {
  54. }
  55. SChannelConnection::~SChannelConnection()
  56. {
  57. // TODO?
  58. if (CtxtHandle *context = static_cast<CtxtHandle*>(this->context))
  59. {
  60. DeleteSecurityContext(context);
  61. delete context;
  62. }
  63. }
  64. SECURITY_STATUS InitializeSecurityContext(CredHandle *phCredential, std::unique_ptr<CtxtHandle>& phContext, const std::string& szTargetName, ULONG fContextReq, std::vector<char>& inputBuffer, std::vector<char>& outputBuffer, ULONG *pfContextAttr)
  65. {
  66. std::array<SecBuffer, 1> recvBuffers;
  67. recvBuffers[0].BufferType = SECBUFFER_TOKEN;
  68. recvBuffers[0].pvBuffer = outputBuffer.data();
  69. recvBuffers[0].cbBuffer = outputBuffer.size();
  70. std::array<SecBuffer, 2> sendBuffers;
  71. sendBuffers[0].BufferType = SECBUFFER_TOKEN;
  72. sendBuffers[0].pvBuffer = inputBuffer.data();
  73. sendBuffers[0].cbBuffer = inputBuffer.size();
  74. sendBuffers[1].BufferType = SECBUFFER_EMPTY;
  75. sendBuffers[1].pvBuffer = nullptr;
  76. sendBuffers[1].cbBuffer = 0;
  77. SecBufferDesc recvBufferDesc, sendBufferDesc;
  78. recvBufferDesc.ulVersion = sendBufferDesc.ulVersion = SECBUFFER_VERSION;
  79. recvBufferDesc.pBuffers = &recvBuffers[0];
  80. recvBufferDesc.cBuffers = recvBuffers.size();
  81. if (!inputBuffer.empty())
  82. {
  83. sendBufferDesc.pBuffers = &sendBuffers[0];
  84. sendBufferDesc.cBuffers = sendBuffers.size();
  85. }
  86. else
  87. {
  88. sendBufferDesc.pBuffers = nullptr;
  89. sendBufferDesc.cBuffers = 0;
  90. }
  91. CtxtHandle* phOldContext = nullptr;
  92. CtxtHandle* phNewContext = nullptr;
  93. if (!phContext)
  94. {
  95. phContext = std::make_unique<CtxtHandle>();
  96. phNewContext = phContext.get();
  97. }
  98. else
  99. {
  100. phOldContext = phContext.get();
  101. }
  102. auto ret = InitializeSecurityContext(phCredential, phOldContext, const_cast<char*>(szTargetName.c_str()), fContextReq, 0, 0, &sendBufferDesc, 0, phNewContext, &recvBufferDesc, pfContextAttr, nullptr);
  103. outputBuffer.resize(recvBuffers[0].cbBuffer);
  104. // Clear the input buffer, so the reader can append
  105. // If we have unprocessed data, leave it in the buffer
  106. size_t unprocessed = 0;
  107. if (sendBuffers[1].BufferType == SECBUFFER_EXTRA)
  108. unprocessed = sendBuffers[1].cbBuffer;
  109. if (unprocessed > 0)
  110. memmove(inputBuffer.data(), inputBuffer.data() + inputBuffer.size() - unprocessed, unprocessed);
  111. inputBuffer.resize(unprocessed);
  112. return ret;
  113. }
  114. bool SChannelConnection::connect(const std::string &hostname, uint16_t port)
  115. {
  116. debug << "Trying to connect to " << hostname << ":" << port << "\n";
  117. if (!socket.connect(hostname, port))
  118. return false;
  119. debug << "Connected\n";
  120. SCHANNEL_CRED cred;
  121. memset(&cred, 0, sizeof(cred));
  122. cred.dwVersion = SCHANNEL_CRED_VERSION;
  123. cred.grbitEnabledProtocols = SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT;
  124. cred.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | SCH_CRED_NO_DEFAULT_CREDS | SCH_USE_STRONG_CRYPTO | SCH_CRED_REVOCATION_CHECK_CHAIN;
  125. CredHandle credHandle;
  126. if (AcquireCredentialsHandle(nullptr, (char*) UNISP_NAME, SECPKG_CRED_OUTBOUND, nullptr, &cred, nullptr, nullptr, &credHandle, nullptr) != SEC_E_OK)
  127. {
  128. debug << "Failed to acquire handle\n";
  129. socket.close();
  130. return false;
  131. }
  132. debug << "Acquired handle\n";
  133. static constexpr size_t bufferSize = 8192;
  134. bool done = false, success = false, contextCreated = false;
  135. ULONG contextAttr;
  136. std::unique_ptr<CtxtHandle> context;
  137. std::vector<char> inputBuffer;
  138. std::vector<char> outputBuffer;
  139. do
  140. {
  141. outputBuffer.resize(bufferSize);
  142. bool recvData = false;
  143. bool sendData = false;
  144. auto ret = InitializeSecurityContext(&credHandle, context, hostname, ISC_REQ_STREAM, inputBuffer, outputBuffer, &contextAttr);
  145. switch (ret)
  146. {
  147. /*case SEC_I_COMPLETE_NEEDED:
  148. case SEC_I_COMPLETE_AND_CONTINUE:
  149. if (CompleteAuthToken(context.get(), &outputBuffer) != SEC_E_OK)
  150. done = true;
  151. else if (ret == SEC_I_COMPLETE_NEEDED)
  152. success = done = true;
  153. break;*/
  154. case SEC_I_CONTINUE_NEEDED:
  155. debug << "Initialize: continue needed\n";
  156. recvData = true;
  157. sendData = true;
  158. break;
  159. case SEC_E_INCOMPLETE_CREDENTIALS:
  160. debug << "Initialize failed: incomplete credentials\n";
  161. done = true;
  162. break;
  163. case SEC_E_INCOMPLETE_MESSAGE:
  164. debug << "Initialize: incomplete message\n";
  165. recvData = true;
  166. break;
  167. case SEC_E_OK:
  168. debug << "Initialize succeeded\n";
  169. success = done = true;
  170. sendData = true;
  171. break;
  172. default:
  173. debug << "Initialize done: " << outputBuffer.size() << " bytes of output and status " << ret << "\n";
  174. done = true;
  175. // TODO: error
  176. break;
  177. }
  178. if (!done)
  179. contextCreated = true;
  180. if (sendData && !outputBuffer.empty())
  181. {
  182. socket.write(outputBuffer.data(), outputBuffer.size());
  183. debug << "Sent " << outputBuffer.size() << " bytes of data\n";
  184. }
  185. if (recvData)
  186. {
  187. size_t unprocessed = inputBuffer.size();
  188. inputBuffer.resize(unprocessed + bufferSize);
  189. size_t actual = socket.read(inputBuffer.data() + unprocessed, bufferSize);
  190. inputBuffer.resize(actual + unprocessed);
  191. debug << "Received " << actual << " bytes of data\n";
  192. if (unprocessed > 0)
  193. debug << " had " << unprocessed << " bytes of remaining, unprocessed data\n";
  194. if (actual + unprocessed == 0)
  195. {
  196. debug << "No data to submit, break\n";
  197. break;
  198. }
  199. }
  200. // TODO: A bunch of frees?
  201. } while (!done);
  202. debug << "Done!\n";
  203. // TODO: Check resulting context attributes
  204. if (success)
  205. this->context = static_cast<void*>(context.release());
  206. else if (contextCreated)
  207. DeleteSecurityContext(context.get());
  208. return success;
  209. }
  210. size_t SChannelConnection::read(char *buffer, size_t size)
  211. {
  212. if (decRecvBuffer.size() > 0)
  213. {
  214. size = dequeue(decRecvBuffer, buffer, size);
  215. debug << "Read " << size << " bytes of previously decoded data\n";
  216. return size;
  217. }
  218. else if (encRecvBuffer.size() > 0)
  219. {
  220. size = dequeue(encRecvBuffer, buffer, size);
  221. debug << "Read " << size << " bytes of extra data\n";
  222. }
  223. else
  224. {
  225. size = socket.read(buffer, size);
  226. debug << "Received " << size << " bytes of data\n";
  227. }
  228. return decrypt(buffer, size);
  229. }
  230. size_t SChannelConnection::decrypt(char *buffer, size_t size, bool recurse)
  231. {
  232. if (size == 0)
  233. return 0;
  234. SecBuffer secBuffers[4];
  235. secBuffers[0].cbBuffer = size;
  236. secBuffers[0].BufferType = SECBUFFER_DATA;
  237. secBuffers[0].pvBuffer = buffer;
  238. for (size_t i = 1; i < 4; ++i)
  239. {
  240. secBuffers[i].BufferType = SECBUFFER_EMPTY;
  241. secBuffers[i].pvBuffer = nullptr;
  242. secBuffers[i].cbBuffer = 0;
  243. }
  244. SecBufferDesc secBufferDesc;
  245. secBufferDesc.ulVersion = SECBUFFER_VERSION;
  246. secBufferDesc.cBuffers = 4;
  247. secBufferDesc.pBuffers = &secBuffers[0];
  248. auto ret = DecryptMessage(static_cast<CtxtHandle*>(context), &secBufferDesc, 0, nullptr); // FIXME
  249. debug << "DecryptMessage returns: " << ret << "\n";
  250. switch (ret)
  251. {
  252. case SEC_E_OK:
  253. {
  254. void *actualDataStart = buffer;
  255. for (size_t i = 0; i < 4; ++i)
  256. {
  257. auto &buffer = secBuffers[i];
  258. if (buffer.BufferType == SECBUFFER_DATA)
  259. {
  260. actualDataStart = buffer.pvBuffer;
  261. size = buffer.cbBuffer;
  262. }
  263. else if (buffer.BufferType == SECBUFFER_EXTRA)
  264. {
  265. debug << "\tExtra data in buffer " << i << " (" << buffer.cbBuffer << " bytes)\n";
  266. enqueue(encRecvBuffer, static_cast<char*>(buffer.pvBuffer), buffer.cbBuffer);
  267. }
  268. else if (buffer.BufferType != SECBUFFER_EMPTY)
  269. debug << "\tBuffer of type " << buffer.BufferType << "\n";
  270. }
  271. if (actualDataStart)
  272. memmove(buffer, actualDataStart, size);
  273. break;
  274. }
  275. case SEC_E_INCOMPLETE_MESSAGE:
  276. {
  277. // Move all our current data to encRecvBuffer
  278. enqueue(encRecvBuffer, buffer, size);
  279. // Now try to read some more data from the socket
  280. size_t bufferSize = encRecvBuffer.size() + 8192;
  281. char *recvBuffer = new char[bufferSize];
  282. size_t recvd = socket.read(recvBuffer+encRecvBuffer.size(), 8192);
  283. debug << recvd << " bytes of extra data read from socket\n";
  284. if (recvd == 0 && !recurse)
  285. {
  286. debug << "Recursion prevented, bailing\n";
  287. return 0;
  288. }
  289. // Fill our buffer with the queued data and the newly received data
  290. size_t totalSize = encRecvBuffer.size() + recvd;
  291. dequeue(encRecvBuffer, recvBuffer, encRecvBuffer.size());
  292. debug << "Trying to decrypt with " << totalSize << " bytes of data\n";
  293. // Now try to decrypt that
  294. size_t decrypted = decrypt(recvBuffer, totalSize, false);
  295. debug << "\tObtained " << decrypted << " bytes of decrypted data\n";
  296. // Copy the first size bytes to the output buffer
  297. size = std::min(size, decrypted);
  298. memcpy(buffer, recvBuffer, size);
  299. // And write the remainder to our queued decrypted data...
  300. // Note: we prepend, since our recursive call may already have written
  301. // something and we can be sure decrypt wasn't called if the buffer was
  302. // non-empty in read
  303. enqueue_prepend(decRecvBuffer, recvBuffer+size, decrypted-size);
  304. debug << "\tStoring " << decrypted-size << " bytes of extra decrypted data\n";
  305. return size;
  306. }
  307. // TODO: More?
  308. default:
  309. size = 0;
  310. break;
  311. }
  312. debug << "\tDecrypted " << size << " bytes of data\n";
  313. return size;
  314. }
  315. size_t SChannelConnection::write(const char *buffer, size_t size)
  316. {
  317. static constexpr size_t bufferSize = 8192;
  318. assert(size <= bufferSize);
  319. SecPkgContext_StreamSizes Sizes;
  320. QueryContextAttributes(
  321. static_cast<CtxtHandle*>(context),
  322. SECPKG_ATTR_STREAM_SIZES,
  323. &Sizes);
  324. debug << "stream sizes:\n\theader: " << Sizes.cbHeader << "\n\tfooter: " << Sizes.cbTrailer << "\n";
  325. char *sendBuffer = new char[bufferSize + Sizes.cbHeader + Sizes.cbTrailer];
  326. memcpy(sendBuffer+Sizes.cbHeader, buffer, size);
  327. SecBuffer secBuffers[4];
  328. secBuffers[0].cbBuffer = Sizes.cbHeader;
  329. secBuffers[0].BufferType = SECBUFFER_STREAM_HEADER;
  330. secBuffers[0].pvBuffer = sendBuffer;
  331. secBuffers[1].cbBuffer = size;
  332. secBuffers[1].BufferType = SECBUFFER_DATA;
  333. secBuffers[1].pvBuffer = sendBuffer+Sizes.cbHeader;
  334. secBuffers[2].cbBuffer = Sizes.cbTrailer;
  335. secBuffers[2].pvBuffer = sendBuffer+Sizes.cbHeader+size;
  336. secBuffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
  337. secBuffers[3].cbBuffer = 0;
  338. secBuffers[3].BufferType = SECBUFFER_EMPTY;
  339. secBuffers[3].pvBuffer = nullptr;
  340. SecBufferDesc secBufferDesc;
  341. secBufferDesc.ulVersion = SECBUFFER_VERSION;
  342. secBufferDesc.cBuffers = 4;
  343. secBufferDesc.pBuffers = secBuffers;
  344. auto ret = EncryptMessage(static_cast<CtxtHandle*>(context), 0, &secBufferDesc, 0); // FIXME
  345. debug << "Send:\n\tHeader size: " << secBuffers[0].cbBuffer << "\n\t\ttype: " << secBuffers[0].BufferType << "\n\tData size: " << secBuffers[1].cbBuffer << "\n\t\ttype: " << secBuffers[1].BufferType << "\n\tFooter size: " << secBuffers[2].cbBuffer << "\n\t\ttype: " << secBuffers[2].BufferType << "\n";
  346. size_t sendSize = 0;
  347. for (size_t i = 0; i < 4; ++i)
  348. if (secBuffers[i].cbBuffer != bufferSize)
  349. sendSize += secBuffers[i].cbBuffer;
  350. debug << "\tReal length? " << sendSize << "\n";
  351. switch (ret)
  352. {
  353. case SEC_E_OK:
  354. socket.write(sendBuffer, sendSize);
  355. break;
  356. // TODO: More?
  357. default:
  358. size = 0;
  359. break;
  360. }
  361. delete[] sendBuffer;
  362. return size;
  363. }
  364. void SChannelConnection::close()
  365. {
  366. // TODO
  367. }
  368. bool SChannelConnection::valid()
  369. {
  370. return true;
  371. }