|
@@ -4,7 +4,9 @@
|
|
|
#include <schnlsp.h>
|
|
|
#include <assert.h>
|
|
|
#include <algorithm>
|
|
|
+#include <memory>
|
|
|
|
|
|
+#include "common/config.h"
|
|
|
#include "SChannelConnection.h"
|
|
|
|
|
|
#ifndef SCH_USE_STRONG_CRYPTO
|
|
@@ -70,6 +72,57 @@ SChannelConnection::~SChannelConnection()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+SECURITY_STATUS InitializeSecurityContext(CredHandle *phCredential, std::unique_ptr<CtxtHandle>& phContext, const std::string& szTargetName, ULONG fContextReq, const std::vector<char>& inputBuffer, std::vector<char>& outputBuffer, ULONG *pfContextAttr)
|
|
|
+{
|
|
|
+ std::array<SecBuffer, 2> recvBuffers;
|
|
|
+ recvBuffers[0].BufferType = SECBUFFER_TOKEN;
|
|
|
+ recvBuffers[0].pvBuffer = outputBuffer.data();
|
|
|
+ recvBuffers[0].cbBuffer = outputBuffer.size();
|
|
|
+
|
|
|
+ recvBuffers[1].BufferType = SECBUFFER_EMPTY;
|
|
|
+ recvBuffers[1].pvBuffer = nullptr;
|
|
|
+ recvBuffers[1].cbBuffer = 0;
|
|
|
+
|
|
|
+ SecBuffer sendBuffer;
|
|
|
+ sendBuffer.BufferType = SECBUFFER_TOKEN;
|
|
|
+ sendBuffer.pvBuffer = const_cast<char*>(inputBuffer.data());
|
|
|
+ sendBuffer.cbBuffer = inputBuffer.size();
|
|
|
+
|
|
|
+ SecBufferDesc recvBufferDesc, sendBufferDesc;
|
|
|
+ recvBufferDesc.ulVersion = sendBufferDesc.ulVersion = SECBUFFER_VERSION;
|
|
|
+ recvBufferDesc.pBuffers = &recvBuffers[0];
|
|
|
+ recvBufferDesc.cBuffers = recvBuffers.size();
|
|
|
+
|
|
|
+ if (inputBuffer.size() > 0)
|
|
|
+ {
|
|
|
+ sendBufferDesc.pBuffers = &sendBuffer;
|
|
|
+ sendBufferDesc.cBuffers = 1;
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ sendBufferDesc.pBuffers = nullptr;
|
|
|
+ sendBufferDesc.cBuffers = 0;
|
|
|
+ }
|
|
|
+
|
|
|
+ CtxtHandle* phOldContext = nullptr;
|
|
|
+ CtxtHandle* phNewContext = nullptr;
|
|
|
+ if (!phContext)
|
|
|
+ {
|
|
|
+ phContext = std::make_unique<CtxtHandle>();
|
|
|
+ phNewContext = phContext.get();
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ phOldContext = phContext.get();
|
|
|
+ }
|
|
|
+
|
|
|
+ auto ret = InitializeSecurityContext(phCredential, phOldContext, const_cast<char*>(szTargetName.c_str()), fContextReq, 0, 0, &sendBufferDesc, 0, phNewContext, &recvBufferDesc, pfContextAttr, nullptr);
|
|
|
+
|
|
|
+ outputBuffer.resize(recvBuffers[0].cbBuffer);
|
|
|
+
|
|
|
+ return ret;
|
|
|
+}
|
|
|
+
|
|
|
bool SChannelConnection::connect(const std::string &hostname, uint16_t port)
|
|
|
{
|
|
|
debug << "Trying to connect to " << hostname << ":" << port << "\n";
|
|
@@ -93,42 +146,30 @@ bool SChannelConnection::connect(const std::string &hostname, uint16_t port)
|
|
|
}
|
|
|
debug << "Acquired handle\n";
|
|
|
|
|
|
- CtxtHandle *context = new CtxtHandle;
|
|
|
- CtxtHandle *inHandle = nullptr, *outHandle = context;
|
|
|
-
|
|
|
- SecBufferDesc inputBuffer, outputBuffer;
|
|
|
- inputBuffer.ulVersion = outputBuffer.ulVersion = SECBUFFER_VERSION;
|
|
|
- inputBuffer.cBuffers = outputBuffer.cBuffers = 0;
|
|
|
- inputBuffer.pBuffers = outputBuffer.pBuffers = nullptr;
|
|
|
-
|
|
|
- ULONG contextAttr;
|
|
|
|
|
|
static constexpr size_t bufferSize = 8192;
|
|
|
bool done = false, success = false, contextCreated = false;
|
|
|
- char *recvBuffer = nullptr;
|
|
|
- char *sendBuffer = new char[2*bufferSize];
|
|
|
-
|
|
|
- SecBuffer recvSecBuffer, sendSecBuffer;
|
|
|
- recvSecBuffer.BufferType = sendSecBuffer.BufferType = SECBUFFER_TOKEN;
|
|
|
- sendSecBuffer.cbBuffer = bufferSize;
|
|
|
- sendSecBuffer.pvBuffer = sendBuffer;
|
|
|
|
|
|
- outputBuffer.cBuffers = 1;
|
|
|
- outputBuffer.pBuffers = &sendSecBuffer;
|
|
|
+ ULONG contextAttr;
|
|
|
+ std::unique_ptr<CtxtHandle> context;
|
|
|
+ std::vector<char> inputBuffer;
|
|
|
+ std::vector<char> outputBuffer;
|
|
|
|
|
|
do
|
|
|
{
|
|
|
+ outputBuffer.resize(bufferSize);
|
|
|
+
|
|
|
bool recvData = false;
|
|
|
- auto ret = InitializeSecurityContext(&credHandle, inHandle, (char*) hostname.c_str(), ISC_REQ_STREAM, 0, 0, &inputBuffer, 0, outHandle, &outputBuffer, &contextAttr, nullptr);
|
|
|
+ auto ret = InitializeSecurityContext(&credHandle, context, hostname, ISC_REQ_STREAM, inputBuffer, outputBuffer, &contextAttr);
|
|
|
switch (ret)
|
|
|
{
|
|
|
- case SEC_I_COMPLETE_NEEDED:
|
|
|
+ /*case SEC_I_COMPLETE_NEEDED:
|
|
|
case SEC_I_COMPLETE_AND_CONTINUE:
|
|
|
- if (CompleteAuthToken(outHandle, &outputBuffer) != SEC_E_OK)
|
|
|
+ if (CompleteAuthToken(context.get(), &outputBuffer) != SEC_E_OK)
|
|
|
done = true;
|
|
|
else if (ret == SEC_I_COMPLETE_NEEDED)
|
|
|
success = done = true;
|
|
|
- break;
|
|
|
+ break;*/
|
|
|
case SEC_I_CONTINUE_NEEDED:
|
|
|
recvData = true;
|
|
|
break;
|
|
@@ -150,64 +191,33 @@ bool SChannelConnection::connect(const std::string &hostname, uint16_t port)
|
|
|
if (!done)
|
|
|
contextCreated = true;
|
|
|
|
|
|
- inHandle = context;
|
|
|
- outHandle = nullptr;
|
|
|
+ debug << "Initialize done, with " << outputBuffer.size() << " bytes of output and status " << ret << "\n";
|
|
|
|
|
|
- debug << "Initialize done, with " << outputBuffer.cBuffers << " output buffers and status " << ret << "\n";
|
|
|
- for (unsigned int i = 0; i < outputBuffer.cBuffers && !success; ++i)
|
|
|
- {
|
|
|
- auto &buffer = outputBuffer.pBuffers[i];
|
|
|
- debug << "\tBuffer of size: " << buffer.cbBuffer << "\n";
|
|
|
- if (buffer.cbBuffer > 0 && buffer.BufferType == SECBUFFER_TOKEN)
|
|
|
- {
|
|
|
- socket.write((const char*) buffer.pvBuffer, buffer.cbBuffer);
|
|
|
- }
|
|
|
- else
|
|
|
- debug << "Got buffer with type " << buffer.BufferType << "\n";
|
|
|
-
|
|
|
- if (buffer.pvBuffer == sendBuffer)
|
|
|
- {
|
|
|
- memset(sendBuffer, 0, bufferSize);
|
|
|
- buffer.cbBuffer = bufferSize;
|
|
|
- }
|
|
|
- //FreeContextBuffer(&buffer);
|
|
|
- }
|
|
|
+ if (outputBuffer.size() > 0)
|
|
|
+ socket.write(outputBuffer.data(), outputBuffer.size());
|
|
|
|
|
|
if (recvData)
|
|
|
{
|
|
|
- debug << "Receiving data\n";
|
|
|
- if (!recvBuffer)
|
|
|
- recvBuffer = new char[bufferSize];
|
|
|
-
|
|
|
- recvSecBuffer.cbBuffer = socket.read(recvBuffer, bufferSize);
|
|
|
- recvSecBuffer.pvBuffer = recvBuffer;
|
|
|
+ inputBuffer.resize(bufferSize);
|
|
|
+ size_t actual = socket.read(inputBuffer.data(), bufferSize);
|
|
|
+ inputBuffer.resize(actual);
|
|
|
|
|
|
- inputBuffer.cBuffers = 1;
|
|
|
- inputBuffer.pBuffers = &recvSecBuffer;
|
|
|
- }
|
|
|
- else
|
|
|
- {
|
|
|
- inputBuffer.cBuffers = 0;
|
|
|
- inputBuffer.pBuffers = nullptr;
|
|
|
+ debug << "Received " << actual << " bytes of data\n";
|
|
|
+ if (actual == 0)
|
|
|
+ {
|
|
|
+ debug << "No data received, break\n";
|
|
|
+ break;
|
|
|
+ }
|
|
|
}
|
|
|
-
|
|
|
// TODO: A bunch of frees?
|
|
|
} while (!done);
|
|
|
|
|
|
- delete[] sendBuffer;
|
|
|
- delete[] recvBuffer;
|
|
|
-
|
|
|
debug << "Done!\n";
|
|
|
// TODO: Check resulting context attributes
|
|
|
if (success)
|
|
|
- {
|
|
|
- this->context = static_cast<void*>(context);
|
|
|
- }
|
|
|
+ this->context = static_cast<void*>(context.release());
|
|
|
else if (contextCreated)
|
|
|
- {
|
|
|
- DeleteSecurityContext(context);
|
|
|
- delete context;
|
|
|
- }
|
|
|
+ DeleteSecurityContext(context.get());
|
|
|
|
|
|
return success;
|
|
|
}
|