Explorar o código

Merge pull request #6 from MikuAuahDark/morehttpmethod

Supports HEAD, PUT, PATCH, and DELETE HTTP methods.
Alex Szpakowski %!s(int64=3) %!d(string=hai) anos
pai
achega
cc33bda552

+ 7 - 1
src/android/AndroidClient.cpp

@@ -97,6 +97,7 @@ HTTPSClient::Reply AndroidClient::request(const HTTPSClient::Request &req)
 
 	jmethodID constructor = env->GetMethodID(httpsClass, "<init>", "()V");
 	jmethodID setURL = env->GetMethodID(httpsClass, "setUrl", "(Ljava/lang/String;)V");
+	jmethodID setMethod = env->GetMethodID(httpsClass, "setMethod", "(Ljava/lang/String;)V");
 	jmethodID request = env->GetMethodID(httpsClass, "request", "()Z");
 	jmethodID getInterleavedHeaders = env->GetMethodID(httpsClass, "getInterleavedHeaders", "()[Ljava/lang/String;");
 	jmethodID getResponse = env->GetMethodID(httpsClass, "getResponse", "()[B");
@@ -109,8 +110,13 @@ HTTPSClient::Reply AndroidClient::request(const HTTPSClient::Request &req)
 	env->CallVoidMethod(httpsObject, setURL, url);
 	env->DeleteLocalRef(url);
 
+	// Set method
+	jstring method = env->NewStringUTF(req.method.c_str());
+	env->CallVoidMethod(httpsObject, setMethod, method);
+	env->DeleteLocalRef(method);
+
 	// Set post data
-	if (req.method == Request::POST)
+	if (req.postdata.size() > 0)
 	{
 		jmethodID setPostData = env->GetMethodID(httpsClass, "setPostData", "([B)V");
 		jbyteArray byteArray = env->NewByteArray((jsize) req.postdata.length());

+ 21 - 1
src/android/java/org/love2d/luahttps/LuaHTTPS.java

@@ -11,6 +11,7 @@ import java.io.InputStream;
 import java.io.OutputStream;
 import java.net.HttpURLConnection;
 import java.net.MalformedURLException;
+import java.net.ProtocolException;
 import java.net.URL;
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -22,6 +23,7 @@ class LuaHTTPS {
     static private String TAG = "LuaHTTPS";
 
     private String urlString;
+    private String method;
     private byte[] postData;
     private byte[] response;
     private int responseCode;
@@ -34,6 +36,7 @@ class LuaHTTPS {
 
     public void reset() {
         urlString = null;
+        method = "GET";
         postData = null;
         response = null;
         responseCode = 0;
@@ -50,6 +53,11 @@ class LuaHTTPS {
         this.postData = postData;
     }
 
+    @Keep
+    public void setMethod(String method) {
+        this.method = method.toUpperCase();
+    }
+
     @Keep
     public void addHeader(String key, String value) {
         headers.put(key, value);
@@ -110,13 +118,21 @@ class LuaHTTPS {
             return false;
         }
 
+        // Set request method
+        try {
+            connection.setRequestMethod(method);
+        } catch (ProtocolException e) {
+            Log.e(TAG, "Error", e);
+            return false;
+        }
+
         // Set header
         for (Map.Entry<String, String> headerData: headers.entrySet()) {
             connection.setRequestProperty(headerData.getKey(), headerData.getValue());
         }
 
         // Set post data
-        if (postData != null) {
+        if (postData != null && canSendData()) {
             connection.setDoOutput(true);
             connection.setChunkedStreamingMode(0);
 
@@ -168,4 +184,8 @@ class LuaHTTPS {
         connection.disconnect();
         return true;
     }
+
+    private boolean canSendData() {
+        return !method.equals("GET") && !method.equals("HEAD");
+    }
 }

+ 4 - 8
src/apple/NSURLClient.mm

@@ -29,16 +29,12 @@ HTTPSClient::Reply NSURLClient::request(const HTTPSClient::Request &req)
 	NSMutableURLRequest *request = [NSMutableURLRequest requestWithURL:url];
 
 	NSData *bodydata = nil;
-	switch(req.method)
+	[request setHTTPMethod:@(req.method.c_str())];
+
+	if (req.postdata.size() > 0 && (req.method != "GET" && req.method != "HEAD"))
 	{
-	case Request::GET:
-		[request setHTTPMethod:@"GET"];
-		break;
-	case Request::POST:
 		bodydata = [NSData dataWithBytesNoCopy:(void*) req.postdata.data() length:req.postdata.size() freeWhenDone:NO];
-		[request setHTTPMethod:@"POST"];
 		[request setHTTPBody:bodydata];
-		break;
 	}
 
 	for (auto &header : req.headers)
@@ -63,7 +59,7 @@ HTTPSClient::Reply NSURLClient::request(const HTTPSClient::Request &req)
 	dispatch_semaphore_wait(sem, DISPATCH_TIME_FOREVER);
 
 	HTTPSClient::Reply reply;
-	reply.responseCode = 400;
+	reply.responseCode = 0;
 
 	if (body)
 	{

+ 10 - 4
src/common/HTTPRequest.cpp

@@ -35,7 +35,13 @@ HTTPSClient::Reply HTTPRequest::request(const HTTPSClient::Request &req)
 	// Build the request
 	{
 		std::stringstream request;
-		request << (req.method == HTTPSClient::Request::GET ? "GET " : "POST ") << info.query << " HTTP/1.1\r\n";
+		std::string method = req.method;
+		bool hasData = req.postdata.length() > 0;
+
+		if (method.length() == 0)
+			method = hasData ? "POST" : "GET";
+
+		request << method << " " << info.query << " HTTP/1.1\r\n";
 
 		for (auto &header : req.headers)
 			request << header.first << ": " << header.second << "\r\n";
@@ -44,15 +50,15 @@ HTTPSClient::Reply HTTPRequest::request(const HTTPSClient::Request &req)
 
 		request << "Host: " << info.hostname << "\r\n";
 
-		if (req.method == HTTPSClient::Request::POST && req.headers.count("Content-Type") == 0)
+		if (hasData && req.headers.count("Content-Type") == 0)
 			request << "Content-Type: application/x-www-form-urlencoded\r\n";
 
-		if (req.method == HTTPSClient::Request::POST)
+		if (hasData)
 			request << "Content-Length: " << req.postdata.size() << "\r\n";
 
 		request << "\r\n";
 
-		if (req.method == HTTPSClient::Request::POST)
+		if (hasData)
 			request << req.postdata;
 
 		// Send it

+ 1 - 1
src/common/HTTPSClient.cpp

@@ -31,7 +31,7 @@ bool HTTPSClient::ci_string_less::operator()(const std::string &lhs, const std::
 
 HTTPSClient::Request::Request(const std::string &url)
 	: url(url)
-	, method(GET)
+	, method("")
 {
 }
 

+ 1 - 6
src/common/HTTPSClient.h

@@ -20,12 +20,7 @@ public:
 		header_map headers;
 		std::string url;
 		std::string postdata;
-
-		enum Method
-		{
-			GET,
-			POST,
-		} method;
+		std::string method;
 	};
 
 	struct Reply

+ 8 - 2
src/generic/CurlClient.cpp

@@ -73,9 +73,15 @@ HTTPSClient::Reply CurlClient::request(const HTTPSClient::Request &req)
 	curl.easy_setopt(handle, CURLOPT_URL, req.url.c_str());
 	curl.easy_setopt(handle, CURLOPT_FOLLOWLOCATION, 1L);
 
-	if (req.method == Request::POST)
-	{
+	if (req.method == "PUT")
+		curl.easy_setopt(handle, CURLOPT_PUT, 1L);
+	else if (req.method == "POST")
 		curl.easy_setopt(handle, CURLOPT_POST, 1L);
+	else
+		curl.easy_setopt(handle, CURLOPT_CUSTOMREQUEST, req.method.c_str());
+
+	if (req.postdata.size() > 0 && (req.method != "GET" && req.method != "HEAD"))
+	{
 		curl.easy_setopt(handle, CURLOPT_POSTFIELDS, req.postdata.c_str());
 		curl.easy_setopt(handle, CURLOPT_POSTFIELDSIZE, req.postdata.size());
 	}

+ 22 - 11
src/lua/main.cpp

@@ -1,3 +1,6 @@
+#include <algorithm>
+#include <set>
+
 extern "C"
 {
 #include <lua.h>
@@ -7,6 +10,14 @@ extern "C"
 #include "../common/HTTPS.h"
 #include "../common/config.h"
 
+static std::string validMethod[] = {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"};
+
+static int str_toupper(char c)
+{
+	unsigned char uc = (unsigned char) c;
+	return toupper(uc);
+}
+
 static std::string w_checkstring(lua_State *L, int idx)
 {
 	size_t len;
@@ -34,20 +45,20 @@ static void w_readheaders(lua_State *L, int idx, HTTPSClient::header_map &header
 	lua_pop(L, 1);
 }
 
-static HTTPSClient::Request::Method w_optmethod(lua_State *L, int idx, HTTPSClient::Request::Method defaultMethod)
+static std::string w_optmethod(lua_State *L, int idx, const std::string &defaultMethod)
 {
+	std::string *const validMethodEnd = validMethod + sizeof(validMethod) / sizeof(std::string);
+
 	if (lua_isnoneornil(L, idx))
 		return defaultMethod;
 
-	auto str = w_checkstring(L, idx);
-	if (str == "get")
-		return HTTPSClient::Request::GET;
-	else if (str == "post")
-		return HTTPSClient::Request::POST;
-	else
-		luaL_argerror(L, idx, "expected one of \"get\" or \"set\"");
+	std::string str = w_checkstring(L, idx);
+	std::transform(str.begin(), str.end(), str.begin(), str_toupper);
+
+	if (std::find(validMethod, validMethodEnd, str) == validMethodEnd)
+		luaL_argerror(L, idx, "expected one of \"get\", \"head\", \"post\", \"put\", \"delete\", or \"patch\"");
 
-	return defaultMethod;
+	return str;
 }
 
 static int w_request(lua_State *L)
@@ -61,13 +72,13 @@ static int w_request(lua_State *L)
 	{
 		advanced = true;
 
-		HTTPSClient::Request::Method defaultMethod = HTTPSClient::Request::GET;
+		std::string defaultMethod = "GET";
 
 		lua_getfield(L, 2, "data");
 		if (!lua_isnoneornil(L, -1))
 		{
 			req.postdata = w_checkstring(L, -1);
-			defaultMethod = HTTPSClient::Request::POST;
+			defaultMethod = "POST";
 		}
 		lua_pop(L, 1);
 

+ 6 - 2
src/windows/SChannelConnection.cpp

@@ -55,8 +55,12 @@ static size_t dequeue(std::vector<char> &buffer, char *data, size_t size)
 	size_t remaining = buffer.size() - size;
 
 	memcpy(data, &buffer[0], size);
-	memmove(&buffer[0], &buffer[size], remaining);
-	buffer.resize(remaining);
+
+	if (remaining > 0)
+	{
+		memmove(&buffer[0], &buffer[size], remaining);
+		buffer.resize(remaining);
+	}
 
 	return size;
 }