瀏覽代碼

[spod] Null fixes:

 * Test and fix types when passing null as a parameter - both directly and indirectly
 * Fix C# handling of null types on the driver
 * Minor cleanup on gencommon/gencs

Closes #4932
Closes #4931
Cauê Waneck 9 年之前
父節點
當前提交
1d396ee3cb

+ 0 - 15
src/generators/gencommon.ml

@@ -4108,19 +4108,6 @@ struct
 
 	let priority = max_dep -. 20.
 
-	let rec deep_follow gen t = match run_follow gen t with
-		| TInst(c,tl) ->
-			TInst(c,List.map (deep_follow gen) tl)
-		| TEnum(e,tl) ->
-			TEnum(e,List.map (deep_follow gen) tl)
-		| TAbstract(a,tl) ->
-			TAbstract(a,List.map (deep_follow gen) tl)
-		| TType(t,tl) ->
-			TType(t,List.map (deep_follow gen) tl)
-		| TFun(args,ret) ->
-			TFun(List.map (fun (n,o,t) -> n,o,deep_follow gen t) args, deep_follow gen ret)
-		| t -> t
-
 	(* this function will receive the original function argument, the applied function argument and the original function parameters. *)
 	(* from this info, it will infer the applied tparams for the function *)
 	(* this function is used by CastDetection module *)
@@ -4136,8 +4123,6 @@ struct
 
 			(try
 				List.iter2 (fun a o ->
-					let o = deep_follow gen o in
-					let a = deep_follow gen a in
 					unify a o
 					(* type_eq EqStrict a o *)
 				) applied original

+ 1 - 1
src/generators/gencs.ml

@@ -909,7 +909,7 @@ let configure gen =
 				else
 					(match real_type t with
 						| TInst( { cl_kind = KTypeParameter _ }, _ ) -> TInst(null_t, [t])
-						| _ when is_cs_basic_type t -> TInst(null_t, [t])
+						| t when is_cs_basic_type t -> TInst(null_t, [t])
 						| _ -> real_type t)
 			| TAbstract _
 			| TType _ -> t

+ 3 - 2
std/cs/db/AdoNet.hx

@@ -287,8 +287,9 @@ private class AdoResultSet implements ResultSet
 		for (i in 0...names.length)
 		{
 			var name = names[i], t = types[i], val:Dynamic = null;
-			if (t == cs.system.Single)
-			{
+			if (reader.IsDBNull(i)) {
+				val = null;
+			} else if (t == cs.system.Single) {
 				val = reader.GetDouble(i);
 			} else if (t == cs.system.DateTime || t == cs.system.TimeSpan) {
 				var d = reader.GetDateTime(i);

+ 9 - 0
std/sys/db/Manager.hx

@@ -525,6 +525,11 @@ class Manager<T : Object> {
 	}
 
 	public static function nullCompare( a : String, b : String, eq : Bool ) {
+		if (a == null || a == 'NULL') {
+			return eq ? '$b IS NULL' : '$b IS NOT NULL';
+		} else if (b == null || b == 'NULL') {
+			return eq ? '$a IS NULL' : '$a IS NOT NULL';
+		}
 		// we can't use a null-safe operator here
 		if( cnx.dbName() != "MySQL" )
 			return a + (eq ? " = " : " != ") + b;
@@ -718,6 +723,10 @@ class Manager<T : Object> {
 	/* ---------------------------- QUOTES -------------------------- */
 
 	public static function quoteAny( v : Dynamic ) {
+		if (v == null) {
+			return 'NULL';
+		}
+
 		var s = new StringBuf();
 		cnx.addValue(s, v);
 		return s.toString();

+ 22 - 14
std/sys/db/RecordMacros.hx

@@ -488,7 +488,7 @@ class RecordMacros {
 		return { expr : EBinop(OpAdd, sql, makeString(s,sql.pos)), pos : sql.pos };
 	}
 
-	function sqlQuoteValue( v : Expr, t : RecordType ) {
+	function sqlQuoteValue( v : Expr, t : RecordType, isNull : Bool ) {
 		switch( v.expr ) {
 		case EConst(c):
 			switch( c ) {
@@ -505,11 +505,11 @@ class RecordMacros {
 			}
 		default:
 		}
-		return { expr : ECall( { expr : EField(manager, "quoteAny"), pos : v.pos }, [ensureType(v,t)]), pos : v.pos }
+		return { expr : ECall( { expr : EField(manager, "quoteAny"), pos : v.pos }, [ensureType(v,t,isNull)]), pos : v.pos }
 	}
 
-	inline function sqlAddValue( sql : Expr, v : Expr, t : RecordType ) {
-		return { expr : EBinop(OpAdd, sql, sqlQuoteValue(v,t)), pos : sql.pos };
+	inline function sqlAddValue( sql : Expr, v : Expr, t : RecordType, isNull : Bool ) {
+		return { expr : EBinop(OpAdd, sql, sqlQuoteValue(v,t, isNull)), pos : sql.pos };
 	}
 
 	function unifyClass( t : RecordType ) {
@@ -639,8 +639,12 @@ class RecordMacros {
 							var epath = e.split('.');
 							var ename = epath.pop();
 							var etype = TPath({ name:ename, pack:epath });
-							var expr = macro std.Type.enumIndex( @:pos(e2.pos) ( $e2 : $etype ) ); //make sure we have the correct type
-							return { sql: makeOp(eq?" = ":" != ", r1.sql, expr, pos), t : DBool, n : r1.n };
+							if (r1.n) {
+								return { sql: macro $manager.nullCompare(${r1.sql}, { var tmp = @:pos(e2.pos) (${e2} : $etype); tmp == null ? null : std.Type.enumIndex(tmp) + ''; }, $v{eq}), t : DBool, n: true };
+							} else {
+								var expr = macro { @:pos(e2.pos) var tmp : $etype = $e2; (tmp == null ? null : std.Type.enumIndex(tmp)); };
+								return { sql: makeOp(eq?" = ":" != ", r1.sql, expr, pos), t : DBool, n : r1.n };
+							}
 						}
 					default:
 					}
@@ -673,7 +677,7 @@ class RecordMacros {
 		var t = typeof(cond);
 		isNull = false;
 		var d = try makeType(t) catch( e : String ) try makeType(follow(t)) catch( e : String ) error("Unsupported type " + Std.string(t), cond.pos);
-		return { sql : sqlQuoteValue(cond, d), t : d, n : isNull };
+		return { sql : sqlQuoteValue(cond, d, isNull), t : d, n : isNull };
 	}
 
 	function getField( f : { field : String, expr : Expr } ) {
@@ -691,7 +695,7 @@ class RecordMacros {
 					var m = getManager(typeof(mpath),p);
 					var getid = { expr : ECall( { expr : EField(mpath, "unsafeGetId"), pos : p }, [f.expr]), pos : p };
 					f.field = r.key;
-					f.expr = ensureType(getid, m.inf.hfields.get(m.inf.key[0]).t);
+					f.expr = ensureType(getid, m.inf.hfields.get(m.inf.key[0]).t, r.isNull);
 					return inf.hfields.get(r.key);
 				}
 			error("No database field '" + f.field+"'", f.expr.pos);
@@ -714,7 +718,7 @@ class RecordMacros {
 				else
 					sql = sqlAddString(sql, " AND ");
 				sql = sqlAddString(sql, quoteField(fi.name) + (fi.isNull ? " <=> " : " = "));
-				sql = sqlAddValue(sql, f.expr, fi.t);
+				sql = sqlAddValue(sql, f.expr, fi.t, fi.isNull);
 				if( fields.exists(fi.name) )
 					error("Duplicate field " + fi.name, p);
 				else
@@ -807,7 +811,7 @@ class RecordMacros {
 			switch( c ) {
 			case CInt(s): return { sql : makeString(s, p), t : DInt, n : false };
 			case CFloat(s): return { sql : makeString(s, p), t : DFloat, n : false };
-			case CString(s): return { sql : sqlQuoteValue(cond, DText), t : DString(s.length), n : false };
+			case CString(s): return { sql : sqlQuoteValue(cond, DText, false), t : DString(s.length), n : false };
 			case CRegexp(_): error("Unsupported", p);
 			case CIdent(n):
 				if( n.charCodeAt(0) == "$".code ) {
@@ -911,8 +915,12 @@ class RecordMacros {
 		return null;
 	}
 
-	function ensureType( e : Expr, rt : RecordType ) {
-		return { expr : ECheckType(e, convertType(rt)), pos : e.pos };
+	function ensureType( e : Expr, rt : RecordType, isNull : Bool ) {
+		var t = convertType(rt);
+		if (isNull) {
+			t = macro : Null<$t>;
+		}
+		return { expr : ECheckType(e, t), pos : e.pos };
 	}
 
 	function checkKeys( econd : Expr ) {
@@ -928,14 +936,14 @@ class RecordMacros {
 					else
 						error("Field " + f.field + " is not part of table key (" + inf.key.join(",") + ")", p);
 				}
-				f.expr = ensureType(f.expr, fi.t);
+				f.expr = ensureType(f.expr, fi.t, fi.isNull);
 			}
 			return econd;
 		default:
 			if( inf.key.length > 1 )
 				error("You can't use a single value on a table with multiple keys (" + inf.key.join(",") + ")", p);
 			var fi = inf.hfields.get(inf.key[0]);
-			return ensureType(econd, fi.t);
+			return ensureType(econd, fi.t, fi.isNull);
 		}
 	}
 

+ 1 - 1
tests/unit/compile-cs-travis.hxml

@@ -1,7 +1,7 @@
 compile-cs.hxml
 
 -net-lib cs_drivers/System.Data.dll
--net-lib cs_drivers/System.Xml.dll
+-net-lib cs_drivers/System.Xml.dll@std
 -net-lib cs_drivers/Mono.Data.Sqlite.dll
 
 -D travis

+ 1 - 0
tests/unit/compile-cs.hxml

@@ -7,3 +7,4 @@ compile-each.hxml
 -main unit.Test
 -cs bin/cs
 -net-lib native_cs/bin/native_cs.dll
+-D dump

+ 28 - 18
tests/unit/src/unit/MySpodClass.hx

@@ -4,33 +4,43 @@ import sys.db.Types;
 
 @:keep class MySpodClass extends Object
 {
-  public var theId:SId;
-  public var int:SInt;
-  public var double:SFloat;
-  public var boolean:SBool;
-  public var string:SString<255>;
-  public var date:SDateTime;
-  public var binary:SBinary;
+	public var theId:SId;
+	public var int:SInt;
+	public var double:SFloat;
+	public var boolean:SBool;
+	public var string:SString<255>;
+	public var date:SDateTime;
+	public var binary:SBinary;
 	public var abstractType:AbstractSpodTest<String>;
 
-  public var nullInt:SNull<Int>;
-  public var enumFlags:SFlags<SpodEnum>;
+	public var nullInt:SNull<Int>;
+	public var enumFlags:SFlags<SpodEnum>;
 
-  @:relation(rid) public var relation:OtherSpodClass;
-  @:relation(rnid) public var relationNullable:Null<OtherSpodClass>;
+	@:relation(rid) public var relation:OtherSpodClass;
+	@:relation(rnid) public var relationNullable:Null<OtherSpodClass>;
 	@:relation(spid) public var next:Null<MySpodClass>;
 
-  public var data:SData<Array<ComplexClass>>;
-  public var anEnum:SEnum<SpodEnum>;
+	public var data:SData<Array<ComplexClass>>;
+	public var anEnum:SEnum<SpodEnum>;
 }
 
 @:keep class NullableSpodClass extends Object
 {
 	public var theId:SId;
-  @:relation(rnid) public var relationNullable:Null<OtherSpodClass>;
-  public var data:Null<SData<Array<ComplexClass>>>;
-	public var abstractType:Null<AbstractSpodTest<String>>;
+	@:relation(rnid) public var relationNullable:Null<OtherSpodClass>;
+	public var data:Null<SData<Array<ComplexClass>>>;
 	public var anEnum:Null<SEnum<SpodEnum>>;
+
+	public var int:SNull<SInt>;
+	public var double:SNull<SFloat>;
+	public var boolean:SNull<SBool>;
+	public var string:SNull<SString<255>>;
+	public var date:SNull<SDateTime>;
+	public var binary:SNull<SBinary>;
+	public var abstractType:SNull<AbstractSpodTest<String>>;
+
+	public var nullInt:SNull<Int>;
+	public var enumFlags:SNull<SFlags<SpodEnum>>;
 }
 
 @:keep class ComplexClass
@@ -71,7 +81,7 @@ abstract AbstractSpodTest<A>(A) from A
 }
 
 @:id(name)
-@:keep class ClassWithStringId extends Object
+	@:keep class ClassWithStringId extends Object
 {
 	public var name:SString<255>;
 	public var field:SInt;
@@ -88,7 +98,7 @@ abstract AbstractSpodTest<A>(A) from A
 @:keep @:skip class BaseIssueC3828 extends sys.db.Object {
 	public var id : SInt;
 	@:relation(ruid)
-	public var refUser : SNull<IssueC3828>;
+		public var refUser : SNull<IssueC3828>;
 }
 
 @:keep class IssueC3828 extends BaseIssueC3828 {

+ 87 - 0
tests/unit/src/unit/TestSpod.hx

@@ -1,5 +1,7 @@
 package unit;
 import sys.db.*;
+import sys.db.Object;
+import sys.db.Types;
 import haxe.io.Bytes;
 import haxe.EnumFlags;
 import sys.db.Connection;
@@ -62,6 +64,91 @@ class TestSpod extends Test
 		return scls;
 	}
 
+	function getDefaultNull() {
+		var scls = new NullableSpodClass();
+		scls.int = 1;
+		scls.double = 2.0;
+		scls.boolean = true;
+		scls.string = "some string";
+		scls.date = new Date(2012, 7, 30, 0, 0, 0);
+		scls.abstractType = "other string";
+
+		var bytes = Bytes.ofString("\x01\n\r'\x02");
+		scls.binary = bytes;
+		scls.enumFlags = EnumFlags.ofInt(0);
+		scls.enumFlags.set(FirstValue);
+		scls.enumFlags.set(ThirdValue);
+
+		scls.data = [new ComplexClass( { name:"test", array:["this", "is", "a", "test"] } )];
+		scls.anEnum = SecondValue;
+		return scls;
+	}
+
+	public function testNull() {
+		setManager();
+		var n1 = getDefaultNull();
+		n1.insert();
+		var n2 = new NullableSpodClass();
+		n2.insert();
+		var id = n2.theId;
+
+		n1 = null; n2 = null;
+		Manager.cleanup();
+
+		var nullVal = getNull();
+		inline function checkReq(lst:List<NullableSpodClass>, ?nres=1, ?pos:haxe.PosInfos) {
+			eq(lst.length,nres, pos);
+			if (lst.length == 1) {
+				eq(lst.first().theId, id, pos);
+			}
+		}
+
+		checkReq(NullableSpodClass.manager.search($relationNullable == null), 2);
+		checkReq(NullableSpodClass.manager.search($data == null));
+		checkReq(NullableSpodClass.manager.search($anEnum == null));
+
+		checkReq(NullableSpodClass.manager.search($int == null));
+		checkReq(NullableSpodClass.manager.search($double == null));
+		checkReq(NullableSpodClass.manager.search($boolean == null));
+		checkReq(NullableSpodClass.manager.search($string == null));
+		checkReq(NullableSpodClass.manager.search($date == null));
+		checkReq(NullableSpodClass.manager.search($binary == null));
+		checkReq(NullableSpodClass.manager.search($abstractType == null));
+
+		checkReq(NullableSpodClass.manager.search($enumFlags == null));
+
+
+		var relationNullable:Null<OtherSpodClass> = getNull();
+		checkReq(NullableSpodClass.manager.search($relationNullable == relationNullable), 2);
+		var data:Null<Bytes> = getNull();
+		checkReq(NullableSpodClass.manager.search($data == data));
+		var anEnum:Null<SEnum<SpodEnum>> = getNull();
+		checkReq(NullableSpodClass.manager.search($anEnum == anEnum));
+
+		var int:Null<Int> = getNull();
+		checkReq(NullableSpodClass.manager.search($int == int));
+		var double:Null<Float> = getNull();
+		checkReq(NullableSpodClass.manager.search($double == double));
+		var boolean:Null<Bool> = getNull();
+		checkReq(NullableSpodClass.manager.search($boolean == boolean));
+		var string:SNull<SString<255>> = getNull();
+		checkReq(NullableSpodClass.manager.search($string == string));
+		var date:SNull<SDateTime> = getNull();
+		checkReq(NullableSpodClass.manager.search($date == date));
+		var binary:SNull<SBinary> = getNull();
+		checkReq(NullableSpodClass.manager.search($binary == binary));
+		var abstractType:SNull<String> = getNull();
+		checkReq(NullableSpodClass.manager.search($abstractType == abstractType));
+
+		for (val in NullableSpodClass.manager.all()) {
+			val.delete();
+		}
+	}
+
+	private function getNull<T>():Null<T> {
+		return null;
+	}
+
 	public function testIssue3828()
 	{
 		setManager();