Browse Source

Extract library loading code into its own file

Makes it easier to replace, centralises the platform-dependent logic, and means we can more easily re-use things like the templated LoadSymbol.
Bart van Strien 1 year ago
parent
commit
0668b0fd66

+ 1 - 0
src/CMakeLists.txt

@@ -41,6 +41,7 @@ add_library (https-common STATIC
 	common/HTTPRequest.cpp
 	common/HTTPSClient.cpp
 	common/PlaintextConnection.cpp
+	common/LibraryLoader.cpp
 )
 
 add_library (https-curl STATIC EXCLUDE_FROM_ALL

+ 4 - 3
src/android/AndroidClient.cpp

@@ -5,7 +5,7 @@
 #include <sstream>
 #include <type_traits>
 
-#include <dlfcn.h>
+#include "../common/LibraryLoader.h"
 
 // We want std::string that contains null byte, hence length of 1.
 // NOLINTNEXTLINE
@@ -52,10 +52,11 @@ static std::string getStringUTF(JNIEnv *env, jstring str)
 AndroidClient::AndroidClient()
 : HTTPSClient()
 {
+	LibraryLoader::handle *library = LibraryLoader::GetCurrentProcessHandle();
 	// Look for SDL_AndroidGetJNIEnv
-	SDL_AndroidGetJNIEnv = (decltype(SDL_AndroidGetJNIEnv)) dlsym(RTLD_DEFAULT, "SDL_AndroidGetJNIEnv");
+	LibraryLoader::LoadSymbol(SDL_AndroidGetJNIEnv, library, "SDL_AndroidGetJNIEnv");
 	// Look for SDL_AndroidGetActivity
-	SDL_AndroidGetActivity = (decltype(SDL_AndroidGetActivity)) dlsym(RTLD_DEFAULT, "SDL_AndroidGetActivity");
+	LibraryLoader::LoadSymbol(SDL_AndroidGetActivity, library, "SDL_AndroidGetActivity");
 }
 
 bool AndroidClient::valid() const

+ 4 - 0
src/common/HTTPS.cpp

@@ -1,6 +1,7 @@
 #include "HTTPS.h"
 #include "config.h"
 #include "ConnectionClient.h"
+#include "LibraryLoader.h"
 
 #include <stdexcept>
 
@@ -65,6 +66,9 @@ static HTTPSClient *clients[] = {
 	nullptr,
 };
 
+// Call into the library loader to make sure it is linked in
+static LibraryLoader::handle* dummyProcessHandle = LibraryLoader::GetCurrentProcessHandle();
+
 HTTPSClient::Reply request(const HTTPSClient::Request &req)
 {
 	for (size_t i = 0; clients[i]; ++i)

+ 54 - 0
src/common/LibraryLoader.cpp

@@ -0,0 +1,54 @@
+#include "config.h"
+#include "LibraryLoader.h"
+
+#ifdef _WIN32
+#define NOMINMAX
+#define WIN32_LEAN_AND_MEAN
+
+#include <windows.h>
+#else
+#include <dlfcn.h>
+#endif
+
+namespace LibraryLoader
+{
+	handle *OpenLibrary(const char *name)
+	{
+#ifdef _WIN32
+		return reinterpret_cast<handle *>(LoadLibraryA(name));
+#else
+		return dlopen(name, RTLD_LAZY);
+#endif
+	}
+
+	void CloseLibrary(handle *handle)
+	{
+		if (handle)
+		{
+#ifdef _WIN32
+			FreeLibrary(handle);
+#else
+			dlclose(handle);
+#endif
+		}
+	}
+
+	handle* GetCurrentProcessHandle()
+	{
+#ifdef _WIN32
+		return reinterpret_cast<handle *>(GetModuleHandle(nullptr));
+#else
+		return RTLD_DEFAULT;
+#endif
+	}
+
+	function *GetFunction(handle *handle, const char *name)
+	{
+#ifdef _WIN32
+		HMODULE nativeHandle = reinterpret_cast<HMODULE>(handle);
+		return reinterpret_cast<function *>(GetProcAddress(nativeHandle, name));
+#else
+		return reinterpret_cast<function *>(dlsym(handle, name));
+#endif
+	}
+}

+ 20 - 0
src/common/LibraryLoader.h

@@ -0,0 +1,20 @@
+#pragma once
+
+namespace LibraryLoader
+{
+	using handle = void;
+	using function = void();
+
+	handle *OpenLibrary(const char *name);
+	void CloseLibrary(handle *handle);
+	handle* GetCurrentProcessHandle();
+
+	function *GetFunction(handle *handle, const char *name);
+
+	template<class T>
+	inline bool LoadSymbol(T& var, handle *handle, const char *name)
+	{
+		var = reinterpret_cast<T>(GetFunction(handle, name));
+		return var != nullptr;
+	}
+}

+ 14 - 39
src/generic/CurlClient.cpp

@@ -1,8 +1,3 @@
-#ifdef _WIN32
-#define NOMINMAX
-#define WIN32_LEAN_AND_MEAN
-#endif
-
 #include "CurlClient.h"
 
 #ifdef HTTPS_BACKEND_CURL
@@ -12,30 +7,12 @@
 #include <sstream>
 #include <vector>
 
-// Dynamic library loader
-#ifdef _WIN32
-#include <windows.h>
-#else
-#include <dlfcn.h>
-#endif
-
 typedef struct StringReader
 {
 	const std::string *str;
 	size_t pos;
 } StringReader;
 
-template <class T>
-static inline bool loadSymbol(T &var, void *handle, const char *name)
-{
-#ifdef _WIN32
-	var = (T) GetProcAddress((HMODULE) handle, name);
-#else
-	var = (T) dlsym(handle, name);
-#endif
-	return var != nullptr;
-}
-
 CurlClient::Curl::Curl()
 : handle(nullptr)
 , loaded(false)
@@ -48,33 +25,35 @@ CurlClient::Curl::Curl()
 , slist_append(nullptr)
 , slist_free_all(nullptr)
 {
+	using namespace LibraryLoader;
+
 #ifdef _WIN32
-	handle = (void *) LoadLibraryA("libcurl.dll");
+	handle = OpenLibrary("libcurl.dll");
 #else
-	handle = dlopen("libcurl.so.4", RTLD_LAZY);
+	handle = OpenLibrary("libcurl.so.4");
 #endif
 	if (!handle)
 		return;
 
 	// Load symbols
 	decltype(&curl_global_init) global_init = nullptr;
-	if (!loadSymbol(global_init, handle, "curl_global_init"))
+	if (!LoadSymbol(global_init, handle, "curl_global_init"))
 		return;
-	if (!loadSymbol(global_cleanup, handle, "curl_global_cleanup"))
+	if (!LoadSymbol(global_cleanup, handle, "curl_global_cleanup"))
 		return;
-	if (!loadSymbol(easy_init, handle, "curl_easy_init"))
+	if (!LoadSymbol(easy_init, handle, "curl_easy_init"))
 		return;
-	if (!loadSymbol(easy_cleanup, handle, "curl_easy_cleanup"))
+	if (!LoadSymbol(easy_cleanup, handle, "curl_easy_cleanup"))
 		return;
-	if (!loadSymbol(easy_setopt, handle, "curl_easy_setopt"))
+	if (!LoadSymbol(easy_setopt, handle, "curl_easy_setopt"))
 		return;
-	if (!loadSymbol(easy_perform, handle, "curl_easy_perform"))
+	if (!LoadSymbol(easy_perform, handle, "curl_easy_perform"))
 		return;
-	if (!loadSymbol(easy_getinfo, handle, "curl_easy_getinfo"))
+	if (!LoadSymbol(easy_getinfo, handle, "curl_easy_getinfo"))
 		return;
-	if (!loadSymbol(slist_append, handle, "curl_slist_append"))
+	if (!LoadSymbol(slist_append, handle, "curl_slist_append"))
 		return;
-	if (!loadSymbol(slist_free_all, handle, "curl_slist_free_all"))
+	if (!LoadSymbol(slist_free_all, handle, "curl_slist_free_all"))
 		return;
 
 	global_init(CURL_GLOBAL_DEFAULT);
@@ -87,11 +66,7 @@ CurlClient::Curl::~Curl()
 		global_cleanup();
 
 	if (handle)
-#ifdef _WIN32
-		FreeLibrary((HMODULE) handle);
-#else
-		dlclose(handle);
-#endif
+		LibraryLoader::CloseLibrary(handle);
 }
 
 static char toUppercase(char c)

+ 2 - 1
src/generic/CurlClient.h

@@ -7,6 +7,7 @@
 #include <curl/curl.h>
 
 #include "../common/HTTPSClient.h"
+#include "../common/LibraryLoader.h"
 
 class CurlClient : public HTTPSClient
 {
@@ -19,7 +20,7 @@ private:
 	{
 		Curl();
 		~Curl();
-		void *handle;
+		LibraryLoader::handle *handle;
 		bool loaded;
 
 		decltype(&curl_global_cleanup) global_cleanup;

+ 32 - 37
src/generic/OpenSSLConnection.cpp

@@ -2,65 +2,60 @@
 
 #ifdef HTTPS_BACKEND_OPENSSL
 
-#include <dlfcn.h>
+#include "../common/LibraryLoader.h"
 
 // Not present in openssl 1.1 headers
 #define SSL_CTRL_OPTIONS 32
 
-template <class T>
-static inline bool loadSymbol(T &var, void *handle, const char *name)
-{
-	var = reinterpret_cast<T>(dlsym(handle, name));
-	return var != nullptr;
-}
-
 OpenSSLConnection::SSLFuncs::SSLFuncs()
 {
+	using namespace LibraryLoader;
+
 	valid = false;
 
 	// Try OpenSSL 1.1
-	void *sslhandle = dlopen("libssl.so.1.1", RTLD_LAZY);
-	void *cryptohandle = dlopen("libcrypto.so.1.1", RTLD_LAZY);
+	handle *sslhandle = OpenLibrary("libssl.so.1.1");
+	handle *cryptohandle = OpenLibrary("libcrypto.so.1.1");
 	// Try OpenSSL 1.0
 	if (!sslhandle || !cryptohandle)
 	{
-		sslhandle = dlopen("libssl.so.1.0.0", RTLD_LAZY);
-		cryptohandle = dlopen("libcrypto.so.1.0.0", RTLD_LAZY);
+		sslhandle = OpenLibrary("libssl.so.1.0.0");
+		cryptohandle = OpenLibrary("libcrypto.so.1.0.0");
 	}
 	// Try OpenSSL without version
 	if (!sslhandle || !cryptohandle)
 	{
-		sslhandle = dlopen("libssl.so", RTLD_LAZY);
-		cryptohandle = dlopen("libcrypto.so", RTLD_LAZY);
+		sslhandle = OpenLibrary("libssl.so");
+		cryptohandle = OpenLibrary("libcrypto.so");
 	}
 	// Give up
 	if (!sslhandle || !cryptohandle)
 		return;
 
 	valid = true;
-	valid = valid && (loadSymbol(library_init, sslhandle, "SSL_library_init") ||
-			loadSymbol(init_ssl, sslhandle, "OPENSSL_init_ssl"));
-
-	valid = valid && loadSymbol(CTX_new, sslhandle, "SSL_CTX_new");
-	valid = valid && loadSymbol(CTX_ctrl, sslhandle, "SSL_CTX_ctrl");
-	valid = valid && loadSymbol(CTX_set_verify, sslhandle, "SSL_CTX_set_verify");
-	valid = valid && loadSymbol(CTX_set_default_verify_paths, sslhandle, "SSL_CTX_set_default_verify_paths");
-	valid = valid && loadSymbol(CTX_free, sslhandle, "SSL_CTX_free");
-
-	valid = valid && loadSymbol(SSL_new, sslhandle, "SSL_new");
-	valid = valid && loadSymbol(SSL_free, sslhandle, "SSL_free");
-	valid = valid && loadSymbol(set_fd, sslhandle, "SSL_set_fd");
-	valid = valid && loadSymbol(connect, sslhandle, "SSL_connect");
-	valid = valid && loadSymbol(read, sslhandle, "SSL_read");
-	valid = valid && loadSymbol(write, sslhandle, "SSL_write");
-	valid = valid && loadSymbol(shutdown, sslhandle, "SSL_shutdown");
-	valid = valid && loadSymbol(get_verify_result, sslhandle, "SSL_get_verify_result");
-	valid = valid && loadSymbol(get_peer_certificate, sslhandle, "SSL_get_peer_certificate");
-
-	valid = valid && (loadSymbol(SSLv23_method, sslhandle, "SSLv23_method") ||
-			loadSymbol(SSLv23_method, sslhandle, "TLS_method"));
-
-	valid = valid && loadSymbol(check_host, cryptohandle, "X509_check_host");
+	valid = valid && (LoadSymbol(library_init, sslhandle, "SSL_library_init") ||
+			LoadSymbol(init_ssl, sslhandle, "OPENSSL_init_ssl"));
+
+	valid = valid && LoadSymbol(CTX_new, sslhandle, "SSL_CTX_new");
+	valid = valid && LoadSymbol(CTX_ctrl, sslhandle, "SSL_CTX_ctrl");
+	valid = valid && LoadSymbol(CTX_set_verify, sslhandle, "SSL_CTX_set_verify");
+	valid = valid && LoadSymbol(CTX_set_default_verify_paths, sslhandle, "SSL_CTX_set_default_verify_paths");
+	valid = valid && LoadSymbol(CTX_free, sslhandle, "SSL_CTX_free");
+
+	valid = valid && LoadSymbol(SSL_new, sslhandle, "SSL_new");
+	valid = valid && LoadSymbol(SSL_free, sslhandle, "SSL_free");
+	valid = valid && LoadSymbol(set_fd, sslhandle, "SSL_set_fd");
+	valid = valid && LoadSymbol(connect, sslhandle, "SSL_connect");
+	valid = valid && LoadSymbol(read, sslhandle, "SSL_read");
+	valid = valid && LoadSymbol(write, sslhandle, "SSL_write");
+	valid = valid && LoadSymbol(shutdown, sslhandle, "SSL_shutdown");
+	valid = valid && LoadSymbol(get_verify_result, sslhandle, "SSL_get_verify_result");
+	valid = valid && LoadSymbol(get_peer_certificate, sslhandle, "SSL_get_peer_certificate");
+
+	valid = valid && (LoadSymbol(SSLv23_method, sslhandle, "SSLv23_method") ||
+			LoadSymbol(SSLv23_method, sslhandle, "TLS_method"));
+
+	valid = valid && LoadSymbol(check_host, cryptohandle, "X509_check_host");
 
 	if (library_init)
 		library_init();