Browse Source

use protocol instance.

Nicolas Cannasse 18 years ago
parent
commit
09dd581f09

+ 16 - 12
std/haxe/remoting/NekoSocketConnection.hx

@@ -37,42 +37,46 @@ class NekoSocketConnection extends Connection {
 	}
 
 	override public function call( params : Array<Dynamic> ) : Dynamic {
-		var sock = getSocket();
-		SocketProtocol.sendRequest(sock,__path,params);
+		var proto = getProtocol();
+		proto.sendRequest(__path,params);
 		while( true ) {
-			var data = SocketProtocol.readMessage(sock.input);
-			if( SocketProtocol.isRequest(data) ) {
+			var data = proto.readMessage();
+			if( proto.isRequest(data) ) {
 				if( __r == null )
 					throw "Request received";
-				SocketProtocol.processRequest(sock,data,__r.resolvePath,onRequestError);
+				proto.processRequest(data,__r.resolvePath,onRequestError);
 				continue;
 			}
-			return SocketProtocol.decodeAnswer(data);
+			return proto.decodeAnswer(data);
 		}
 		return null;
 	}
 
 	public function processRequest() {
-		var sock = getSocket();
+		var proto = getProtocol();
 		if( __r == null )
 			throw "No RemotingServer defined";
-		var data = SocketProtocol.readMessage(sock.input);
-		SocketProtocol.processRequest(sock,data,__r.resolvePath,onRequestError);
+		var data = proto.readMessage();
+		proto.processRequest(data,__r.resolvePath,onRequestError);
 	}
 
 	public function onRequestError( path : Array<String>, method : String, args : Array<Dynamic>, exc : Dynamic ) {
 	}
 
-	public function getSocket() : Socket {
+	public function setProtocol( p : SocketProtocol ) {
+		__data = p;
+	}
+
+	public function getProtocol() : SocketProtocol {
 		return __data;
 	}
 
 	public function closeConnection() {
-		try getSocket().close() catch( e : Dynamic ) { };
+		try getProtocol().socket.close() catch( e : Dynamic ) { };
 	}
 
 	public static function socketConnect( s : neko.net.Socket, ?r : neko.net.RemotingServer ) {
-		var sc = new NekoSocketConnection(s,[]);
+		var sc = new NekoSocketConnection(new SocketProtocol(s),[]);
 		sc.__r = r;
 		return sc;
 	}

+ 24 - 9
std/haxe/remoting/SocketConnection.hx

@@ -45,25 +45,30 @@ class SocketConnection extends AsyncConnection {
 
 	override public function call( params : Array<Dynamic>, ?onData : Dynamic -> Void ) : Void {
 		try {
-			SocketProtocol.sendRequest(getSocket(),__path,params);
+			getProtocol().sendRequest(__path,params);
 			__funs.add(onData);
 		} catch( e : Dynamic ) {
 			__error.ref(e);
 		}
 	}
 
-	public function getSocket() : Socket {
+	public function setProtocol( p : SocketProtocol ) {
+		__data = p;
+	}
+
+	public function getProtocol() : SocketProtocol {
 		return __data;
 	}
 
 	public function closeConnection() {
-		try getSocket().close() catch( e : Dynamic ) { };
+		try getProtocol().socket.close() catch( e : Dynamic ) { };
 	}
 
 	public function processMessage( data : String ) {
 		var request;
+		var proto = getProtocol();
 		try {
-			request = SocketProtocol.isRequest(data);
+			request = proto.isRequest(data);
 		} catch( e : Dynamic ) {
 			__error.ref(e); // protocol error
 			return;
@@ -81,7 +86,7 @@ class SocketConnection extends AsyncConnection {
 						function(path:Array<String>) { return js.Lib.eval(path.join(".")); }
 					#end
 				;
-				SocketProtocol.processRequest(getSocket(),data,eval,function(path,name,args,e) {
+				proto.processRequest(data,eval,function(path,name,args,e) {
 					// exception inside the called method
 					var astr, estr;
 					try astr = args.join(",") catch( e : Dynamic ) astr = "???";
@@ -100,7 +105,7 @@ class SocketConnection extends AsyncConnection {
 			if( __funs.isEmpty() )
 				throw "No response excepted ("+data+")";
 			f = __funs.pop();
-			v = SocketProtocol.decodeAnswer(data);
+			v = proto.decodeAnswer(data);
 		} catch( e : Dynamic ) {
 			__error.ref(e); // protocol error or answer exception
 			return;
@@ -112,7 +117,7 @@ class SocketConnection extends AsyncConnection {
 	#if neko
 
 	public static function socketConnect( s : neko.net.Socket, r : neko.net.RemotingServer ) {
-		var sc = new SocketConnection(s,[]);
+		var sc = new SocketConnection(new SocketProtocol(s),[]);
 		sc.__funs = new List();
 		sc.__r = r;
 		return sc;
@@ -121,10 +126,16 @@ class SocketConnection extends AsyncConnection {
 	#else (flash || js)
 
 	public static function socketConnect( s : Socket ) {
-		var sc = new SocketConnection(s,[]);
+		var sc = new SocketConnection(new SocketProtocol(s),[]);
 		sc.__funs = new List();
 		#if flash9
 		s.addEventListener(flash.events.DataEvent.DATA, function(e : flash.events.DataEvent) {
+			var data = e.data;
+			var msgLen = sc.getProtocol().messageLength(data.charCodeAt(0),data.charCodeAt(1));
+			if( msgLen == null || data.length != msgLen - 1 ) {
+				sc.__error.ref("Invalid message header");
+				return;
+			}
 			sc.processMessage(e.data.substr(2,e.data.length-2));
 		});
 		#else true
@@ -134,8 +145,12 @@ class SocketConnection extends AsyncConnection {
 		// where a new onData is called is a parallel thread
 		// ...with the buffer of the previous onData (!)
 		s.onData = function(data : String) {
-			trace(data);
 			haxe.Timer.queue(function() {
+				var msgLen = sc.getProtocol().messageLength(data.charCodeAt(0),data.charCodeAt(1));
+				if( msgLen == null || data.length != msgLen - 1 ) {
+					sc.__error.ref("Invalid message header");
+					return;
+				}
 				sc.processMessage(data.substr(2,data.length-2));
 			});
 		};

+ 33 - 24
std/haxe/remoting/SocketProtocol.hx

@@ -51,7 +51,13 @@ typedef Socket =
 **/
 class SocketProtocol {
 
-	public static function decodeChar(c) : Null<Int> {
+	public var socket : Socket;
+
+	public function new( sock ) {
+		this.socket = sock;
+	}
+
+	function decodeChar(c) : Null<Int> {
 		// A...Z
 		if( c >= 65 && c <= 90 )
 			return c - 65;
@@ -70,7 +76,7 @@ class SocketProtocol {
 		return null;
 	}
 
-	public static function encodeChar(c) : Null<Int> {
+	function encodeChar(c) : Null<Int> {
 		if( c < 0 )
 			return null;
 		// A...Z
@@ -91,49 +97,49 @@ class SocketProtocol {
 		return null;
 	}
 
-	public static function dataLength( c1 : Int, c2 : Int ) {
+	public function messageLength( c1 : Int, c2 : Int ) {
 		var e1 = decodeChar(c1);
 		var e2 = decodeChar(c2);
 		if( e1 == null || e2 == null )
-			throw "Invalid header";
-		return ((e1 << 6) | e2) - 3;
+			return null;
+		return (e1 << 6) | e2;
 	}
 
-	public static function sendRequest( sock : Socket, path : Array<String>, params : Array<Dynamic> ) {
+	public function sendRequest( path : Array<String>, params : Array<Dynamic> ) {
 		var s = new haxe.Serializer();
 		s.serialize(true);
 		s.serialize(path);
 		s.serialize(params);
-		sendMessage(sock,s.toString());
+		sendMessage(s.toString());
 	}
 
-	public static function sendAnswer( sock : Socket, answer : Dynamic, ?isException : Bool ) {
+	public function sendAnswer( answer : Dynamic, ?isException : Bool ) {
 		var s = new haxe.Serializer();
 		s.serialize(false);
 		if( isException )
 			s.serializeException(answer);
 		else
 			s.serialize(answer);
-		sendMessage(sock,s.toString());
+		sendMessage(s.toString());
 	}
 
-	public static function sendMessage( sock : Socket, msg : String ) {
+	public function sendMessage( msg : String ) {
 		var len = msg.length + 3;
 		var c1 = encodeChar(len>>6);
 		if( c1 == null )
 			throw "Message is too big";
 		var c2 = encodeChar(len&63);
 		#if neko
-		sock.output.writeChar(c1);
-		sock.output.writeChar(c2);
-		sock.output.write(msg);
-		sock.output.writeChar(0);
+		socket.output.writeChar(c1);
+		socket.output.writeChar(c2);
+		socket.output.write(msg);
+		socket.output.writeChar(0);
 		#else true
-		sock.send(Std.chr(c1)+Std.chr(c2)+msg);
+		socket.send(Std.chr(c1)+Std.chr(c2)+msg);
 		#end
 	}
 
-	public static function isRequest( data : String ) {
+	public function isRequest( data : String ) {
 		return switch( haxe.Unserializer.run(data) ) {
 		case true: true;
 		case false: false;
@@ -141,7 +147,7 @@ class SocketProtocol {
 		}
 	}
 
-	public static function processRequest( sock : Socket, data : String, eval : Array<String> -> Dynamic, ?onError : Array<String> -> String -> Array<Dynamic> -> Dynamic -> Void ) {
+	public function processRequest( data : String, eval : Array<String> -> Dynamic, ?onError : Array<String> -> String -> Array<Dynamic> -> Dynamic -> Void ) {
 		var s = new haxe.Unserializer(data);
 		var result : Dynamic;
 		var isException = false;
@@ -171,10 +177,10 @@ class SocketProtocol {
 			s.serializeException(result);
 		else
 			s.serialize(result);
-		sendMessage(sock,s.toString());
+		sendMessage(s.toString());
 	}
 
-	public static function decodeAnswer( data : String ) : Dynamic {
+	public function decodeAnswer( data : String ) : Dynamic {
 		var s = new haxe.Unserializer(data);
 		if( s.unserialize() != false )
 			throw "Not an answer";
@@ -183,11 +189,14 @@ class SocketProtocol {
 
 	#if neko
 
-	public static function readMessage( i : neko.io.Input ) {
-		var c1 = i.readChar();
-		var c2 = i.readChar();
-		var data = i.read(dataLength(c1,c2));
-		if( i.readChar() != 0 )
+	public function readMessage() {
+		var c1 = socket.input.readChar();
+		var c2 = socket.input.readChar();
+		var len = messageLength(c1,c2);
+		if( len == null )
+			throw "Invalid header";
+		var data = socket.input.read(len - 3);
+		if( socket.input.readChar() != 0 )
 			throw "Invalid message";
 		return data;
 	}

+ 4 - 6
std/neko/net/ThreadRemotingServer.hx

@@ -48,10 +48,9 @@ class ThreadRemotingServer extends ThreadServer<haxe.remoting.SocketConnection,S
 		return cnx;
 	}
 
-	public override function readClientMessage( cnx, buf : String, pos : Int, len : Int ) {
-		var c1 = haxe.remoting.SocketProtocol.decodeChar(buf.charCodeAt(pos));
-		var c2 = haxe.remoting.SocketProtocol.decodeChar(buf.charCodeAt(pos+1));
-		if( c1 == null || c2 == null ) {
+	public override function readClientMessage( cnx : haxe.remoting.SocketConnection, buf : String, pos : Int, len : Int ) {
+		var msgLen = cnx.getProtocol().messageLength(buf.charCodeAt(pos),buf.charCodeAt(pos+1));
+		if( msgLen == null ) {
 			if( buf.charCodeAt(pos) != 60 )
 				throw "Invalid remoting message '"+buf.substr(pos,len)+"'";
 			// XML handling
@@ -63,7 +62,6 @@ class ThreadRemotingServer extends ThreadServer<haxe.remoting.SocketConnection,S
 				bytes : p - pos + 1,
 			};
 		}
-		var msgLen = (c1 << 6) | c2;
 		if( len < msgLen )
 			return null;
 		if( buf.charCodeAt(pos + msgLen-1) != 0 )
@@ -88,7 +86,7 @@ class ThreadRemotingServer extends ThreadServer<haxe.remoting.SocketConnection,S
 		} catch( e : Dynamic ) {
 			if( !Std.is(e,neko.io.Eof) && !Std.is(e,neko.io.Error) )
 				logError(e);
-			stopClient(cnx.getSocket());
+			stopClient(cnx.getProtocol().socket);
 		}
 	}