Browse Source

Update lua-https to latest commit (love2d/lua-https@6dbce69)

Sasha Szpakowski 2 years ago
parent
commit
ddf88c64bd

+ 10 - 1
CMakeLists.txt

@@ -1619,6 +1619,8 @@ set(LOVE_SRC_3P_LUAHTTPS_LUA
 set(LOVE_SRC_3P_LUAHTTPS_WINDOWS
 	src/libraries/luahttps/src/windows/SChannelConnection.cpp
 	src/libraries/luahttps/src/windows/SChannelConnection.h
+	src/libraries/luahttps/src/windows/WinINetClient.cpp
+	src/libraries/luahttps/src/windows/WinINetClient.h
 )
 
 # These are platform-dependent but have ifdef guards to make sure they only
@@ -1638,10 +1640,17 @@ endif()
 set(LOVE_LINK_L3P_LUAHTTPS)
 if(MSVC)
 	set(LOVE_LINK_L3P_LUAHTTPS
-		${LOVE_LINK_L3P_LUASOCKET_LIBLUASOCKET}
+		${LOVE_LINK_L3P_LUAHTTPS}
 		ws2_32
 		secur32
 	)
+
+	if(NOT CMAKE_SYSTEM_NAME STREQUAL "WindowsStore")
+		set(LOVE_LINK_L3P_LUAHTTPS
+			${LOVE_LINK_L3P_LUAHTTPS}
+			wininet
+		)
+	endif()
 endif()
 
 add_library(love_3p_luahttps ${LOVE_SRC_3P_LUAHTTPS})

+ 0 - 3
src/libraries/luahttps/src/common/HTTPRequest.cpp

@@ -50,9 +50,6 @@ HTTPSClient::Reply HTTPRequest::request(const HTTPSClient::Request &req)
 
 		request << "Host: " << info.hostname << "\r\n";
 
-		if (hasData && req.headers.count("Content-Type") == 0)
-			request << "Content-Type: application/x-www-form-urlencoded\r\n";
-
 		if (hasData)
 			request << "Content-Length: " << req.postdata.size() << "\r\n";
 

+ 9 - 9
src/libraries/luahttps/src/common/HTTPRequest.h

@@ -8,14 +8,6 @@
 class HTTPRequest
 {
 public:
-	typedef std::function<Connection *()> ConnectionFactory;
-	HTTPRequest(ConnectionFactory factory);
-
-	HTTPSClient::Reply request(const HTTPSClient::Request &req);
-
-private:
-	ConnectionFactory factory;
-
 	struct DissectedURL
 	{
 		bool valid;
@@ -25,6 +17,14 @@ private:
 		std::string query;
 		// TODO: Auth?
 	};
+	typedef std::function<Connection *()> ConnectionFactory;
+
+	HTTPRequest(ConnectionFactory factory);
+
+	HTTPSClient::Reply request(const HTTPSClient::Request &req);
+
+	static DissectedURL parseUrl(const std::string &url);
 
-	DissectedURL parseUrl(const std::string &url);
+private:
+	ConnectionFactory factory;
 };

+ 10 - 0
src/libraries/luahttps/src/common/HTTPS.cpp

@@ -19,6 +19,9 @@
 #ifdef HTTPS_BACKEND_ANDROID
 #	include "../android/AndroidClient.h"
 #endif
+#ifdef HTTPS_BACKEND_WININET
+#	include "../windows/WinINetClient.h"
+#endif
 
 #ifdef HTTPS_BACKEND_CURL
 	static CurlClient curlclient;
@@ -35,6 +38,9 @@
 #ifdef HTTPS_BACKEND_ANDROID
 	static AndroidClient androidclient;
 #endif
+#ifdef HTTPS_BACKEND_WININET
+	static WinINetClient wininetclient;
+#endif
 
 static HTTPSClient *clients[] = {
 #ifdef HTTPS_BACKEND_CURL
@@ -42,6 +48,10 @@ static HTTPSClient *clients[] = {
 #endif
 #ifdef HTTPS_BACKEND_OPENSSL
 	&opensslclient,
+#endif
+	// WinINet must be above SChannel
+#ifdef HTTPS_BACKEND_WININET
+	&wininetclient,
 #endif
 #ifdef HTTPS_BACKEND_SCHANNEL
 	&schannelclient,

+ 2 - 2
src/libraries/luahttps/src/common/HTTPSClient.cpp

@@ -30,8 +30,8 @@ bool HTTPSClient::ci_string_less::operator()(const std::string &lhs, const std::
 }
 
 HTTPSClient::Request::Request(const std::string &url)
-	: url(url)
-	, method("")
+: url(url)
+, method("GET")
 {
 }
 

+ 5 - 0
src/libraries/luahttps/src/common/config.h

@@ -5,6 +5,11 @@
 #elif defined(WIN32) || defined(_WIN32)
 	#define HTTPS_BACKEND_SCHANNEL
 	#define HTTPS_USE_WINSOCK
+	#include <winapifamily.h>
+	#if !defined(WINAPI_FAMILY) || (WINAPI_FAMILY == WINAPI_FAMILY_DESKTOP_APP)
+		// WinINet is only supported on desktop.
+		#define HTTPS_BACKEND_WININET
+	#endif
 #elif defined(__ANDROID__)
 	#define HTTPS_BACKEND_ANDROID
 #elif defined(__APPLE__)

+ 117 - 31
src/libraries/luahttps/src/generic/CurlClient.cpp

@@ -1,49 +1,129 @@
+#ifdef _WIN32
+#define NOMINMAX
+#define WIN32_LEAN_AND_MEAN
+#endif
+
 #include "CurlClient.h"
 
 #ifdef HTTPS_BACKEND_CURL
 
-#include <dlfcn.h>
+#include <algorithm>
 #include <stdexcept>
 #include <sstream>
 #include <vector>
 
+// Dynamic library loader
+#ifdef _WIN32
+#include <windows.h>
+#else
+#include <dlfcn.h>
+#endif
+
+typedef struct StringReader
+{
+	const std::string *str;
+	size_t pos;
+} StringReader;
+
+template <class T>
+static inline bool loadSymbol(T &var, void *handle, const char *name)
+{
+#ifdef _WIN32
+	var = (T) GetProcAddress((HMODULE) handle, name);
+#else
+	var = (T) dlsym(handle, name);
+#endif
+	return var != nullptr;
+}
+
 CurlClient::Curl::Curl()
+: handle(nullptr)
+, loaded(false)
+, global_cleanup(nullptr)
+, easy_init(nullptr)
+, easy_cleanup(nullptr)
+, easy_setopt(nullptr)
+, easy_perform(nullptr)
+, easy_getinfo(nullptr)
+, slist_append(nullptr)
+, slist_free_all(nullptr)
 {
-	void *handle = dlopen("libcurl.so", RTLD_LAZY);
+#ifdef _WIN32
+	handle = (void *) LoadLibraryA("libcurl.dll");
+#else
+	handle = dlopen("libcurl.so.4", RTLD_LAZY);
+#endif
 	if (!handle)
-	{
-		loaded = false;
 		return;
-	}
 
-	void (*global_init)() = (void(*)()) dlsym(handle, "curl_global_init");
-	easy_init = (CURL*(*)()) dlsym(handle, "curl_easy_init");
-	easy_cleanup = (void(*)(CURL*)) dlsym(handle, "curl_easy_cleanup");
-	easy_setopt = (CURLcode(*)(CURL*,CURLoption,...)) dlsym(handle, "curl_easy_setopt");
-	easy_perform = (CURLcode(*)(CURL*)) dlsym(handle, "curl_easy_perform");
-	easy_getinfo = (CURLcode(*)(CURL*,CURLINFO,...)) dlsym(handle, "curl_easy_getinfo");
-	slist_append = (curl_slist*(*)(curl_slist*,const char*)) dlsym(handle, "curl_slist_append");
-	slist_free_all = (void(*)(curl_slist*)) dlsym(handle, "curl_slist_free_all");
+	// Load symbols
+	decltype(&curl_global_init) global_init = nullptr;
+	if (!loadSymbol(global_init, handle, "curl_global_init"))
+		return;
+	if (!loadSymbol(global_cleanup, handle, "curl_global_cleanup"))
+		return;
+	if (!loadSymbol(easy_init, handle, "curl_easy_init"))
+		return;
+	if (!loadSymbol(easy_cleanup, handle, "curl_easy_cleanup"))
+		return;
+	if (!loadSymbol(easy_setopt, handle, "curl_easy_setopt"))
+		return;
+	if (!loadSymbol(easy_perform, handle, "curl_easy_perform"))
+		return;
+	if (!loadSymbol(easy_getinfo, handle, "curl_easy_getinfo"))
+		return;
+	if (!loadSymbol(slist_append, handle, "curl_slist_append"))
+		return;
+	if (!loadSymbol(slist_free_all, handle, "curl_slist_free_all"))
+		return;
+
+	global_init(CURL_GLOBAL_DEFAULT);
+	loaded = true;
+}
+
+CurlClient::Curl::~Curl()
+{
+	if (loaded)
+		global_cleanup();
+
+	if (handle)
+#ifdef _WIN32
+		FreeLibrary((HMODULE) handle);
+#else
+		dlclose(handle);
+#endif
+}
+
+static char toUppercase(char c)
+{
+	int ch = (unsigned char) c;
+	return toupper(ch);
+}
 
-	loaded = (global_init && easy_init && easy_cleanup && easy_setopt && easy_perform && easy_getinfo && slist_append && slist_free_all);
+static size_t stringReader(char *ptr, size_t size, size_t nmemb, StringReader *reader)
+{
+	const char *data = reader->str->data();
+	size_t len = reader->str->length();
+	size_t maxCount = (len - reader->pos) / size;
+	size_t desiredCount = std::min(maxCount, nmemb);
+	size_t desiredBytes = desiredCount * size;
 
-	if (!loaded)
-		return;
+	std::copy(data + reader->pos, data + desiredBytes, ptr);
+	reader->pos += desiredBytes;
 
-	global_init();
+	return desiredCount;
 }
 
-static size_t stringstreamWriter(char *ptr, size_t size, size_t nmemb, void *userdata)
+static size_t stringstreamWriter(char *ptr, size_t size, size_t nmemb, std::stringstream *ss)
 {
-	std::stringstream *ss = (std::stringstream*) userdata;
 	size_t count = size*nmemb;
 	ss->write(ptr, count);
 	return count;
 }
 
-static size_t headerWriter(char *ptr, size_t size, size_t nmemb, void *userdata)
+static size_t headerWriter(char *ptr, size_t size, size_t nmemb, std::map<std::string,std::string> *userdata)
 {
-	std::map<std::string, std::string> &headers = *((std::map<std::string,std::string>*) userdata);
+	std::map<std::string, std::string> &headers = *userdata;
 	size_t count = size*nmemb;
 	std::string line(ptr, count);
 	size_t split = line.find(':');
@@ -64,7 +144,10 @@ bool CurlClient::valid() const
 HTTPSClient::Reply CurlClient::request(const HTTPSClient::Request &req)
 {
 	Reply reply;
-	reply.responseCode = 400;
+	reply.responseCode = 0;
+
+	// Use sensible default header for later
+	HTTPSClient::header_map newHeaders = req.headers;
 
 	CURL *handle = curl.easy_init();
 	if (!handle)
@@ -72,23 +155,26 @@ HTTPSClient::Reply CurlClient::request(const HTTPSClient::Request &req)
 
 	curl.easy_setopt(handle, CURLOPT_URL, req.url.c_str());
 	curl.easy_setopt(handle, CURLOPT_FOLLOWLOCATION, 1L);
+	curl.easy_setopt(handle, CURLOPT_CUSTOMREQUEST, req.method.c_str());
 
-	if (req.method == "PUT")
-		curl.easy_setopt(handle, CURLOPT_PUT, 1L);
-	else if (req.method == "POST")
-		curl.easy_setopt(handle, CURLOPT_POST, 1L);
-	else
-		curl.easy_setopt(handle, CURLOPT_CUSTOMREQUEST, req.method.c_str());
+	StringReader reader {};
 
 	if (req.postdata.size() > 0 && (req.method != "GET" && req.method != "HEAD"))
 	{
-		curl.easy_setopt(handle, CURLOPT_POSTFIELDS, req.postdata.c_str());
-		curl.easy_setopt(handle, CURLOPT_POSTFIELDSIZE, req.postdata.size());
+		reader.str = &req.postdata;
+		reader.pos = 0;
+		curl.easy_setopt(handle, CURLOPT_UPLOAD, 1L);
+		curl.easy_setopt(handle, CURLOPT_READFUNCTION, stringReader);
+		curl.easy_setopt(handle, CURLOPT_READDATA, &reader);
+		curl.easy_setopt(handle, CURLOPT_INFILESIZE_LARGE, (curl_off_t) req.postdata.length());
 	}
 
+	if (req.method == "HEAD")
+		curl.easy_setopt(handle, CURLOPT_NOBODY, 1L);
+
 	// Curl doesn't copy memory, keep the strings around
 	std::vector<std::string> lines;
-	for (auto &header : req.headers)
+	for (auto &header : newHeaders)
 	{
 		std::stringstream line;
 		line << header.first << ": " << header.second;

+ 11 - 7
src/libraries/luahttps/src/generic/CurlClient.h

@@ -18,16 +18,20 @@ private:
 	static struct Curl
 	{
 		Curl();
+		~Curl();
+		void *handle;
 		bool loaded;
 
-		CURL *(*easy_init)();
-		void (*easy_cleanup)(CURL *handle);
-		CURLcode (*easy_setopt)(CURL *handle, CURLoption option, ...);
-		CURLcode (*easy_perform)(CURL *easy_handle);
-		CURLcode (*easy_getinfo)(CURL *curl, CURLINFO info, ...);
+		decltype(&curl_global_cleanup) global_cleanup;
 
-		curl_slist *(*slist_append)(curl_slist *list, const char *string);
-		void (*slist_free_all)(curl_slist *list);
+		decltype(&curl_easy_init) easy_init;
+		decltype(&curl_easy_cleanup) easy_cleanup;
+		decltype(&curl_easy_setopt) easy_setopt;
+		decltype(&curl_easy_perform) easy_perform;
+		decltype(&curl_easy_getinfo) easy_getinfo;
+
+		decltype(&curl_slist_append) slist_append;
+		decltype(&curl_slist_free_all) slist_free_all;
 	} curl;
 };
 

+ 1 - 0
src/libraries/luahttps/src/lua/main.cpp

@@ -78,6 +78,7 @@ static int w_request(lua_State *L)
 		if (!lua_isnoneornil(L, -1))
 		{
 			req.postdata = w_checkstring(L, -1);
+			req.headers["Content-Type"] = "application/x-www-form-urlencoded";
 			defaultMethod = "POST";
 		}
 		lua_pop(L, 1);

+ 224 - 0
src/libraries/luahttps/src/windows/WinINetClient.cpp

@@ -0,0 +1,224 @@
+#include "WinINetClient.h"
+
+#ifdef HTTPS_BACKEND_WININET
+
+#include <algorithm>
+#include <stdexcept>
+#include <sstream>
+#include <vector>
+
+#include <Windows.h>
+#include <wininet.h>
+
+#include "../common/HTTPRequest.h"
+
+class LazyHInternetLoader final
+{
+public:
+	LazyHInternetLoader(): hInternet(nullptr) { }
+	~LazyHInternetLoader()
+	{
+		if (hInternet)
+			InternetCloseHandle(hInternet);
+	}
+
+	HINTERNET getInstance()
+	{
+		if (!init)
+		{
+			hInternet = InternetOpenA("", INTERNET_OPEN_TYPE_PRECONFIG, nullptr, nullptr, 0);
+			if (hInternet)
+			{
+				// Try to enable HTTP2
+				DWORD httpProtocol = HTTP_PROTOCOL_FLAG_HTTP2;
+				InternetSetOptionA(hInternet, INTERNET_OPTION_ENABLE_HTTP_PROTOCOL, &httpProtocol, sizeof(DWORD));
+				SetLastError(0); // If it errors, ignore.
+			}
+		}
+
+		return hInternet;
+	}
+
+private:
+	bool init;
+	HINTERNET hInternet;
+};
+
+static thread_local LazyHInternetLoader hInternetCache;
+
+bool WinINetClient::valid() const
+{
+	// Allow disablement of WinINet backend.
+	const char *disabler = getenv("LUAHTTPS_DISABLE_WININET");
+	if (disabler && strcmp(disabler, "1") == 0)
+		return false;
+
+	return hInternetCache.getInstance() != nullptr;
+}
+
+HTTPSClient::Reply WinINetClient::request(const HTTPSClient::Request &req)
+{
+	Reply reply;
+	reply.responseCode = 0;
+
+	// Parse URL
+	auto parsedUrl = HTTPRequest::parseUrl(req.url);
+
+	// Default flags
+	DWORD inetFlags =
+		INTERNET_FLAG_NO_AUTH |
+		INTERNET_FLAG_NO_CACHE_WRITE |
+		INTERNET_FLAG_NO_COOKIES |
+		INTERNET_FLAG_NO_UI;
+
+	if (parsedUrl.schema == "https")
+		inetFlags |= INTERNET_FLAG_SECURE;
+	else if (parsedUrl.schema != "http")
+		return reply;
+
+	// Keep-Alive
+	auto connectHeader = req.headers.find("Connection");
+	auto headerEnd = req.headers.end();
+	if ((connectHeader != headerEnd && connectHeader->second != "close") || connectHeader == headerEnd)
+		inetFlags |= INTERNET_FLAG_KEEP_CONNECTION;
+
+	// Open internet
+	HINTERNET hInternet = hInternetCache.getInstance();
+	if (hInternet == nullptr)
+		return reply;
+
+	// Connect
+	HINTERNET hConnect = InternetConnectA(
+		hInternet,
+		parsedUrl.hostname.c_str(),
+		parsedUrl.port,
+		nullptr, nullptr,
+		INTERNET_SERVICE_HTTP,
+		INTERNET_FLAG_EXISTING_CONNECT,
+		(DWORD_PTR) this
+	);
+	if (!hConnect)
+		return reply;
+
+	std::string httpMethod = req.method;
+	std::transform(
+		httpMethod.begin(),
+		httpMethod.end(),
+		httpMethod.begin(),
+		[](char c) {return (char)toupper((unsigned char) c); }
+	);
+
+	// Open HTTP request
+	HINTERNET hHTTP = HttpOpenRequestA(
+		hConnect,
+		httpMethod.c_str(),
+		parsedUrl.query.c_str(),
+		nullptr,
+		nullptr,
+		nullptr,
+		inetFlags,
+		(DWORD_PTR) this
+	);
+	if (!hHTTP)
+	{
+		InternetCloseHandle(hConnect);
+		return reply;
+	}
+
+	// Send additional headers
+	HttpAddRequestHeadersA(hHTTP, "User-Agent:", 0, HTTP_ADDREQ_FLAG_REPLACE);
+	for (const auto &header: req.headers)
+	{
+		std::string headerString = header.first + ": " + header.second + "\r\n";
+		HttpAddRequestHeadersA(hHTTP, headerString.c_str(), headerString.length(), HTTP_ADDREQ_FLAG_ADD | HTTP_ADDREQ_FLAG_REPLACE);
+	}
+
+	// POST data
+	const char *postData = nullptr;
+	if (req.postdata.length() > 0 && (httpMethod != "GET" && httpMethod != "HEAD"))
+	{
+		char temp[48];
+		int len = sprintf(temp, "Content-Length: %u\r\n", (unsigned int) req.postdata.length());
+		postData = req.postdata.c_str();
+
+		HttpAddRequestHeadersA(hHTTP, temp, len, HTTP_ADDREQ_FLAG_ADD | HTTP_ADDREQ_FLAG_REPLACE);
+	}
+
+	// Send away!
+	BOOL result = HttpSendRequestA(hHTTP, nullptr, 0, (void *) postData, (DWORD) req.postdata.length());
+	if (!result)
+	{
+		InternetCloseHandle(hHTTP);
+		InternetCloseHandle(hConnect);
+		return reply;
+	}
+
+	DWORD bufferLength = sizeof(DWORD);
+	DWORD headerCounter = 0;
+
+	// Status code
+	DWORD statusCode = 0;
+	if (!HttpQueryInfoA(hHTTP, HTTP_QUERY_STATUS_CODE | HTTP_QUERY_FLAG_NUMBER, &statusCode, &bufferLength, &headerCounter))
+	{
+		InternetCloseHandle(hHTTP);
+		InternetCloseHandle(hConnect);
+		return reply;
+	}
+
+	// Query headers
+	std::vector<char> responseHeaders;
+	bufferLength = 0;
+	HttpQueryInfoA(hHTTP, HTTP_QUERY_RAW_HEADERS, responseHeaders.data(), &bufferLength, &headerCounter);
+	if (GetLastError() != ERROR_INSUFFICIENT_BUFFER)
+	{
+		InternetCloseHandle(hHTTP);
+		InternetCloseHandle(hConnect);
+		return reply;
+	}
+
+	responseHeaders.resize(bufferLength);
+	if (!HttpQueryInfoA(hHTTP, HTTP_QUERY_RAW_HEADERS, responseHeaders.data(), &bufferLength, &headerCounter))
+	{
+		InternetCloseHandle(hHTTP);
+		InternetCloseHandle(hConnect);
+		return reply;
+	}
+
+	for (const char *headerData = responseHeaders.data(); *headerData; headerData += strlen(headerData) + 1)
+	{
+		const char *value = strchr(headerData, ':');
+		if (value)
+		{
+			ptrdiff_t keyLen = (ptrdiff_t) (value - headerData);
+			reply.headers[std::string(headerData, keyLen)] = value + 2; // +2, colon and 1 space character.
+		}
+	}
+	responseHeaders.resize(1);
+
+	// Read response
+	std::stringstream responseData;
+	for (;;)
+	{
+		constexpr DWORD BUFFER_SIZE = 4096;
+		char buffer[BUFFER_SIZE];
+		DWORD readed = 0;
+
+		BOOL ret = InternetQueryDataAvailable(hHTTP, &readed, 0, 0);
+		if (!ret || readed == 0)
+			break;
+
+		if (!InternetReadFile(hHTTP, buffer, BUFFER_SIZE, &readed))
+			break;
+
+		responseData.write(buffer, readed);
+	}
+
+	reply.body = responseData.str();
+	reply.responseCode = statusCode;
+
+	InternetCloseHandle(hHTTP);
+	InternetCloseHandle(hConnect);
+	return reply;
+}
+
+#endif // HTTPS_BACKEND_WININET

+ 16 - 0
src/libraries/luahttps/src/windows/WinINetClient.h

@@ -0,0 +1,16 @@
+#pragma once
+
+#include "../common/config.h"
+
+#ifdef HTTPS_BACKEND_WININET
+
+#include "../common/HTTPSClient.h"
+
+class WinINetClient: public HTTPSClient
+{
+public:
+	bool valid() const override;
+	HTTPSClient::Reply request(const HTTPSClient::Request &req) override;
+};
+
+#endif // HTTPS_BACKEND_WININET