Selaa lähdekoodia

[spirv] Add more visitors to LiteralTypeVisitor.

Ehsan Nasiri 6 vuotta sitten
vanhempi
commit
246f19a4c9

+ 52 - 9
tools/clang/lib/SPIRV/LiteralTypeVisitor.cpp

@@ -14,15 +14,16 @@
 namespace clang {
 namespace spirv {
 
-// -- SpirvReturn (OpReturnValue)
-// -- SpirvCompositeExtract
-// -- SpirvCompositeInsert
-// -- SpirvExtInst
-// -- SpirvImageOp
-// -- SpirvImageQuery
-// -- SpirvImageTexelPointer
-// -- SpirvSpecConstantBinaryOp
-// -- SpirvSpecConstantUnaryOp
+bool LiteralTypeVisitor::visit(SpirvFunction *fn, Phase phase) {
+  assert(fn);
+
+  // Before going through the function instructions
+  if (phase == Visitor::Phase::Init) {
+    curFnAstReturnType = fn->getAstReturnType();
+  }
+
+  return true;
+}
 
 bool LiteralTypeVisitor::isLiteralLargerThan32Bits(
     SpirvConstantInteger *constant) {
@@ -148,6 +149,8 @@ bool LiteralTypeVisitor::visit(SpirvBinaryOp *inst) {
       op == spv::Op::OpShiftLeftLogical) {
     // Base (arg1) should have the same type as result type
     updateTypeForInstruction(inst->getOperand1(), resultType);
+    // The shitf amount (arg2) cannot be a 64-bit type for a 32-bit base!
+    updateTypeForInstruction(inst->getOperand2(), resultType);
     return true;
   }
 
@@ -267,6 +270,11 @@ bool LiteralTypeVisitor::visit(SpirvStore *inst) {
     if (const auto *ptrType = type->getAs<PointerType>())
       type = ptrType->getPointeeType();
     updateTypeForInstruction(object, type);
+  } else if (pointer->hasResultType()) {
+    if (auto *ptrType = dyn_cast<HybridPointerType>(pointer->getResultType())) {
+      QualType type = ptrType->getPointeeType();
+      updateTypeForInstruction(object, type);
+    }
   }
   return true;
 }
@@ -375,5 +383,40 @@ bool LiteralTypeVisitor::visit(SpirvAccessChain *inst) {
   return true;
 }
 
+bool LiteralTypeVisitor::visit(SpirvExtInst *inst) {
+  // Result type of the instruction can provide a hint about its operands. e.g.
+  // OpExtInst %float %glsl_set Pow %double_2 %double_12
+  // should be evaluated as:
+  // OpExtInst %float %glsl_set Pow %float_2 %float_12
+  const auto resultType = inst->getAstResultType();
+  for (auto *operand : inst->getOperands())
+    updateTypeForInstruction(operand, resultType);
+  return true;
+}
+
+bool LiteralTypeVisitor::visit(SpirvReturn *inst) {
+  if (inst->hasReturnValue()) {
+    updateTypeForInstruction(inst->getReturnValue(), curFnAstReturnType);
+  }
+  return true;
+}
+
+bool LiteralTypeVisitor::visit(SpirvCompositeInsert *inst) {
+  const auto resultType = inst->getAstResultType();
+  updateTypeForInstruction(inst->getComposite(), resultType);
+  updateTypeForInstruction(inst->getObject(),
+                           getElementType(astContext, resultType));
+  return true;
+}
+
+bool LiteralTypeVisitor::visit(SpirvImageOp *inst) {
+  if (inst->isImageWrite() && inst->hasAstResultType()) {
+    const auto sampledType =
+        hlsl::GetHLSLResourceResultType(inst->getAstResultType());
+    updateTypeForInstruction(inst->getTexelToWrite(), sampledType);
+  }
+  return true;
+}
+
 } // end namespace spirv
 } // namespace clang

+ 16 - 1
tools/clang/lib/SPIRV/LiteralTypeVisitor.h

@@ -22,7 +22,9 @@ class LiteralTypeVisitor : public Visitor {
 public:
   LiteralTypeVisitor(const ASTContext &ctx, SpirvContext &spvCtx,
                      const SpirvCodeGenOptions &opts)
-      : Visitor(opts, spvCtx), astContext(ctx) {}
+      : Visitor(opts, spvCtx), astContext(ctx), curFnAstReturnType({}) {}
+
+  bool visit(SpirvFunction *, Phase);
 
   bool visit(SpirvVariable *);
   bool visit(SpirvAtomic *);
@@ -39,6 +41,18 @@ public:
   bool visit(SpirvComposite *);
   bool visit(SpirvCompositeExtract *);
   bool visit(SpirvAccessChain *);
+  bool visit(SpirvExtInst *);
+  bool visit(SpirvReturn *);
+  bool visit(SpirvCompositeInsert *);
+  bool visit(SpirvImageOp *);
+
+  // Note: We currently don't do anything to deduce literal types for the
+  // following instructions:
+  //
+  // SpirvImageQuery
+  // SpirvImageTexelPointer
+  // SpirvSpecConstantBinaryOp
+  // SpirvSpecConstantUnaryOp
 
   /// The "sink" visit function for all instructions.
   ///
@@ -67,6 +81,7 @@ private:
 
 private:
   const ASTContext &astContext;
+  QualType curFnAstReturnType;
 };
 
 } // end namespace spirv

+ 2 - 2
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -319,10 +319,10 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
         // LowerTypeVisitor is invoked. We should error out if we encounter a
         // literal type.
         case BuiltinType::LitInt:
-          emitError("found literal int type when lowering types", srcLoc);
+          //emitError("found literal int type when lowering types", srcLoc);
           return spvContext.getUIntType(64);
         case BuiltinType::LitFloat: {
-          emitError("found literal float type when lowering types", srcLoc);
+          //emitError("found literal float type when lowering types", srcLoc);
           return spvContext.getFloatType(64);
 
         default:

+ 3 - 17
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -8572,23 +8572,9 @@ SpirvConstant *SPIRVEmitter::getMaskForBitwidthValue(QualType type) {
   if (isScalarType(type, &elemType) || isVectorType(type, &elemType, &count)) {
     const auto bitwidth = getElementSpirvBitwidth(
         astContext, elemType, spirvOptions.enable16BitTypes);
-    SpirvConstant *mask = nullptr;
-    switch (bitwidth) {
-    case 16:
-      elemType = astContext.UnsignedShortTy;
-      mask = spvBuilder.getConstantInt(elemType, llvm::APInt(16, bitwidth - 1));
-      break;
-    case 32:
-      elemType = astContext.UnsignedIntTy;
-      mask = spvBuilder.getConstantInt(elemType, llvm::APInt(32, bitwidth - 1));
-      break;
-    case 64:
-      elemType = astContext.UnsignedLongLongTy;
-      mask = spvBuilder.getConstantInt(elemType, llvm::APInt(64, bitwidth - 1));
-      break;
-    default:
-      assert(false && "this method only supports 16-, 32-, and 64-bit types");
-    }
+    SpirvConstant *mask = spvBuilder.getConstantInt(
+        elemType,
+        llvm::APInt(bitwidth, bitwidth - 1, elemType->isSignedIntegerType()));
 
     if (count == 1)
       return mask;

+ 3 - 3
tools/clang/test/CodeGenSPIRV/binary-op.bitwise-assign.shift-left.hlsl

@@ -1,7 +1,7 @@
 // Run: %dxc -T ps_6_2 -E main -enable-16bit-types
 
 // CHECK: [[v2c31:%\d+]] = OpConstantComposite %v2uint %uint_31 %uint_31
-// CHECK: [[v3c63:%\d+]] = OpConstantComposite %v3ulong %ulong_63 %ulong_63 %ulong_63
+// CHECK: [[v3c63:%\d+]] = OpConstantComposite %v3long %long_63 %long_63 %long_63
 // CHECK: [[v4c15:%\d+]] = OpConstantComposite %v4ushort %ushort_15 %ushort_15 %ushort_15 %ushort_15
 void main() {
     int       a, b;
@@ -14,7 +14,7 @@ void main() {
     uint16_t4 p, q;
 
 // CHECK:        [[b:%\d+]] = OpLoad %int %b
-// CHECK:      [[rhs:%\d+]] = OpBitwiseAnd %int [[b]] %uint_31
+// CHECK:      [[rhs:%\d+]] = OpBitwiseAnd %int [[b]] %int_31
 // CHECK-NEXT:                OpShiftLeftLogical %int {{%\d+}} [[rhs]]
     a <<= b;
 
@@ -34,7 +34,7 @@ void main() {
     j <<= k;
 
 // CHECK:        [[n:%\d+]] = OpLoad %short %n
-// CHECK:      [[rhs:%\d+]] = OpBitwiseAnd %short [[n]] %ushort_15
+// CHECK:      [[rhs:%\d+]] = OpBitwiseAnd %short [[n]] %short_15
 // CHECK-NEXT:                OpShiftLeftLogical %short {{%\d+}} [[rhs]]
     m <<= n;
 

+ 3 - 3
tools/clang/test/CodeGenSPIRV/binary-op.bitwise-assign.shift-right.hlsl

@@ -1,7 +1,7 @@
 // Run: %dxc -T ps_6_2 -E main -enable-16bit-types
 
 // CHECK: [[v2c31:%\d+]] = OpConstantComposite %v2uint %uint_31 %uint_31
-// CHECK: [[v3c63:%\d+]] = OpConstantComposite %v3ulong %ulong_63 %ulong_63 %ulong_63
+// CHECK: [[v3c63:%\d+]] = OpConstantComposite %v3long %long_63 %long_63 %long_63
 // CHECK: [[v4c15:%\d+]] = OpConstantComposite %v4ushort %ushort_15 %ushort_15 %ushort_15 %ushort_15
 void main() {
     int       a, b;
@@ -14,7 +14,7 @@ void main() {
     uint16_t4 p, q;
 
 // CHECK:        [[b:%\d+]] = OpLoad %int %b
-// CHECK:      [[rhs:%\d+]] = OpBitwiseAnd %int [[b]] %uint_31
+// CHECK:      [[rhs:%\d+]] = OpBitwiseAnd %int [[b]] %int_31
 // CHECK-NEXT:                OpShiftRightArithmetic %int {{%\d+}} [[rhs]]
     a >>= b;
 
@@ -34,7 +34,7 @@ void main() {
     j >>= k;
 
 // CHECK:        [[n:%\d+]] = OpLoad %short %n
-// CHECK:      [[rhs:%\d+]] = OpBitwiseAnd %short [[n]] %ushort_15
+// CHECK:      [[rhs:%\d+]] = OpBitwiseAnd %short [[n]] %short_15
 // CHECK-NEXT:                OpShiftRightArithmetic %short {{%\d+}} [[rhs]]
     m >>= n;
 

+ 3 - 3
tools/clang/test/CodeGenSPIRV/binary-op.bitwise.shift-left.hlsl

@@ -1,7 +1,7 @@
 // Run: %dxc -T ps_6_2 -E main -enable-16bit-types
 
 // CHECK: [[v2c31:%\d+]] = OpConstantComposite %v2uint %uint_31 %uint_31
-// CHECK: [[v3c63:%\d+]] = OpConstantComposite %v3ulong %ulong_63 %ulong_63 %ulong_63
+// CHECK: [[v3c63:%\d+]] = OpConstantComposite %v3long %long_63 %long_63 %long_63
 // CHECK: [[v4c15:%\d+]] = OpConstantComposite %v4ushort %ushort_15 %ushort_15 %ushort_15 %ushort_15
 void main() {
     int       a, b, c;
@@ -14,7 +14,7 @@ void main() {
     uint16_t4 p, q, r;
 
 // CHECK:        [[b:%\d+]] = OpLoad %int %b
-// CHECK-NEXT: [[rhs:%\d+]] = OpBitwiseAnd %int [[b]] %uint_31
+// CHECK-NEXT: [[rhs:%\d+]] = OpBitwiseAnd %int [[b]] %int_31
 // CHECK-NEXT:                OpShiftLeftLogical %int {{%\d+}} [[rhs]]
     c = a << b;
 
@@ -34,7 +34,7 @@ void main() {
     l = j << k;
 
 // CHECK:        [[n:%\d+]] = OpLoad %short %n
-// CHECK-NEXT: [[rhs:%\d+]] = OpBitwiseAnd %short [[n]] %ushort_15
+// CHECK-NEXT: [[rhs:%\d+]] = OpBitwiseAnd %short [[n]] %short_15
 // CHECK-NEXT:                OpShiftLeftLogical %short {{%\d+}} [[rhs]]
     o = m << n;
 

+ 3 - 3
tools/clang/test/CodeGenSPIRV/binary-op.bitwise.shift-right.hlsl

@@ -1,7 +1,7 @@
 // Run: %dxc -T ps_6_2 -E main -enable-16bit-types
 
 // CHECK: [[v2c31:%\d+]] = OpConstantComposite %v2uint %uint_31 %uint_31
-// CHECK: [[v3c63:%\d+]] = OpConstantComposite %v3ulong %ulong_63 %ulong_63 %ulong_63
+// CHECK: [[v3c63:%\d+]] = OpConstantComposite %v3long %long_63 %long_63 %long_63
 // CHECK: [[v4c15:%\d+]] = OpConstantComposite %v4ushort %ushort_15 %ushort_15 %ushort_15 %ushort_15
 void main() {
     int       a, b, c;
@@ -14,7 +14,7 @@ void main() {
     uint16_t4 p, q, r;
 
 // CHECK:        [[b:%\d+]] = OpLoad %int %b
-// CHECK-NEXT: [[rhs:%\d+]] = OpBitwiseAnd %int [[b]] %uint_31
+// CHECK-NEXT: [[rhs:%\d+]] = OpBitwiseAnd %int [[b]] %int_31
 // CHECK-NEXT:                OpShiftRightArithmetic %int {{%\d+}} [[rhs]]
     c = a >> b;
 
@@ -34,7 +34,7 @@ void main() {
     l = j >> k;
 
 // CHECK:        [[n:%\d+]] = OpLoad %short %n
-// CHECK-NEXT: [[rhs:%\d+]] = OpBitwiseAnd %short [[n]] %ushort_15
+// CHECK-NEXT: [[rhs:%\d+]] = OpBitwiseAnd %short [[n]] %short_15
 // CHECK-NEXT:                OpShiftRightArithmetic %short {{%\d+}} [[rhs]]
     o = m >> n;
 

+ 153 - 141
tools/clang/test/CodeGenSPIRV/ternary-op.cond-op.hlsl

@@ -8,147 +8,159 @@ Texture2D gTex;
 
 uint foo() { return 1; }
 float bar() { return 3.0; }
+uint zoo();
 
 void main() {
-// CHECK-LABEL: %bb_entry = OpLabel
-
-// CHECK: %temp_var_ternary = OpVariable %_ptr_Function_mat2v3float Function
-
-    bool b0;
-    int m, n, o;
-    // Plain assign (scalar)
-// CHECK:      [[b0:%\d+]] = OpLoad %bool %b0
-// CHECK-NEXT: [[m0:%\d+]] = OpLoad %int %m
-// CHECK-NEXT: [[n0:%\d+]] = OpLoad %int %n
-// CHECK-NEXT: [[s0:%\d+]] = OpSelect %int [[b0]] [[m0]] [[n0]]
-// CHECK-NEXT: OpStore %o [[s0]]
-    o = b0 ? m : n;
-
-
-    bool1 b1;
-    bool3 b3;
-    uint1 p, q, r;
-    float3 x, y, z;
-    // Plain assign (vector)
-// CHECK-NEXT: [[b1:%\d+]] = OpLoad %bool %b1
-// CHECK-NEXT: [[p0:%\d+]] = OpLoad %uint %p
-// CHECK-NEXT: [[q0:%\d+]] = OpLoad %uint %q
-// CHECK-NEXT: [[s1:%\d+]] = OpSelect %uint [[b1]] [[p0]] [[q0]]
-// CHECK-NEXT: OpStore %r [[s1]]
-    r = b1 ? p : q;
-// CHECK-NEXT: [[b3:%\d+]] = OpLoad %v3bool %b3
-// CHECK-NEXT: [[x0:%\d+]] = OpLoad %v3float %x
-// CHECK-NEXT: [[y0:%\d+]] = OpLoad %v3float %y
-// CHECK-NEXT: [[s2:%\d+]] = OpSelect %v3float [[b3]] [[x0]] [[y0]]
-// CHECK-NEXT: OpStore %z [[s2]]
-    z = b3 ? x : y;
-
-    // Try condition with various type.
-    // Note: the SPIR-V OpSelect selection argument must be the same size as the return type.
-    int3 u, v, w;
-    bool  cond;
-    bool3 cond3;
-    float floatCond;
-    int3 int3Cond;
-
-// CHECK:      [[cond3:%\d+]] = OpLoad %v3bool %cond3
-// CHECK-NEXT:     [[u:%\d+]] = OpLoad %v3int %u
-// CHECK-NEXT:     [[v:%\d+]] = OpLoad %v3int %v
-// CHECK-NEXT:       {{%\d+}} = OpSelect %v3int [[cond3]] [[u]] [[v]]
-    w = cond3 ? u : v;
-
-// CHECK:       [[cond:%\d+]] = OpLoad %bool %cond
-// CHECK-NEXT:     [[u:%\d+]] = OpLoad %v3int %u
-// CHECK-NEXT:     [[v:%\d+]] = OpLoad %v3int %v
-// CHECK-NEXT: [[splat:%\d+]] = OpCompositeConstruct %v3bool [[cond]] [[cond]] [[cond]]
-// CHECK-NEXT:       {{%\d+}} = OpSelect %v3int [[splat]] [[u]] [[v]]
-    w = cond  ? u : v;
-
-// CHECK:      [[floatCond:%\d+]] = OpLoad %float %floatCond
-// CHECK-NEXT:  [[boolCond:%\d+]] = OpFOrdNotEqual %bool [[floatCond]] %float_0
-// CHECK-NEXT: [[bool3Cond:%\d+]] = OpCompositeConstruct %v3bool [[boolCond]] [[boolCond]] [[boolCond]]
-// CHECK-NEXT:         [[u:%\d+]] = OpLoad %v3int %u
-// CHECK-NEXT:         [[v:%\d+]] = OpLoad %v3int %v
-// CHECK-NEXT:           {{%\d+}} = OpSelect %v3int [[bool3Cond]] [[u]] [[v]]
-    w = floatCond ? u : v;
-
-// CHECK:       [[int3Cond:%\d+]] = OpLoad %v3int %int3Cond
-// CHECK-NEXT: [[bool3Cond:%\d+]] = OpINotEqual %v3bool [[int3Cond]] [[v3i0]]
-// CHECK-NEXT:         [[u:%\d+]] = OpLoad %v3int %u
-// CHECK-NEXT:         [[v:%\d+]] = OpLoad %v3int %v
-// CHECK-NEXT:           {{%\d+}} = OpSelect %v3int [[bool3Cond]] [[u]] [[v]]
-    w = int3Cond ? u : v;
-
-// Make sure literal types are handled correctly in ternary ops
-
-// CHECK: [[b_float:%\d+]] = OpSelect %float {{%\d+}} %float_1_5 %float_2_5
-// CHECK-NEXT:    {{%\d+}} = OpConvertFToS %int [[b_float]]
-    int   b = cond ? 1.5 : 2.5;
-
-// CHECK:      [[a_int:%\d+]] = OpSelect %int {{%\d+}} %int_1 %int_0
-// CHECK-NEXT:       {{%\d+}} = OpConvertSToF %float [[a_int]]
-    float a = cond ? 1 : 0;
-
-// CHECK:      [[c_long:%\d+]] = OpSelect %long {{%\d+}} %long_3000000000 %long_4000000000
-    double c = cond ? 3000000000 : 4000000000;
-
-// CHECK:      [[d_int:%\d+]] = OpSelect %uint {{%\d+}} %uint_1 %uint_0
-    uint d = cond ? 1 : 0;
-
-    float2x3 e;
-    float2x3 f;
-// CHECK:     [[cond:%\d+]] = OpLoad %bool %cond
-// CHECK-NEXT:   [[e:%\d+]] = OpLoad %mat2v3float %e
-// CHECK-NEXT:   [[f:%\d+]] = OpLoad %mat2v3float %f
-// CHECK-NEXT:                OpSelectionMerge %if_merge None
-// CHECK-NEXT:                OpBranchConditional [[cond]] %if_true %if_false
-// CHECK-NEXT:     %if_true = OpLabel
-// CHECK-NEXT:                OpStore %temp_var_ternary [[e]]
-// CHECK-NEXT:                OpBranch %if_merge
-// CHECK-NEXT:    %if_false = OpLabel
-// CHECK-NEXT:                OpStore %temp_var_ternary [[f]]
-// CHECK-NEXT:                OpBranch %if_merge
-// CHECK-NEXT:    %if_merge = OpLabel
-// CHECK-NEXT:[[temp:%\d+]] = OpLoad %mat2v3float %temp_var_ternary
-// CHECK-NEXT:                OpStore %g [[temp]]
-    float2x3 g = cond ? e : f;
-
-// CHECK:      [[inner:%\d+]] = OpSelect %uint {{%\d+}} %uint_1 %uint_2
-// CHECK-NEXT:       {{%\d+}} = OpSelect %uint {{%\d+}} %uint_9 [[inner]]
-    uint h = cond ? 9 : (cond ? 1 : 2);
-
-//CHECK:      [[i_int:%\d+]] = OpSelect %int {{%\d+}} %int_1 %int_0
-//CHECK-NEXT:       {{%\d+}} = OpINotEqual %bool [[i_int]] %int_0
-    bool i = cond ? 1 : 0;
-
-// CHECK:     [[foo:%\d+]] = OpFunctionCall %uint %foo
-// CHECKNEXT:     {{%\d+}} = OpSelect %uint {{%\d+}} %uint_3 [[foo]]
-    uint j = cond ? 3 : foo();
-
-// CHECK:          [[bar:%\d+]] = OpFunctionCall %float %bar
-// CHECK-NEXT: [[k_float:%\d+]] = OpSelect %float {{%\d+}} %float_4 [[bar]]
-// CHECK-NEXT:         {{%\d+}} = OpConvertFToU %uint [[k_float]]
-    uint k = cond ? 4 : bar();
-
-// AST looks like:
-// |-ConditionalOperator 'SamplerState'
-// | |-DeclRefExpr 'bool' lvalue Var 0x1476949e328 'cond' 'bool'
-// | |-DeclRefExpr 'SamplerState' lvalue Var 0x1476742e498 'gSS1' 'SamplerState'
-// | `-DeclRefExpr 'SamplerState' lvalue Var 0x1476742e570 'gSS2' 'SamplerState'
-
-// CHECK:      [[cond:%\d+]] = OpLoad %bool %cond
-// CHECK-NEXT: [[gSS1:%\d+]] = OpLoad %type_sampler %gSS1
-// CHECK-NEXT: [[gSS2:%\d+]] = OpLoad %type_sampler %gSS2
-// CHECK-NEXT:                 OpSelectionMerge %if_merge_0 None
-// CHECK-NEXT:                 OpBranchConditional [[cond]] %if_true_0 %if_false_0
-// CHECK-NEXT:    %if_true_0 = OpLabel
-// CHECK-NEXT:                 OpStore %temp_var_ternary_0 [[gSS1]]
-// CHECK-NEXT:                 OpBranch %if_merge_0
-// CHECK-NEXT:   %if_false_0 = OpLabel
-// CHECK-NEXT:                 OpStore %temp_var_ternary_0 [[gSS2]]
-// CHECK-NEXT:                 OpBranch %if_merge_0
-// CHECK-NEXT:   %if_merge_0 = OpLabel
-// CHECK-NEXT:   [[ss:%\d+]] = OpLoad %type_sampler %temp_var_ternary_0
-// CHECK-NEXT:      {{%\d+}} = OpSampledImage %type_sampled_image {{%\d+}} [[ss]]
-    float4 l = gTex.Sample(cond ? gSS1 : gSS2, float2(1., 2.));
+  // CHECK-LABEL: %bb_entry = OpLabel
+
+  // CHECK: %temp_var_ternary = OpVariable %_ptr_Function_mat2v3float Function
+
+  bool b0;
+  int m, n, o;
+  // Plain assign (scalar)
+  // CHECK:      [[b0:%\d+]] = OpLoad %bool %b0
+  // CHECK-NEXT: [[m0:%\d+]] = OpLoad %int %m
+  // CHECK-NEXT: [[n0:%\d+]] = OpLoad %int %n
+  // CHECK-NEXT: [[s0:%\d+]] = OpSelect %int [[b0]] [[m0]] [[n0]]
+  // CHECK-NEXT: OpStore %o [[s0]]
+  o = b0 ? m : n;
+
+  bool1 b1;
+  bool3 b3;
+  uint1 p, q, r;
+  float3 x, y, z;
+  // Plain assign (vector)
+  // CHECK-NEXT: [[b1:%\d+]] = OpLoad %bool %b1
+  // CHECK-NEXT: [[p0:%\d+]] = OpLoad %uint %p
+  // CHECK-NEXT: [[q0:%\d+]] = OpLoad %uint %q
+  // CHECK-NEXT: [[s1:%\d+]] = OpSelect %uint [[b1]] [[p0]] [[q0]]
+  // CHECK-NEXT: OpStore %r [[s1]]
+  r = b1 ? p : q;
+  // CHECK-NEXT: [[b3:%\d+]] = OpLoad %v3bool %b3
+  // CHECK-NEXT: [[x0:%\d+]] = OpLoad %v3float %x
+  // CHECK-NEXT: [[y0:%\d+]] = OpLoad %v3float %y
+  // CHECK-NEXT: [[s2:%\d+]] = OpSelect %v3float [[b3]] [[x0]] [[y0]]
+  // CHECK-NEXT: OpStore %z [[s2]]
+  z = b3 ? x : y;
+
+  // Try condition with various type.
+  // Note: the SPIR-V OpSelect selection argument must be the same size as the return type.
+  int3 u, v, w;
+  bool cond;
+  bool3 cond3;
+  float floatCond;
+  int3 int3Cond;
+
+  // CHECK:      [[cond3:%\d+]] = OpLoad %v3bool %cond3
+  // CHECK-NEXT:     [[u:%\d+]] = OpLoad %v3int %u
+  // CHECK-NEXT:     [[v:%\d+]] = OpLoad %v3int %v
+  // CHECK-NEXT:       {{%\d+}} = OpSelect %v3int [[cond3]] [[u]] [[v]]
+  w = cond3 ? u : v;
+
+  // CHECK:       [[cond:%\d+]] = OpLoad %bool %cond
+  // CHECK-NEXT:     [[u:%\d+]] = OpLoad %v3int %u
+  // CHECK-NEXT:     [[v:%\d+]] = OpLoad %v3int %v
+  // CHECK-NEXT: [[splat:%\d+]] = OpCompositeConstruct %v3bool [[cond]] [[cond]] [[cond]]
+  // CHECK-NEXT:       {{%\d+}} = OpSelect %v3int [[splat]] [[u]] [[v]]
+  w = cond ? u : v;
+
+  // CHECK:      [[floatCond:%\d+]] = OpLoad %float %floatCond
+  // CHECK-NEXT:  [[boolCond:%\d+]] = OpFOrdNotEqual %bool [[floatCond]] %float_0
+  // CHECK-NEXT: [[bool3Cond:%\d+]] = OpCompositeConstruct %v3bool [[boolCond]] [[boolCond]] [[boolCond]]
+  // CHECK-NEXT:         [[u:%\d+]] = OpLoad %v3int %u
+  // CHECK-NEXT:         [[v:%\d+]] = OpLoad %v3int %v
+  // CHECK-NEXT:           {{%\d+}} = OpSelect %v3int [[bool3Cond]] [[u]] [[v]]
+  w = floatCond ? u : v;
+
+  // CHECK:       [[int3Cond:%\d+]] = OpLoad %v3int %int3Cond
+  // CHECK-NEXT: [[bool3Cond:%\d+]] = OpINotEqual %v3bool [[int3Cond]] [[v3i0]]
+  // CHECK-NEXT:         [[u:%\d+]] = OpLoad %v3int %u
+  // CHECK-NEXT:         [[v:%\d+]] = OpLoad %v3int %v
+  // CHECK-NEXT:           {{%\d+}} = OpSelect %v3int [[bool3Cond]] [[u]] [[v]]
+  w = int3Cond ? u : v;
+
+  // Make sure literal types are handled correctly in ternary ops
+
+  // CHECK: [[b_float:%\d+]] = OpSelect %float {{%\d+}} %float_1_5 %float_2_5
+  // CHECK-NEXT:    {{%\d+}} = OpConvertFToS %int [[b_float]]
+  int b = cond ? 1.5 : 2.5;
+
+  // CHECK:      [[a_int:%\d+]] = OpSelect %int {{%\d+}} %int_1 %int_0
+  // CHECK-NEXT:       {{%\d+}} = OpConvertSToF %float [[a_int]]
+  float a = cond ? 1 : 0;
+
+  // CHECK:      [[c_long:%\d+]] = OpSelect %long {{%\d+}} %long_3000000000 %long_4000000000
+  double c = cond ? 3000000000 : 4000000000;
+
+  // CHECK:      [[d_int:%\d+]] = OpSelect %uint {{%\d+}} %uint_1 %uint_0
+  uint d = cond ? 1 : 0;
+
+  float2x3 e;
+  float2x3 f;
+  // CHECK:     [[cond:%\d+]] = OpLoad %bool %cond
+  // CHECK-NEXT:   [[e:%\d+]] = OpLoad %mat2v3float %e
+  // CHECK-NEXT:   [[f:%\d+]] = OpLoad %mat2v3float %f
+  // CHECK-NEXT:                OpSelectionMerge %if_merge None
+  // CHECK-NEXT:                OpBranchConditional [[cond]] %if_true %if_false
+  // CHECK-NEXT:     %if_true = OpLabel
+  // CHECK-NEXT:                OpStore %temp_var_ternary [[e]]
+  // CHECK-NEXT:                OpBranch %if_merge
+  // CHECK-NEXT:    %if_false = OpLabel
+  // CHECK-NEXT:                OpStore %temp_var_ternary [[f]]
+  // CHECK-NEXT:                OpBranch %if_merge
+  // CHECK-NEXT:    %if_merge = OpLabel
+  // CHECK-NEXT:[[temp:%\d+]] = OpLoad %mat2v3float %temp_var_ternary
+  // CHECK-NEXT:                OpStore %g [[temp]]
+  float2x3 g = cond ? e : f;
+
+  // CHECK:      [[inner:%\d+]] = OpSelect %uint {{%\d+}} %uint_1 %uint_2
+  // CHECK-NEXT:       {{%\d+}} = OpSelect %uint {{%\d+}} %uint_9 [[inner]]
+  uint h = cond ? 9 : (cond ? 1 : 2);
+
+  //CHECK:      [[i_int:%\d+]] = OpSelect %int {{%\d+}} %int_1 %int_0
+  //CHECK-NEXT:       {{%\d+}} = OpINotEqual %bool [[i_int]] %int_0
+  bool i = cond ? 1 : 0;
+
+  // CHECK:     [[foo:%\d+]] = OpFunctionCall %uint %foo
+  // CHECKNEXT:     {{%\d+}} = OpSelect %uint {{%\d+}} %uint_3 [[foo]]
+  uint j = cond ? 3 : foo();
+
+  // CHECK:          [[bar:%\d+]] = OpFunctionCall %float %bar
+  // CHECK-NEXT: [[k_float:%\d+]] = OpSelect %float {{%\d+}} %float_4 [[bar]]
+  // CHECK-NEXT:         {{%\d+}} = OpConvertFToU %uint [[k_float]]
+  uint k = cond ? 4 : bar();
+
+  // AST looks like:
+  // |-ConditionalOperator 'SamplerState'
+  // | |-DeclRefExpr 'bool' lvalue Var 0x1476949e328 'cond' 'bool'
+  // | |-DeclRefExpr 'SamplerState' lvalue Var 0x1476742e498 'gSS1' 'SamplerState'
+  // | `-DeclRefExpr 'SamplerState' lvalue Var 0x1476742e570 'gSS2' 'SamplerState'
+
+  // CHECK:      [[cond:%\d+]] = OpLoad %bool %cond
+  // CHECK-NEXT: [[gSS1:%\d+]] = OpLoad %type_sampler %gSS1
+  // CHECK-NEXT: [[gSS2:%\d+]] = OpLoad %type_sampler %gSS2
+  // CHECK-NEXT:                 OpSelectionMerge %if_merge_0 None
+  // CHECK-NEXT:                 OpBranchConditional [[cond]] %if_true_0 %if_false_0
+  // CHECK-NEXT:    %if_true_0 = OpLabel
+  // CHECK-NEXT:                 OpStore %temp_var_ternary_0 [[gSS1]]
+  // CHECK-NEXT:                 OpBranch %if_merge_0
+  // CHECK-NEXT:   %if_false_0 = OpLabel
+  // CHECK-NEXT:                 OpStore %temp_var_ternary_0 [[gSS2]]
+  // CHECK-NEXT:                 OpBranch %if_merge_0
+  // CHECK-NEXT:   %if_merge_0 = OpLabel
+  // CHECK-NEXT:   [[ss:%\d+]] = OpLoad %type_sampler %temp_var_ternary_0
+  // CHECK-NEXT:      {{%\d+}} = OpSampledImage %type_sampled_image {{%\d+}} [[ss]]
+  float4 l = gTex.Sample(cond ? gSS1 : gSS2, float2(1., 2.));
+
+  zoo();
 }
+
+//
+// The literal integer type should be deduced from the function return type.
+//
+// CHECK: OpSelect %uint %174 %uint_1 %uint_2
+uint zoo() {
+  bool cond;
+  return cond ? 1 : 2;
+}
+