Преглед на файлове

Implement an easier wrapper around InitializeSecurityContext and use it

Since calling it is a lot easier now, it seems to actually finish
negotiating a connection.. sometimes
Bart van Strien преди 6 години
родител
ревизия
ff531385cc
променени са 1 файла, в които са добавени 77 реда и са изтрити 67 реда
  1. 77 67
      src/windows/SChannelConnection.cpp

+ 77 - 67
src/windows/SChannelConnection.cpp

@@ -4,7 +4,9 @@
 #include <schnlsp.h>
 #include <assert.h>
 #include <algorithm>
+#include <memory>
 
+#include "common/config.h"
 #include "SChannelConnection.h"
 
 #ifndef SCH_USE_STRONG_CRYPTO
@@ -70,6 +72,57 @@ SChannelConnection::~SChannelConnection()
 	}
 }
 
+SECURITY_STATUS InitializeSecurityContext(CredHandle *phCredential, std::unique_ptr<CtxtHandle>& phContext, const std::string& szTargetName, ULONG fContextReq, const std::vector<char>& inputBuffer, std::vector<char>& outputBuffer, ULONG *pfContextAttr)
+{
+	std::array<SecBuffer, 2> recvBuffers;
+	recvBuffers[0].BufferType = SECBUFFER_TOKEN;
+	recvBuffers[0].pvBuffer = outputBuffer.data();
+	recvBuffers[0].cbBuffer = outputBuffer.size();
+
+	recvBuffers[1].BufferType = SECBUFFER_EMPTY;
+	recvBuffers[1].pvBuffer = nullptr;
+	recvBuffers[1].cbBuffer = 0;
+
+	SecBuffer sendBuffer;
+	sendBuffer.BufferType = SECBUFFER_TOKEN;
+	sendBuffer.pvBuffer = const_cast<char*>(inputBuffer.data());
+	sendBuffer.cbBuffer = inputBuffer.size();
+
+	SecBufferDesc recvBufferDesc, sendBufferDesc;
+	recvBufferDesc.ulVersion = sendBufferDesc.ulVersion = SECBUFFER_VERSION;
+	recvBufferDesc.pBuffers = &recvBuffers[0];
+	recvBufferDesc.cBuffers = recvBuffers.size();
+
+	if (inputBuffer.size() > 0)
+	{
+		sendBufferDesc.pBuffers = &sendBuffer;
+		sendBufferDesc.cBuffers = 1;
+	}
+	else
+	{
+		sendBufferDesc.pBuffers = nullptr;
+		sendBufferDesc.cBuffers = 0;
+	}
+
+	CtxtHandle* phOldContext = nullptr;
+	CtxtHandle* phNewContext = nullptr;
+	if (!phContext)
+	{
+		phContext = std::make_unique<CtxtHandle>();
+		phNewContext = phContext.get();
+	}
+	else
+	{
+		phOldContext = phContext.get();
+	}
+
+	auto ret = InitializeSecurityContext(phCredential, phOldContext, const_cast<char*>(szTargetName.c_str()), fContextReq, 0, 0, &sendBufferDesc, 0, phNewContext, &recvBufferDesc, pfContextAttr, nullptr);
+
+	outputBuffer.resize(recvBuffers[0].cbBuffer);
+
+	return ret;
+}
+
 bool SChannelConnection::connect(const std::string &hostname, uint16_t port)
 {
 	debug << "Trying to connect to " << hostname << ":" << port << "\n";
@@ -93,42 +146,30 @@ bool SChannelConnection::connect(const std::string &hostname, uint16_t port)
 	}
 	debug << "Acquired handle\n";
 
-	CtxtHandle *context = new CtxtHandle;
-	CtxtHandle *inHandle = nullptr, *outHandle = context;
-
-	SecBufferDesc inputBuffer, outputBuffer;
-	inputBuffer.ulVersion = outputBuffer.ulVersion = SECBUFFER_VERSION;
-	inputBuffer.cBuffers = outputBuffer.cBuffers = 0;
-	inputBuffer.pBuffers = outputBuffer.pBuffers = nullptr;
-
-	ULONG contextAttr;
 
 	static constexpr size_t bufferSize = 8192;
 	bool done = false, success = false, contextCreated = false;
-	char *recvBuffer = nullptr;
-	char *sendBuffer = new char[2*bufferSize];
-
-	SecBuffer recvSecBuffer, sendSecBuffer;
-	recvSecBuffer.BufferType = sendSecBuffer.BufferType = SECBUFFER_TOKEN;
-	sendSecBuffer.cbBuffer = bufferSize;
-	sendSecBuffer.pvBuffer = sendBuffer;
 
-	outputBuffer.cBuffers = 1;
-	outputBuffer.pBuffers = &sendSecBuffer;
+	ULONG contextAttr;
+	std::unique_ptr<CtxtHandle> context;
+	std::vector<char> inputBuffer;
+	std::vector<char> outputBuffer;
 
 	do
 	{
+		outputBuffer.resize(bufferSize);
+
 		bool recvData = false;
-		auto ret = InitializeSecurityContext(&credHandle, inHandle, (char*) hostname.c_str(), ISC_REQ_STREAM, 0, 0, &inputBuffer, 0, outHandle, &outputBuffer, &contextAttr, nullptr);
+		auto ret = InitializeSecurityContext(&credHandle, context, hostname, ISC_REQ_STREAM, inputBuffer, outputBuffer, &contextAttr);
 		switch (ret)
 		{
-		case SEC_I_COMPLETE_NEEDED:
+		/*case SEC_I_COMPLETE_NEEDED:
 		case SEC_I_COMPLETE_AND_CONTINUE:
-			if (CompleteAuthToken(outHandle, &outputBuffer) != SEC_E_OK)
+			if (CompleteAuthToken(context.get(), &outputBuffer) != SEC_E_OK)
 				done = true;
 			else if (ret == SEC_I_COMPLETE_NEEDED)
 				success = done = true;
-			break;
+			break;*/
 		case SEC_I_CONTINUE_NEEDED:
 			recvData = true;
 			break;
@@ -150,64 +191,33 @@ bool SChannelConnection::connect(const std::string &hostname, uint16_t port)
 		if (!done)
 			contextCreated = true;
 
-		inHandle = context;
-		outHandle = nullptr;
+		debug << "Initialize done, with " << outputBuffer.size() << " bytes of output and status " << ret << "\n";
 
-		debug << "Initialize done, with " << outputBuffer.cBuffers << " output buffers and status " << ret << "\n";
-		for (unsigned int i = 0; i < outputBuffer.cBuffers && !success; ++i)
-		{
-			auto &buffer = outputBuffer.pBuffers[i];
-			debug << "\tBuffer of size: " << buffer.cbBuffer << "\n";
-			if (buffer.cbBuffer > 0 && buffer.BufferType == SECBUFFER_TOKEN)
-			{
-				socket.write((const char*) buffer.pvBuffer, buffer.cbBuffer);
-			}
-			else
-				debug << "Got buffer with type " << buffer.BufferType << "\n";
-
-			if (buffer.pvBuffer == sendBuffer)
-			{
-				memset(sendBuffer, 0, bufferSize);
-				buffer.cbBuffer = bufferSize;
-			}
-			//FreeContextBuffer(&buffer);
-		}
+		if (outputBuffer.size() > 0)
+			socket.write(outputBuffer.data(), outputBuffer.size());
 
 		if (recvData)
 		{
-			debug << "Receiving data\n";
-			if (!recvBuffer)
-				recvBuffer = new char[bufferSize];
-
-			recvSecBuffer.cbBuffer = socket.read(recvBuffer, bufferSize);
-			recvSecBuffer.pvBuffer = recvBuffer;
+			inputBuffer.resize(bufferSize);
+			size_t actual = socket.read(inputBuffer.data(), bufferSize);
+			inputBuffer.resize(actual);
 
-			inputBuffer.cBuffers = 1;
-			inputBuffer.pBuffers = &recvSecBuffer;
-		}
-		else
-		{
-			inputBuffer.cBuffers = 0;
-			inputBuffer.pBuffers = nullptr;
+			debug << "Received " << actual << " bytes of data\n";
+			if (actual == 0)
+			{
+				debug << "No data received, break\n";
+				break;
+			}
 		}
-
 		// TODO: A bunch of frees?
 	} while (!done);
 
-	delete[] sendBuffer;
-	delete[] recvBuffer;
-
 	debug << "Done!\n";
 	// TODO: Check resulting context attributes
 	if (success)
-	{
-		this->context = static_cast<void*>(context);
-	}
+		this->context = static_cast<void*>(context.release());
 	else if (contextCreated)
-	{
-		DeleteSecurityContext(context);
-		delete context;
-	}
+		DeleteSecurityContext(context.get());
 
 	return success;
 }