Browse Source

[cs] First implementation of sys.db using ADO.NET

Cauê Waneck 10 years ago
parent
commit
d88b79b448

+ 48 - 0
std/cs/_std/sys/db/Sqlite.hx

@@ -0,0 +1,48 @@
+package sys.db;
+
+class Sqlite
+{
+	static var type:Class<cs.system.data.IDbConnection>;
+	/**
+		Opens a new SQLite connection on the specified path.
+		Note that you will need a SQLite ADO.NET Provider (see http://www.mono-project.com/docs/database-access/providers/sqlite/).
+		Also note that this will try to open an assembly named `Mono.Data.Sqlite` if it wasn't loaded yet.
+	**/
+	public static function open(file:String):sys.db.Connection
+	{
+		var cnxString = 'Data Source=$file';
+		if (type == null)
+		{
+			var t = null;
+			var assemblies = cs.system.AppDomain.CurrentDomain.GetAssemblies();
+			for (i in 0...assemblies.Length)
+			{
+				var a = assemblies[i];
+				t = a.GetType('Mono.Data.Sqlite.SqliteConnection');
+				if (t == null)
+					t = a.GetType('System.Data.SQLite.SQLiteConnection');
+				if (t != null)
+				{
+					break;
+				}
+			}
+
+			if (t == null)
+			{
+				var asm = cs.system.reflection.Assembly.Load('Mono.Data.Sqlite');
+				t = asm.GetType('Mono.Data.Sqlite.SqliteConnection');
+			}
+
+			if (t != null)
+				type = cast cs.Lib.fromNativeType(t);
+		}
+
+		if (type == null)
+		{
+			throw "No ADO.NET SQLite provider was found!";
+		}
+		var ret = Type.createInstance(type,[cnxString]);
+		ret.Open();
+		return cs.db.AdoNet.create(ret,'SQLite');
+	}
+}

+ 354 - 0
std/cs/db/AdoNet.hx

@@ -0,0 +1,354 @@
+package cs.db;
+import sys.db.*;
+import cs.system.data.*;
+
+class AdoNet
+{
+	public static function create(cnx:IDbConnection, dbName:String):Connection
+	{
+		return new AdoConnection(cnx,dbName);
+	}
+}
+
+private class AdoConnection implements Connection
+{
+	private static var ids = 0;
+	private var id:Int;
+
+	private var cnx:IDbConnection;
+	//escape handling
+	private var escapeRegex:EReg;
+	private var escapes:Array<IDbDataParameter>;
+	private var name:String;
+	private var command:IDbCommand;
+	private var transaction:IDbTransaction;
+
+	public function new(cnx,name:String)
+	{
+		this.id = cs.system.threading.Interlocked.Increment(ids);
+		this.cnx = cnx;
+		this.name = name;
+		this.escapes = [];
+		this.command = cnx.CreateCommand();
+		this.escapeRegex = ~/@@HX_ESCAPE(\d+)_(\d+)@@/;
+	}
+
+	public function close() : Void
+	{
+		cnx.Close();
+	}
+
+	public function escape(s:String):String
+	{
+		var param = command.CreateParameter();
+		var name = "@@HX_ESCAPE" + id + "_" +escapes.push(param) + "@@";
+		param.ParameterName = name;
+		param.Value = s;
+		return name;
+	}
+
+	public function quote(s:String):String
+	{
+		var param = command.CreateParameter();
+		var name = "@@HX_ESCAPE" + id + "_" +escapes.push(param) + "@@";
+		param.ParameterName = name;
+		param.Value = s;
+		return name;
+	}
+
+	public function addValue(s:StringBuf, v:Dynamic)
+	{
+		if (Std.is(v, Date))
+		{
+			v = Std.string(v);
+		} else if (Std.is(v, haxe.io.Bytes)) {
+			var bt:haxe.io.Bytes = v;
+			v = bt.getData();
+		}
+		var param = command.CreateParameter();
+		var name = "@@HX_ESCAPE" + id + "_" +escapes.push(param) + "@@";
+		param.ParameterName = name;
+		param.Value = v;
+		s.add(name);
+	}
+
+	public function lastInsertId():Int
+	{
+		var ret = cnx.CreateCommand();
+		ret.CommandText = 'SELECT @@IDENTITY';
+		ret.CommandType = CommandType.Text;
+		var r = cast ret.ExecuteScalar();
+		ret.Dispose();
+
+		return r;
+	}
+
+	public function dbName() : String
+	{
+		return name;
+	}
+
+	public function startTransaction() : Void
+	{
+		if (this.transaction != null)
+			throw 'Transaction already active';
+		this.transaction = cnx.BeginTransaction();
+	}
+
+	public function commit() : Void
+	{
+		if (this.transaction == null)
+			throw 'No transaction was initiated';
+		this.transaction.Commit();
+	}
+
+	public function rollback() : Void
+	{
+		if (this.transaction == null)
+			throw 'No transaction was initiated';
+		this.transaction.Rollback();
+	}
+
+	private static function getFirstStatement(s:String)
+	{
+		var buf = new StringBuf();
+		var hasData = false;
+		var chr = 0,
+				i = 0;
+		inline function getch() return chr = StringTools.fastCodeAt(s,i++);
+		while ( !StringTools.isEof(getch()) )
+		{
+			inline function peek() { var c = StringTools.fastCodeAt(s,i); if (StringTools.isEof(c)) break; return c; }
+			switch(chr)
+			{
+				case ' '.code | '\t'.code | '\n'.code:
+					if (hasData)
+						return buf.toString();
+				case '-'.code if (peek() == '-'.code):
+					if (hasData)
+						return buf.toString();
+					while (!StringTools.isEof(getch()))
+					{
+						if (chr == '\n'.code) break;
+					}
+				case '#'.code:
+					if (hasData)
+						return buf.toString();
+					while (!StringTools.isEof(getch()))
+					{
+						if (chr == '\n'.code) break;
+					}
+				case '/'.code if (peek() == '*'.code):
+					i++;
+					if (hasData)
+						return buf.toString();
+					while (!StringTools.isEof(getch()))
+					{
+						if (chr == '*'.code && peek() == '/'.code)
+						{
+							i++;
+							break;
+						}
+					}
+				case _:
+					hasData = true;
+					buf.addChar(chr);
+			}
+		}
+		return buf.toString();
+	}
+
+	public function request( s : String ) : ResultSet
+	{
+		//cycle through the request string, adding any @@HX_ESCAPE@@ reference to the sentArray
+		var ret:ResultSet = null;
+		var r = escapeRegex;
+		var myid = id + "", escapes = escapes, elen = escapes.length;
+		var cmd = this.command;
+		try
+		{
+			var check = s;
+			//check escape names
+			while (r.match(check))
+			{
+				var id = r.matched(1);
+				if (id != myid) throw "Request quotes are only valid for one single request; They can't be cached.";
+
+				var eid = Std.parseInt(r.matched(2));
+				if (eid == null || eid > elen)
+					throw "Invalid request quote ID " + eid;
+				check = r.matchedRight();
+			}
+
+			trace(s);
+			for (param in escapes)
+			{
+				trace(param.ParameterName);
+				cmd.Parameters.Add(param);
+			}
+			trace(cmd);
+			cmd.CommandText = s;
+
+			var stmt = getFirstStatement(s).toLowerCase();
+			if (stmt == 'insert')
+			{
+				ret = new AdoResultSet( cmd.ExecuteReader() );
+			} else {
+				cmd.ExecuteNonQuery();
+				ret = EmptyResultSet.empty;
+			}
+
+			if (escapes.length != 0)
+				this.escapes = [];
+			this.id = cs.system.threading.Interlocked.Increment(ids);
+			cmd.Dispose();
+			this.command = cnx.CreateCommand();
+			return ret;
+		}
+		catch(e:Dynamic)
+		{
+			trace(escapes.length);
+			if (escapes.length != 0)
+				this.escapes = [];
+			this.id = cs.system.threading.Interlocked.Increment(ids);
+			try { cmd.Dispose(); } catch(e:Dynamic) {}
+			this.command = cnx.CreateCommand();
+			cs.Lib.rethrow(e);
+		}
+		return null;
+	}
+}
+
+private class AdoResultSet implements ResultSet
+{
+	public var length(get,null) : Int;
+	public var nfields(get,null) : Int;
+
+	private var reader:IDataReader;
+	private var readFirst = false;
+	private var names:Array<String>;
+	private var types:Array<Class<Dynamic>>;
+
+	public function new(reader)
+	{
+		this.reader = reader;
+		this.names = [ for (i in 0...reader.FieldCount) reader.GetName(i) ];
+		this.types = [ for (i in 0...names.length) cs.Lib.fromNativeType(reader.GetFieldType(i)) ];
+	}
+
+	private function get_length()
+	{
+		return reader.Depth;
+	}
+
+	private function get_nfields()
+	{
+		return names.length;
+	}
+
+	public function hasNext() : Bool
+	{
+		if (!readFirst)
+		{
+			readFirst = true;
+			return reader.Depth > 0;
+		} else {
+			return reader.NextResult();
+		}
+	}
+
+	public function next() : Dynamic
+	{
+		var ret = {}, names = names, types = types;
+		for (i in 0...names.length)
+		{
+			var name = names[i], t = types[i], val:Dynamic = null;
+			trace(name,t);
+			if (t == cs.system.Single)
+			{
+				val = reader.GetDouble(i);
+			} else if (t == cs.system.DateTime || t == cs.system.TimeSpan) {
+				// if (dbName == "SQLite")
+				// {
+				// 	var str = rs.getString(i+1);
+				// 	if (str != null)
+				// 	{
+				// 		var d:Date = Date.fromString(str);
+				// 		val = d;
+				// 	}
+				// } else {
+					var d = reader.GetDateTime(i);
+					if (d != null)
+						val = Date.fromTime(cast(d.Ticks,Float) / cast(cs.system.TimeSpan.TicksPerMillisecond,Float));
+				// }
+			// } else if (t == Types.LONGVARBINARY || t == Types.VARBINARY || t == Types.BINARY || t == Types.BLOB) {
+			// 	var b = rs.getBytes(i+1);
+			// 	if (b != null)
+			// 		val = Bytes.ofData(b);
+			} else {
+				val = reader.GetValue(i);
+				// untyped __java__("val = rs.getObject(i + 1)"); //type parameter constraint + overloads
+			}
+			Reflect.setField(ret, name, val);
+		}
+		return ret;
+	}
+
+	public function results() : List<Dynamic>
+	{
+		var l = new List();
+		while (hasNext())
+			l.add(next());
+		return l;
+	}
+
+	public function getResult( n : Int ) : String
+	{
+		return reader.GetString(n);
+	}
+
+	public function getIntResult( n : Int ) : Int
+	{
+		return reader.GetInt32(n);
+	}
+
+	public function getFloatResult( n : Int ) : Float
+	{
+		return reader.GetDouble(n);
+	}
+
+	public function getFieldsNames() : Null<Array<String>>
+	{
+		return names;
+	}
+
+}
+
+private class EmptyResultSet implements ResultSet
+{
+	public static var empty = new EmptyResultSet();
+	public function new()
+	{
+	}
+
+	public var length(get,null) : Int;
+	public var nfields(get,null) : Int;
+
+	private function get_length()
+	{
+		return 0;
+	}
+
+	private function get_nfields()
+	{
+		return 0;
+	}
+
+	public function hasNext() : Bool return false;
+	public function next() : Dynamic return null;
+	public function results() : List<Dynamic> return new List();
+	public function getResult( n : Int ) : String return null;
+	public function getIntResult( n : Int ) : Int return 0;
+	public function getFloatResult( n : Int ) : Float return 0;
+	public function getFieldsNames() : Null<Array<String>> return null;
+}

+ 3 - 0
tests/unit/compile-cs.hxml

@@ -7,3 +7,6 @@ compile-each.hxml
 -main unit.Test
 -cs bin/cs
 -net-lib native_cs/bin/native_cs.dll
+-net-lib cs_drivers/System.Data.dll
+-net-lib cs_drivers/System.Xml.dll
+-net-lib cs_drivers/Mono.Data.Sqlite.dll

BIN
tests/unit/cs_drivers/Mono.Data.Sqlite.dll


BIN
tests/unit/cs_drivers/System.Data.dll


BIN
tests/unit/cs_drivers/System.Xml.dll


+ 2 - 2
tests/unit/src/unit/Test.hx

@@ -300,8 +300,8 @@ class Test #if swf_mark implements mt.Protect #end {
 			//new TestRemoting(),
 		];
 		// SPOD tests
-		#if ( (neko || php || java || cpp) && !macro && !interp)
-		#if !cpp
+		#if ( (neko || php || java || cpp || cs) && !macro && !interp)
+		#if !(cpp || cs)
 		if (Sys.getEnv("CI") != null && Sys.systemName() == "Linux")
 		{
 			classes.push(new TestSpod(sys.db.Mysql.connect({